Skip to content

Commit 0617b60

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add max_pool2d_with_indices_backward
Differential Revision: D70577129
1 parent d897e73 commit 0617b60

File tree

8 files changed

+412
-4
lines changed

8 files changed

+412
-4
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@
249249

250250
- op: max_pool2d_with_indices.out
251251

252+
- op: max_pool2d_with_indices_backward.grad_input
253+
252254
- op: max.dim_max
253255

254256
- op: max.unary_out
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using Tensor = executorch::aten::Tensor;
17+
using ScalarType = executorch::aten::ScalarType;
18+
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;
19+
20+
namespace {
21+
22+
bool check_max_pool2d_backward_args(
23+
const Tensor& grad_output,
24+
const Tensor& input,
25+
IntArrayRef kernel_size,
26+
IntArrayRef stride,
27+
IntArrayRef padding,
28+
IntArrayRef dilation,
29+
bool ceil_mode,
30+
const Tensor& indices,
31+
const Tensor& grad_input) {
32+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
33+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
34+
35+
ET_CHECK_OR_RETURN_FALSE(
36+
check_max_pool2d_with_indices_args(
37+
input,
38+
kernel_size,
39+
stride,
40+
padding,
41+
dilation,
42+
ceil_mode,
43+
grad_output,
44+
indices),
45+
"Invalid max_pool_2d arguments");
46+
47+
size_t output_ndim = 0;
48+
executorch::aten::SizesType output_sizes[kTensorDimensionLimit];
49+
get_max_pool2d_with_indices_out_target_size(
50+
input,
51+
kernel_size,
52+
stride,
53+
padding,
54+
dilation,
55+
ceil_mode,
56+
output_sizes,
57+
&output_ndim);
58+
59+
ET_LOG_AND_RETURN_IF_FALSE(
60+
output_size_is_valid({output_sizes, output_ndim}, 2));
61+
62+
ET_CHECK_OR_RETURN_FALSE(
63+
grad_output.dim() == input.dim(),
64+
"grad_output should have same number of dimensions as input");
65+
66+
ET_LOG_AND_RETURN_IF_FALSE(
67+
tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));
68+
69+
return true;
70+
}
71+
72+
template <typename CTYPE, bool is_3d>
73+
void max_pool_backward_impl(
74+
const Tensor& grad_input,
75+
const Tensor& grad_output,
76+
const Tensor& indices) {
77+
const CTYPE* grad_output_data = grad_output.const_data_ptr<CTYPE>();
78+
const int64_t* indices_data = indices.const_data_ptr<int64_t>();
79+
CTYPE* grad_input_data = grad_input.mutable_data_ptr<CTYPE>();
80+
81+
// treat batch size and channels as one dimension
82+
//
83+
// MaxPool2d:
84+
// ndim == 3: CHW
85+
// ndim == 4: NCHW
86+
//
87+
// MaxPool3d:
88+
// ndim == 4: CDHW
89+
// ndim == 5: NCDHW
90+
int64_t ndim = grad_output.dim();
91+
int64_t channels;
92+
if (is_3d) {
93+
channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
94+
} else {
95+
channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
96+
}
97+
int64_t input_depth = is_3d ? grad_input.size(-3) : 1;
98+
99+
int64_t input_height = grad_input.size(ndim-2);
100+
int64_t input_width = grad_input.size(ndim-1);
101+
int64_t output_depth = is_3d ? grad_output.size(ndim-3) : 1;
102+
int64_t output_height = grad_output.size(ndim-2);
103+
int64_t output_width = grad_output.size(ndim-1);
104+
105+
for (int64_t c = 0; c < channels; ++c) {
106+
CTYPE* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
107+
const CTYPE* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
108+
const int64_t * indices_ptr = indices_data + c * output_depth * output_height * output_width;
109+
110+
for (int64_t od = 0; od < output_depth; od++) {
111+
for (int64_t oh = 0; oh < output_height; oh++) {
112+
for (int64_t ow = 0; ow < output_width; ow++) {
113+
// retrieve position of max
114+
int64_t index = od * output_height * output_width + oh * output_width + ow;
115+
int64_t maxindex = indices_ptr[index];
116+
if (maxindex != -1) {
117+
// update gradient
118+
grad_input_ptr[maxindex] += grad_output_ptr[index];
119+
}
120+
}
121+
}
122+
}
123+
}
124+
}
125+
126+
} // namespace
127+
128+
Tensor& max_pool2d_with_indices_backward_out(
129+
KernelRuntimeContext& ctx,
130+
const Tensor& grad_output,
131+
const Tensor& input,
132+
ET_UNUSED IntArrayRef kernel_size,
133+
ET_UNUSED IntArrayRef stride,
134+
ET_UNUSED IntArrayRef padding,
135+
ET_UNUSED IntArrayRef dilation,
136+
ET_UNUSED bool ceil_mode,
137+
const Tensor& indices,
138+
Tensor& grad_input) {
139+
(void)ctx;
140+
141+
ET_KERNEL_CHECK(
142+
ctx,
143+
check_max_pool2d_backward_args(
144+
grad_output,
145+
input,
146+
kernel_size,
147+
stride,
148+
padding,
149+
dilation,
150+
ceil_mode,
151+
indices,
152+
grad_input),
153+
InvalidArgument,
154+
grad_input);
155+
156+
ET_KERNEL_CHECK(
157+
ctx,
158+
resize_tensor(grad_input, input.sizes()) == Error::Ok,
159+
InvalidArgument,
160+
grad_input);
161+
162+
constexpr auto name = "max_pool2d_with_indices_backward.grad_input";
163+
164+
ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
165+
max_pool_backward_impl<CTYPE, false>(
166+
grad_input,
167+
grad_output,
168+
indices);
169+
});
170+
171+
return grad_input;
172+
}
173+
174+
} // namespace native
175+
} // namespace executor
176+
} // namespace torch

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ bool check_max_pool2d_with_indices_args(
470470
IntArrayRef padding,
471471
IntArrayRef dilation,
472472
bool ceil_mode,
473-
Tensor& out,
474-
Tensor& indices) {
473+
const Tensor& out,
474+
const Tensor& indices) {
475475
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
476476
ET_CHECK_OR_RETURN_FALSE(
477477
indices.scalar_type() == ScalarType::Long,

kernels/portable/cpu/util/kernel_ops_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ bool check_max_pool2d_with_indices_args(
442442
IntArrayRef padding,
443443
IntArrayRef dilation,
444444
bool ceil_mode,
445-
Tensor& out,
446-
Tensor& indices);
445+
const Tensor& out,
446+
const Tensor& indices);
447447

448448
void get_max_pool2d_with_indices_out_target_size(
449449
const Tensor& in,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,11 @@
572572
- arg_meta: null
573573
kernel_name: torch::executor::max_pool2d_with_indices_out
574574

575+
- op: max_pool2d_with_indices_backward.grad_input
576+
kernels:
577+
- arg_meta: null
578+
kernel_name: torch::executor::max_pool2d_with_indices_backward_out
579+
575580
- op: mean.out
576581
kernels:
577582
- arg_meta: null

0 commit comments

Comments
 (0)