Skip to content

Support channels_last format in portable upsample kernels #9526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 96 additions & 3 deletions kernels/portable/cpu/op_upsample_bilinear2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using executorch::aten::SizesType;

namespace {
template <typename CTYPE>
void upsample_bilinear2d_kernel_impl(
void upsample_bilinear2d_kernel_impl_nchw(
const Tensor& in,
bool align_corners,
const float scale_h,
Expand Down Expand Up @@ -86,6 +86,99 @@ void upsample_bilinear2d_kernel_impl(
}
}
}

template <typename CTYPE>
void upsample_bilinear2d_kernel_impl_nhwc(
const Tensor& in,
bool align_corners,
const float scale_h,
const float scale_w,
Tensor& out) {
auto in_data = in.const_data_ptr<CTYPE>();
auto out_data = out.mutable_data_ptr<CTYPE>();

for ([[maybe_unused]] const auto n : c10::irange(out.size(0))) {
for (const auto h : c10::irange(out.size(2))) {
// Compute source index and weights.
int64_t in_h1, in_h2;
float weight_h, inv_weight_h;

compute_source_index_and_lambda(
in_h1,
in_h2,
weight_h,
inv_weight_h,
scale_h,
h,
in.sizes()[2],
out.sizes()[2],
align_corners);

for (const auto w : c10::irange(out.size(3))) {
int64_t in_w1, in_w2;
float weight_w, inv_weight_w;

compute_source_index_and_lambda(
in_w1,
in_w2,
weight_w,
inv_weight_w,
scale_w,
w,
in.sizes()[3],
out.sizes()[3],
align_corners);

for ([[maybe_unused]] const auto c : c10::irange(out.size(1))) {
const auto top_left = in_data
[in_h1 * in.strides()[2] + in_w1 * in.strides()[3] +
c * in.strides()[1]];
const auto top_right = in_data
[in_h1 * in.strides()[2] + in_w2 * in.strides()[3] +
c * in.strides()[1]];
const auto bottom_left = in_data
[in_h2 * in.strides()[2] + in_w1 * in.strides()[3] +
c * in.strides()[1]];
const auto bottom_right = in_data
[in_h2 * in.strides()[2] + in_w2 * in.strides()[3] +
c * in.strides()[1]];

const auto top = top_left * weight_w + top_right * inv_weight_w;
const auto bottom =
bottom_left * weight_w + bottom_right * inv_weight_w;
const auto val = top * weight_h + bottom * inv_weight_h;

*out_data = val;
out_data++;
}
}
}

in_data += in.strides()[0];
}
}

template <typename CTYPE>
void upsample_bilinear2d_kernel_impl(
KernelRuntimeContext& ctx,
const Tensor& in,
bool align_corners,
const float scale_h,
const float scale_w,
Tensor& out) {
if (is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size())) {
upsample_bilinear2d_kernel_impl_nchw<CTYPE>(
in, align_corners, scale_h, scale_w, out);
} else if (is_channels_last_dim_order(
in.dim_order().data(), in.dim_order().size())) {
upsample_bilinear2d_kernel_impl_nhwc<CTYPE>(
in, align_corners, scale_h, scale_w, out);
} else {
// Shouldn't be reachable because of args checks, but just in case.
ET_LOG(Error, "Unsupported dim order");
ctx.fail(Error::InvalidArgument);
}
}
} // namespace

// Signatures are auto-generated, so disable pass-by-value lint.
Expand All @@ -101,7 +194,7 @@ Tensor& upsample_bilinear2d_vec_out(
// Preconditions (checked in check_..._args):
// In and out tensors have same dtype.
// In and out tensors are rank 4 and have same dim[0] and dim[1].
// In and out tensors are default dim order (NCHW).
// In and out tensors are NHWC or NCHW dim order.
ET_KERNEL_CHECK(
ctx,
check_upsample_bilinear2d_args(
Expand All @@ -127,7 +220,7 @@ Tensor& upsample_bilinear2d_vec_out(
ET_SWITCH_REALHBF16_TYPES(
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
upsample_bilinear2d_kernel_impl<CTYPE>(
in, align_corners, kernel_scale_h, kernel_scale_w, out);
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
});

return out;
Expand Down
52 changes: 50 additions & 2 deletions kernels/portable/cpu/op_upsample_nearest2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using executorch::aten::SizesType;

namespace {
template <typename CTYPE>
void upsample_nearest2d_kernel_impl(
void upsample_nearest2d_kernel_impl_nchw(
const Tensor& in,
const float scale_h,
const float scale_w,
Expand All @@ -46,6 +46,54 @@ void upsample_nearest2d_kernel_impl(
}
}
}

template <typename CTYPE>
void upsample_nearest2d_kernel_impl_nhwc(
const Tensor& in,
const float scale_h,
const float scale_w,
Tensor& out) {
auto in_data = in.const_data_ptr<CTYPE>();
auto out_data = out.mutable_data_ptr<CTYPE>();

for (auto n = 0; n < out.size(0); n++) {
for (auto h = 0; h < out.size(2); h++) {
const auto in_h =
nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]);
for (auto w = 0; w < out.size(3); w++) {
const auto in_w =
nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]);
for (auto c = 0; c < out.size(1); c++) {
*out_data = in_data
[in_h * in.strides()[2] + in_w * in.strides()[3] +
c * in.strides()[1]];
out_data++;
}
}
}

in_data += in.strides()[0];
}
}

