Skip to content

Commit 4dbd460

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add max_pool2d_with_indices_backward (#8940)
Summary: Pull Request resolved: #8940 Reviewed By: JacobSzwejbka Differential Revision: D70577129
1 parent ef73540 commit 4dbd460

File tree

8 files changed

+427
-4
lines changed

8 files changed

+427
-4
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@
249249

250250
- op: max_pool2d_with_indices.out
251251

252+
- op: max_pool2d_with_indices_backward.grad_input
253+
252254
- op: max.dim_max
253255

254256
- op: max.unary_out
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using Tensor = executorch::aten::Tensor;
17+
using ScalarType = executorch::aten::ScalarType;
18+
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;
19+
20+
namespace {
21+
22+
bool check_max_pool2d_backward_args(
23+
const Tensor& grad_output,
24+
const Tensor& input,
25+
IntArrayRef kernel_size,
26+
IntArrayRef stride,
27+
IntArrayRef padding,
28+
IntArrayRef dilation,
29+
bool ceil_mode,
30+
const Tensor& indices,
31+
const Tensor& grad_input) {
32+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
33+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
34+
35+
ET_CHECK_OR_RETURN_FALSE(
36+
check_max_pool2d_with_indices_args(
37+
input,
38+
kernel_size,
39+
stride,
40+
padding,
41+
dilation,
42+
ceil_mode,
43+
grad_output,
44+
indices),
45+
"Invalid max_pool_2d arguments");
46+
47+
size_t output_ndim = 0;
48+
// @lint-ignore CLANGTIDY facebook-hte-CArray
49+
executorch::aten::SizesType output_sizes[kTensorDimensionLimit];
50+
get_max_pool2d_with_indices_out_target_size(
51+
input,
52+
kernel_size,
53+
stride,
54+
padding,
55+
dilation,
56+
ceil_mode,
57+
output_sizes,
58+
&output_ndim);
59+
60+
ET_LOG_AND_RETURN_IF_FALSE(
61+
output_size_is_valid({output_sizes, output_ndim}, 2));
62+
63+
ET_CHECK_OR_RETURN_FALSE(
64+
grad_output.dim() == input.dim(),
65+
"grad_output should have same number of dimensions as input");
66+
67+
ET_LOG_AND_RETURN_IF_FALSE(
68+
tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));
69+
70+
return true;
71+
}
72+
73+
template <typename CTYPE, bool is_3d>
74+
void max_pool_backward_impl(
75+
const Tensor& grad_input,
76+
const Tensor& grad_output,
77+
const Tensor& indices) {
78+
const CTYPE* grad_output_data = grad_output.const_data_ptr<CTYPE>();
79+
const int64_t* indices_data = indices.const_data_ptr<int64_t>();
80+
CTYPE* grad_input_data = grad_input.mutable_data_ptr<CTYPE>();
81+
82+
// treat batch size and channels as one dimension
83+
//
84+
// MaxPool2d:
85+
// ndim == 3: CHW
86+
// ndim == 4: NCHW
87+
//
88+
// MaxPool3d:
89+
// ndim == 4: CDHW
90+
// ndim == 5: NCDHW
91+
int64_t ndim = grad_output.dim();
92+
int64_t channels;
93+
if (is_3d) {
94+
channels = ndim == 4 ? grad_output.size(0)
95+
: grad_output.size(0) * grad_output.size(1);
96+
} else {
97+
channels = ndim == 3 ? grad_output.size(0)
98+
: grad_output.size(0) * grad_output.size(1);
99+
}
100+
int64_t input_depth = is_3d ? grad_input.size(-3) : 1;
101+
102+
int64_t input_height = grad_input.size(ndim - 2);
103+
int64_t input_width = grad_input.size(ndim - 1);
104+
int64_t output_depth = is_3d ? grad_output.size(ndim - 3) : 1;
105+
int64_t output_height = grad_output.size(ndim - 2);
106+
int64_t output_width = grad_output.size(ndim - 1);
107+
108+
for (int64_t c = 0; c < channels; ++c) {
109+
CTYPE* grad_input_ptr =
110+
grad_input_data + c * input_depth * input_height * input_width;
111+
const CTYPE* grad_output_ptr =
112+
grad_output_data + c * output_depth * output_height * output_width;
113+
const int64_t* indices_ptr =
114+
indices_data + c * output_depth * output_height * output_width;
115+
116+
for (int64_t od = 0; od < output_depth; od++) {
117+
for (int64_t oh = 0; oh < output_height; oh++) {
118+
for (int64_t ow = 0; ow < output_width; ow++) {
119+
// retrieve position of max
120+
int64_t index =
121+
od * output_height * output_width + oh * output_width + ow;
122+
int64_t maxindex = indices_ptr[index];
123+
if (maxindex != -1) {
124+
// update gradient
125+
grad_input_ptr[maxindex] += grad_output_ptr[index];
126+
}
127+
}
128+
}
129+
}
130+
}
131+
}
132+
133+
} // namespace
134+
135+
Tensor& max_pool2d_with_indices_backward_out(
136+
KernelRuntimeContext& ctx,
137+
const Tensor& grad_output,
138+
const Tensor& input,
139+
ET_UNUSED IntArrayRef kernel_size,
140+
ET_UNUSED IntArrayRef stride,
141+
ET_UNUSED IntArrayRef padding,
142+
ET_UNUSED IntArrayRef dilation,
143+
ET_UNUSED bool ceil_mode,
144+
const Tensor& indices,
145+
Tensor& grad_input) {
146+
(void)ctx;
147+
148+
ET_KERNEL_CHECK(
149+
ctx,
150+
check_max_pool2d_backward_args(
151+
grad_output,
152+
input,
153+
kernel_size,
154+
stride,
155+
padding,
156+
dilation,
157+
ceil_mode,
158+
indices,
159+
grad_input),
160+
InvalidArgument,
161+
grad_input);
162+
163+
ET_KERNEL_CHECK(
164+
ctx,
165+
resize_tensor(grad_input, input.sizes()) == Error::Ok,
166+
InvalidArgument,
167+
grad_input);
168+
169+
constexpr auto name = "max_pool2d_with_indices_backward.grad_input";
170+
171+
ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
172+
max_pool_backward_impl<CTYPE, false>(grad_input, grad_output, indices);
173+
});
174+
175+
return grad_input;
176+
}
177+
178+
} // namespace native
179+
} // namespace executor
180+
} // namespace torch

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ bool check_max_pool2d_with_indices_args(
470470
IntArrayRef padding,
471471
IntArrayRef dilation,
472472
bool ceil_mode,
473-
Tensor& out,
474-
Tensor& indices) {
473+
const Tensor& out,
474+
const Tensor& indices) {
475475
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
476476
ET_CHECK_OR_RETURN_FALSE(
477477
indices.scalar_type() == ScalarType::Long,

kernels/portable/cpu/util/kernel_ops_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ bool check_max_pool2d_with_indices_args(
442442
IntArrayRef padding,
443443
IntArrayRef dilation,
444444
bool ceil_mode,
445-
Tensor& out,
446-
Tensor& indices);
445+
const Tensor& out,
446+
const Tensor& indices);
447447

448448
void get_max_pool2d_with_indices_out_target_size(
449449
const Tensor& in,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,11 @@
572572
- arg_meta: null
573573
kernel_name: torch::executor::max_pool2d_with_indices_out
574574

575+
- op: max_pool2d_with_indices_backward.grad_input
576+
kernels:
577+
- arg_meta: null
578+
kernel_name: torch::executor::max_pool2d_with_indices_backward_out
579+
575580
- op: mean.out
576581
kernels:
577582
- arg_meta: null

0 commit comments

Comments
 (0)