Skip to content

Commit 1bfb530

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support channels_last format in portable upsample kernels
Summary: Support channels_last input format in portable CPU upsample_bilinear2d and upsample_nearest2d kernels. This is useful for resize-in-model patterns when the user wants to pass inputs in channels_last format. It also (theoretically) allows for more effective auto-vectorization when vectorizing along the channels dim when there are a larger number of channels. I considered generalizing the kernel to handle arbitrary dim order, but having a specialized channels last version allows for traversing the output in contiguous order. I could add a separate, arbitrarily-strided variant, but we can take that as a follow-up if needed. To accomplish this, this PR makes the following changes: - Update `check_upsample_2d_common_args` to relax the dim order restriction. It now allows for both default and channels_last dim order and verifies that the output dim order matches the input. - In the upsample kernels (bilinear and nearest), split out NCHW and NHWC variants. The NHWC variant interchanges the loop order as to maintain contiguous output accesses. - Add test coverage to ensure ATen numerical parity. Differential Revision: D71690379
1 parent 76ae537 commit 1bfb530

File tree

7 files changed

+281
-12
lines changed

7 files changed

+281
-12
lines changed

kernels/portable/cpu/op_upsample_bilinear2d.cpp

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using executorch::aten::SizesType;
2020

2121
namespace {
2222
template <typename CTYPE>
23-
void upsample_bilinear2d_kernel_impl(
23+
void upsample_bilinear2d_kernel_impl_nchw(
2424
const Tensor& in,
2525
bool align_corners,
2626
const float scale_h,
@@ -86,6 +86,71 @@ void upsample_bilinear2d_kernel_impl(
8686
}
8787
}
8888
}
89+
90+
template <typename CTYPE>
91+
void upsample_bilinear2d_kernel_impl_nhwc(
92+
const Tensor& in,
93+
bool align_corners,
94+
const float scale_h,
95+
const float scale_w,
96+
Tensor& out) {
97+
const auto in_data = in.const_data_ptr<CTYPE>();
98+
auto out_data = out.mutable_data_ptr<CTYPE>();
99+
100+
for ([[maybe_unused]] const auto n : c10::irange(out.size(0))) {
101+
for (const auto h : c10::irange(out.size(2))) {
102+
// Compute source index and weights.
103+
int64_t in_h1, in_h2;
104+
float weight_h, inv_weight_h;
105+
106+
compute_source_index_and_lambda(
107+
in_h1,
108+
in_h2,
109+
weight_h,
110+
inv_weight_h,
111+
scale_h,
112+
h,
113+
in.sizes()[2],
114+
out.sizes()[2],
115+
align_corners);
116+
117+
for (const auto w : c10::irange(out.size(3))) {
118+
int64_t in_w1, in_w2;
119+
float weight_w, inv_weight_w;
120+
121+
compute_source_index_and_lambda(
122+
in_w1,
123+
in_w2,
124+
weight_w,
125+
inv_weight_w,
126+
scale_w,
127+
w,
128+
in.sizes()[3],
129+
out.sizes()[3],
130+
align_corners);
131+
132+
for ([[maybe_unused]] const auto c : c10::irange(out.size(1))) {
133+
const auto top_left =
134+
in_data[in_h1 * in.strides()[2] + in_w1 * in.strides()[3] + c * in.strides()[1]];
135+
const auto top_right =
136+
in_data[in_h1 * in.strides()[2] + in_w2 * in.strides()[3] + c * in.strides()[1]];
137+
const auto bottom_left =
138+
in_data[in_h2 * in.strides()[2] + in_w1 * in.strides()[3] + c * in.strides()[1]];
139+
const auto bottom_right =
140+
in_data[in_h2 * in.strides()[2] + in_w2 * in.strides()[3] + c * in.strides()[1]];
141+
142+
const auto top = top_left * weight_w + top_right * inv_weight_w;
143+
const auto bottom =
144+
bottom_left * weight_w + bottom_right * inv_weight_w;
145+
const auto val = top * weight_h + bottom * inv_weight_h;
146+
147+
*out_data = val;
148+
out_data++;
149+
}
150+
}
151+
}
152+
}
153+
}
89154
} // namespace
90155