template <typename CTYPE>
void upsample_nearest2d_kernel_impl(
KernelRuntimeContext& ctx,
const Tensor& in,
const float scale_h,
const float scale_w,
Tensor& out) {
if (is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size())) {
upsample_nearest2d_kernel_impl_nchw<CTYPE>(in, scale_h, scale_w, out);
} else if (is_channels_last_dim_order(
in.dim_order().data(), in.dim_order().size())) {
upsample_nearest2d_kernel_impl_nhwc<CTYPE>(in, scale_h, scale_w, out);
} else {
// Shouldn't be reachable because of args checks, but just in case.
ET_LOG(Error, "Unsupported dim order");
ctx.fail(Error::InvalidArgument);
}
}
} // namespace

Tensor& upsample_nearest2d_vec_out(
Expand Down Expand Up @@ -82,7 +130,7 @@ Tensor& upsample_nearest2d_vec_out(
ET_SWITCH_REALHBF16_TYPES(
in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
upsample_nearest2d_kernel_impl<CTYPE>(
in, kernel_scale_h, kernel_scale_w, out);
ctx, in, kernel_scale_h, kernel_scale_w, out);
});

return out;
Expand Down
5 changes: 3 additions & 2 deletions kernels/portable/cpu/util/upsample_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ bool check_upsample_2d_common_args(
const executorch::aten::OptionalArrayRef<double>& scale_factors,
Tensor& out) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dim_order(in, out));
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4);
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4);
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in));
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out));
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
ET_LOG_AND_RETURN_IF_FALSE(
output_size.has_value() ^ scale_factors.has_value());
if (scale_factors.has_value()) {
Expand Down
148 changes: 148 additions & 0 deletions kernels/test/op_upsample_bilinear2d_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,28 @@ TEST_F(OpUpsampleBilinear2dTest, ZeroComputedOutputSizeDies) {
out));
}

TEST_F(OpUpsampleBilinear2dTest, MismatchedDimOrderDies) {
TensorFactory<ScalarType::Float> tf;

if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "ATen kernel can implicitly convert dim order";
}

const auto input = tf.ones({1, 1, 1, 2});
auto out = tf.zeros_channels_last({1, 1, 1, 4});
std::array<double, 2> scale_factors = {2, 2};

ET_EXPECT_KERNEL_FAILURE(
context_,
op_upsample_bilinear2d_vec_out(
input,
{},
false,
OptionalArrayRef<double>(
{scale_factors.data(), scale_factors.size()}),
out));
}

