You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support channels_last format in portable upsample kernels (#9526)
Summary:
Support channels_last input format in portable CPU upsample_bilinear2d and upsample_nearest2d kernels. This is useful for resize-in-model patterns when the user wants to pass inputs in channels_last format. It also (theoretically) allows for more effective auto-vectorization when vectorizing along the channels dim when there are a larger number of channels.
I considered generalizing the kernel to handle arbitrary dim order, but having a specialized channels last version allows for traversing the output in contiguous order. I could add a separate, arbitrarily-strided variant, but we can take that as a follow-up if needed.
To accomplish this, this PR makes the following changes:
- Update `check_upsample_2d_common_args` to relax the dim order restriction. It now allows for both default and channels_last dim order and verifies that the output dim order matches the input.
- In the upsample kernels (bilinear and nearest), split out NCHW and NHWC variants. The NHWC variant interchanges the loop order as to maintain contiguous output accesses.
- Add test coverage to ensure ATen numerical parity.
Reviewed By: manuelcandales
Differential Revision: D71690379
0 commit comments