91156
// Signatures are auto-generated, so disable pass-by-value lint.
@@ -101,7 +166,7 @@ Tensor& upsample_bilinear2d_vec_out(
101166
// Preconditions (checked in check_..._args):
102167
// In and out tensors have same dtype.
103168
// In and out tensors are rank 4 and have same dim[0] and dim[1].
104-
// In and out tensors are default dim order (NCHW).
169+
// In and out tensors are NHWC or NCHW dim order.
105170
ET_KERNEL_CHECK(
106171
ctx,
107172
check_upsample_bilinear2d_args(
@@ -124,11 +189,24 @@ Tensor& upsample_bilinear2d_vec_out(
124189
const auto kernel_scale_w = area_pixel_compute_scale<double>(
125190
in.sizes()[3], out.sizes()[3], align_corners, scale_w);
126191

127-
ET_SWITCH_REALHBF16_TYPES(
128-
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
129-
upsample_bilinear2d_kernel_impl<CTYPE>(
130-
in, align_corners, kernel_scale_h, kernel_scale_w, out);
131-
});
192+
if (executorch::runtime::tensor_is_default_dim_order(in)) {
193+
ET_SWITCH_REALHBF16_TYPES(
194+
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
195+
upsample_bilinear2d_kernel_impl_nchw<CTYPE>(
196+
in, align_corners, kernel_scale_h, kernel_scale_w, out);
197+
});
198+
} else if (executorch::runtime::tensor_is_channels_last_dim_order(in)) {
199+
ET_SWITCH_REALHBF16_TYPES(
200+
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
201+
upsample_bilinear2d_kernel_impl_nhwc<CTYPE>(
202+
in, align_corners, kernel_scale_h, kernel_scale_w, out);
203+
});
204+
} else {
205+
// Shouldn't be reachable because of args checks, but just in case.
206+
ET_LOG(Error, "Unsupported dim order");
207+
ctx.fail(Error::InvalidArgument);
208+
return out;
209+
}
132210

133211
return out;
134212
}

kernels/portable/cpu/op_upsample_nearest2d.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using executorch::aten::SizesType;
1919

2020
namespace {
2121
template <typename CTYPE>
22-
void upsample_nearest2d_kernel_impl(
22+
void upsample_nearest2d_kernel_impl_nchw(
2323
const Tensor& in,
2424
const float scale_h,
2525
const float scale_w,
@@ -46,6 +46,31 @@ void upsample_nearest2d_kernel_impl(
4646
}
4747
}
4848
}
49+
50+
template <typename CTYPE>
51+
void upsample_nearest2d_kernel_impl_nhwc(
52+
const Tensor& in,
53+
const float scale_h,
54+
const float scale_w,
55+
Tensor& out) {
56+
const auto in_data = in.const_data_ptr<CTYPE>();
57+
auto out_data = out.mutable_data_ptr<CTYPE>();
58+
59+
for (auto n = 0; n < out.size(0); n++) {
60+
for (auto h = 0; h < out.size(2); h++) {
61+
const auto in_h =
62+
nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]);
63+
for (auto w = 0; w < out.size(3); w++) {
64+
const auto in_w =
65+
nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]);
66+
for (auto c = 0; c < out.size(1); c++) {
67+
*out_data = in_data[in_h * in.strides()[2] + in_w * in.strides()[3] + c * in.strides()[1]];
68+
out_data++;
69+
}
70+
}
71+
}
72+
}
73+
}
4974
} // namespace
5075

5176
Tensor& upsample_nearest2d_vec_out(
@@ -79,11 +104,25 @@ Tensor& upsample_nearest2d_vec_out(
79104
const auto kernel_scale_w = area_pixel_compute_scale<double>(
80105
in.sizes()[3], out.sizes()[3], false, scale_w);
81106

107+
if (tensor_is_default_dim_order(in)) {
82108
ET_SWITCH_REALHBF16_TYPES(
83109
in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
84-
upsample_nearest2d_kernel_impl<CTYPE>(
110+
upsample_nearest2d_kernel_impl_nchw<CTYPE>(
85111
in, kernel_scale_h, kernel_scale_w, out);
86112
});
113+
} else if (executorch::runtime::tensor_is_channels_last_dim_order(in)) {
114+
ET_SWITCH_REALHBF16_TYPES(
115+
in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
116+
upsample_nearest2d_kernel_impl_nhwc<CTYPE>(
117+
in, kernel_scale_h, kernel_scale_w, out);
118+
});
119+
} else {
120+
// Shouldn't be reachable because of args checks, but just in case.
121+
ET_LOG(Error, "Unsupported dim order");
122+
ctx.fail(Error::InvalidArgument);
123+
return out;
124+
}
125+
87126

88127
return out;
89128
}