TEST_F(OpUpsampleBilinear2dTest, NumericsCheck) {
TensorFactory<ScalarType::Float> tf;

Expand Down Expand Up @@ -577,3 +599,129 @@ TEST_F(OpUpsampleBilinear2dTest, Simple5x1To4x1AlignCorners) {

EXPECT_TENSOR_CLOSE(out, expected);
}

TEST_F(OpUpsampleBilinear2dTest, Simple1x2To1x4ChannelsLast) {
TensorFactory<ScalarType::Float> tf;

const auto input = tf.make_channels_last({1, 1, 1, 2}, {1.0, 4.0});
std::array<int64_t, 2> output_size = {1, 4};
auto out = tf.zeros_channels_last({1, 1, 1, 4});

op_upsample_bilinear2d_vec_out(
input,
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
false,
{},
out);

const auto expected =
tf.make_channels_last({1, 1, 1, 4}, {1.0, 1.75, 3.25, 4.0});

EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(OpUpsampleBilinear2dTest, SmokeTestChannelsLast) {
TensorFactory<ScalarType::Float> tf;

const auto input = tf.make_channels_last(
{1, 2, 3, 4}, {0.0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17,
6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23});
std::array<int64_t, 2> output_size = {6, 8};
auto out = tf.zeros_channels_last({1, 2, 6, 8});

op_upsample_bilinear2d_vec_out(
input,
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
false,
{},
out);

const auto expected = tf.make_channels_last(
{1, 2, 6, 8},
{0.0000, 12.0000, 0.2500, 12.2500, 0.7500, 12.7500, 1.2500, 13.2500,
1.7500, 13.7500, 2.2500, 14.2500, 2.7500, 14.7500, 3.0000, 15.0000,
1.0000, 13.0000, 1.2500, 13.2500, 1.7500, 13.7500, 2.2500, 14.2500,
2.7500, 14.7500, 3.2500, 15.2500, 3.7500, 15.7500, 4.0000, 16.0000,
3.0000, 15.0000, 3.2500, 15.2500, 3.7500, 15.7500, 4.2500, 16.2500,
4.7500, 16.7500, 5.2500, 17.2500, 5.7500, 17.7500, 6.0000, 18.0000,
5.0000, 17.0000, 5.2500, 17.2500, 5.7500, 17.7500, 6.2500, 18.2500,
6.7500, 18.7500, 7.2500, 19.2500, 7.7500, 19.7500, 8.0000, 20.0000,
7.0000, 19.0000, 7.2500, 19.2500, 7.7500, 19.7500, 8.2500, 20.2500,
8.7500, 20.7500, 9.2500, 21.2500, 9.7500, 21.7500, 10.0000, 22.0000,
8.0000, 20.0000, 8.2500, 20.2500, 8.7500, 20.7500, 9.2500, 21.2500,
9.7500, 21.7500, 10.2500, 22.2500, 10.7500, 22.7500, 11.0000, 23.0000});

EXPECT_TENSOR_CLOSE(out, expected);
}

TEST_F(OpUpsampleBilinear2dTest, NumericsCheckChannelsLast) {
TensorFactory<ScalarType::Float> tf;

const auto input = tf.zeros_channels_last({3, 7, 47, 99});
auto out = tf.zeros_channels_last({3, 7, 291, 512});
std::array<int64_t, 2> output_size = {291, 512};

auto input_ptr = static_cast<float*>(input.mutable_data_ptr());
for (auto i = 0ul; i < input.numel(); i++) {
input_ptr[i] = static_cast<float>(i);
}

op_upsample_bilinear2d_vec_out(
input,
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
false,
{},
out);

// Indices and expected values to evaluate.
std::vector<std::tuple<int, int, int, int, float>> test_values = {
{0, 2, 60, 200, 6695.0137},
{1, 6, 5, 503, 33524.098},
{2, 0, 111, 300, 77678.68},
};

const auto output_data = static_cast<const float*>(out.const_data_ptr());
for (const auto& test_case : test_values) {
const auto [n, c, h, w, expected] = test_case;
const auto actual = output_data
[n * out.strides()[0] + c * out.strides()[1] + h * out.strides()[2] +
w * out.strides()[3]];
EXPECT_FLOAT_EQ(expected, actual);
}
}

TEST_F(OpUpsampleBilinear2dTest, NumericsCheckAlignCornersChannelsLast) {
TensorFactory<ScalarType::Float> tf;

const auto input = tf.zeros_channels_last({3, 7, 47, 99});
auto out = tf.zeros_channels_last({3, 7, 291, 512});
std::array<int64_t, 2> output_size = {291, 512};

auto input_ptr = static_cast<float*>(input.mutable_data_ptr());
for (auto i = 0ul; i < input.numel(); i++) {
input_ptr[i] = static_cast<float>(i);
}

op_upsample_bilinear2d_vec_out(
input,
OptionalArrayRef<int64_t>({output_size.data(), output_size.size()}),
true,
{},
out);

// Indices and expected values to evaluate.
std::vector<std::tuple<int, int, int, int, float>> test_values = {
{0, 2, 60, 200, 6865.9414},
{1, 6, 5, 503, 33801.883},
{2, 0, 111, 300, 77746.32},
};

const auto output_data = static_cast<const float*>(out.const_data_ptr());
for (const auto& test_case : test_values) {
const auto [n, c, h, w, expected] = test_case;
const auto actual = output_data
[n * out.strides()[0] + c * out.strides()[1] + h * out.strides()[2] +
w * out.strides()[3]];
EXPECT_FLOAT_EQ(expected, actual);
}
}
Loading
Loading