kernels/portable/cpu/util/upsample_util.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ bool check_upsample_2d_common_args(
1818
const executorch::aten::OptionalArrayRef<double>& scale_factors,
1919
Tensor& out) {
2020
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
21+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dim_order(in, out));
2122
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4);
2223
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4);
23-
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in));
24-
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out));
24+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
25+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
2526
ET_LOG_AND_RETURN_IF_FALSE(
2627
output_size.has_value() ^ scale_factors.has_value());
2728
if (scale_factors.has_value()) {

kernels/portable/test/op_upsample_bilinear2d_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,26 @@ def test_upsample_bilinear2d_aten_parity_f32(self):
6363
input, scale_factors=(out_h / h, out_w / w), align_corners=align_corners
6464
)
6565

66+
def test_upsample_bilinear2d_aten_parity_f32_channels_last(self):
67+
N = [1, 2]
68+
C = [1, 3]
69+
H = [1, 3, 50, 1001]
70+
W = [1, 2, 62, 1237]
71+
OUT_H = [5, 21]
72+
OUT_W = [7, 31]
73+
ALIGN_CORNERS = [True, False]
74+
75+
for n, c, h, w, out_h, out_w, align_corners in itertools.product(
76+
N, C, H, W, OUT_H, OUT_W, ALIGN_CORNERS
77+
):
78+
input = torch.randn(n, c, h, w).to(memory_format=torch.channels_last)
79+
self.run_upsample_test(
80+
input, output_size=(out_h, out_w), align_corners=align_corners
81+
)
82+
self.run_upsample_test(
83+
input, scale_factors=(out_h / h, out_w / w), align_corners=align_corners
84+
)
85+
6686
def test_upsample_bilinear2d_aten_parity_u8(self):
6787
N = [1, 2]
6888
C = [1, 3]
@@ -85,3 +105,26 @@ def test_upsample_bilinear2d_aten_parity_u8(self):
85105
align_corners=align_corners,
86106
atol=2,
87107
)
108+
109+
def test_upsample_bilinear2d_aten_parity_u8_channels_last(self):
110+
N = [1, 2]
111+
C = [1, 3]
112+
H = [1, 3, 50, 1001]
113+
W = [1, 2, 62, 1237]
114+
OUT_H = [5, 21]
115+
OUT_W = [7, 31]
116+
ALIGN_CORNERS = [True, False]
117+
118+
for n, c, h, w, out_h, out_w, align_corners in itertools.product(
119+
N, C, H, W, OUT_H, OUT_W, ALIGN_CORNERS
120+
):
121+
input = torch.randint(0, 255, (n, c, h, w), dtype=torch.uint8).to(memory_format=torch.channels_last)
122+
self.run_upsample_test(
123+
input, output_size=(out_h, out_w), align_corners=align_corners, atol=2
124+
)
125+
self.run_upsample_test(
126+
input,
127+
scale_factors=(out_h / h, out_w / w),
128+
align_corners=align_corners,
129+
atol=2,
130+
)

kernels/test/op_upsample_bilinear2d_test.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,59 @@ TEST_F(OpUpsampleBilinear2dTest, Simple5x1To4x1AlignCorners) {
577577

578578
EXPECT_TENSOR_CLOSE(out, expected);
579579
}
580+
581+
TEST_F(OpUpsampleBilinear2dTest, Simple1x2To1x4ChannelsLast) {
582+
TensorFactory<ScalarType::Float> tf;
583+
584+
const auto input = tf.make_channels_last({1, 1, 1, 2}, {1.0, 4.0});
585+
std::array<int64_t, 2> output_size = {1, 4};
586+
auto out = tf.zeros_channels_last({1, 1, 1, 4});
587+
588+
op_upsample_bilinear2d_vec_out(
589+
input,
590+
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
591+
false,
592+
{},
593+
out);
594+
595+
const auto expected = tf.make_channels_last({1, 1, 1, 4}, {1.0, 1.75, 3.25, 4.0});
596+
597+
EXPECT_TENSOR_EQ(out, expected);
598+
}
599+
600+
TEST_F(OpUpsampleBilinear2dTest, SmokeTestChannelsLast) {
601+
TensorFactory<ScalarType::Float> tf;
602+
603+
const auto input = tf.make_channels_last(
604+
{1, 2, 3, 4},
605+
{
606+
0.0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20,
607+
9, 21, 10, 22, 11, 23
608+
});
609+
std::array<int64_t, 2> output_size = {6, 8};
610+
auto out = tf.zeros_channels_last({1, 2, 6, 8});
611+
612+
op_upsample_bilinear2d_vec_out(
613+
input,
614+
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
615+
false,
616+
{},
617+
out);
618+
619+
const auto expected = tf.make_channels_last(
620+
{1, 2, 6, 8},
621+
{0.0000, 12.0000, 0.2500, 12.2500, 0.7500, 12.7500, 1.2500, 13.2500,
622+
1.7500, 13.7500, 2.2500, 14.2500, 2.7500, 14.7500, 3.0000, 15.0000,
623+
1.0000, 13.0000, 1.2500, 13.2500, 1.7500, 13.7500, 2.2500, 14.2500,
624+
2.7500, 14.7500, 3.2500, 15.2500, 3.7500, 15.7500, 4.0000, 16.0000,
625+
3.0000, 15.0000, 3.2500, 15.2500, 3.7500, 15.7500, 4.2500, 16.2500,
626+
4.7500, 16.7500, 5.2500, 17.2500, 5.7500, 17.7500, 6.0000, 18.0000,
627+
5.0000, 17.0000, 5.2500, 17.2500, 5.7500, 17.7500, 6.2500, 18.2500,
628+
6.7500, 18.7500, 7.2500, 19.2500, 7.7500, 19.7500, 8.0000, 20.0000,
629+
7.0000, 19.0000, 7.2500, 19.2500, 7.7500, 19.7500, 8.2500, 20.2500,
630+
8.7500, 20.7500, 9.2500, 21.2500, 9.7500, 21.7500, 10.0000, 22.0000,
631+
8.0000, 20.0000, 8.2500, 20.2500, 8.7500, 20.7500, 9.2500, 21.2500,
632+
9.7500, 21.7500, 10.2500, 22.2500, 10.7500, 22.7500, 11.0000, 23.0000});
633+
634+
EXPECT_TENSOR_CLOSE(out, expected);
635+
}

kernels/test/op_upsample_nearest2d_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,39 @@ TEST_F(OpUpsampleNearest2dTest, ZeroComputedOutputSizeDies) {
411411
{scale_factors.data(), scale_factors.size()}),
412412
out));
413413
}
414+
415+
TEST_F(OpUpsampleNearest2dTest, SmokeTestChannelsLast) {
416+
TensorFactory<ScalarType::Float> tf;
417+
418+
const auto input = tf.make_channels_last(
419+
{1, 2, 2, 2},
420+
{
421+
0.1,
422+
2.1,
423+
0.2,
424+
2.2,
425+
1.1,
426+
3.1,
427+
1.2,
428+
3.2,
429+
});
430+
std::array<int64_t, 2> output_size = {4, 4};
431+
auto out = tf.zeros_channels_last({1, 2, 4, 4});
432+
433+
op_upsample_nearest2d_out(
434+
input,
435+
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
436+
{},
437+
out);
438+
439+
const auto expected = tf.make_channels_last(
440+
{1, 2, 4, 4},
441+
{
442+
0.1000, 2.1000, 0.1000, 2.1000, 0.2000, 2.2000, 0.2000, 2.2000, 0.1000,
443+
2.1000, 0.1000, 2.1000, 0.2000, 2.2000, 0.2000, 2.2000, 1.1000, 3.1000,
444+
1.1000, 3.1000, 1.2000, 3.2000, 1.2000, 3.2000, 1.1000, 3.1000, 1.1000,
445+
3.1000, 1.2000, 3.2000, 1.2000, 3.2000
446+
});
447+
448+
EXPECT_TENSOR_EQ(out, expected);
449+
}

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,30 @@ class TensorFactory {
451451
.to(at::MemoryFormat::ChannelsLast);
452452
}
453453

454+
/**
455+
* Returns a new Tensor with the specified shape, containing channels-last
456+
* contiguous data with all `0` elements.
457+
*
458+
* @param[in] sizes The sizes of the dimensions of the Tensor.
459+
* @return A new Tensor with the specified shape.
460+
*/
461+
at::Tensor zeros(
462+
const std::vector<int32_t>& sizes,
463+
ET_UNUSED TensorShapeDynamism dynamism =
464+
TensorShapeDynamism::DYNAMIC_UNBOUND) {
465+
auto sizes64 = vec_32_to_64(sizes);
466+
return at::zeros(at::IntArrayRef(sizes64), at::dtype(DTYPE))
467+
.to(at::MemoryFormat::ChannelsLast);
468+
}
469+
454470
/**
455471
* Returns a new Tensor with the specified shape, containing contiguous data
456472
* with all `0` elements.
457473
*
458474
* @param[in] sizes The sizes of the dimensions of the Tensor.
459475
* @return A new Tensor with the specified shape.
460476
*/
461-
at::Tensor zeros(
477+
at::Tensor zeros_channels_last(
462478
const std::vector<int32_t>& sizes,
463479
ET_UNUSED TensorShapeDynamism dynamism =
464480
TensorShapeDynamism::DYNAMIC_UNBOUND) {

0 commit comments

Comments
 (0)