Skip to content

Commit 380f372

Browse files
authored
removed temp memory from mean dim out
Differential Revision: D70795553 Pull Request resolved: #9052
1 parent 2094659 commit 380f372

File tree

1 file changed

+4
-15
lines changed

1 file changed

+4
-15
lines changed

backends/cadence/fusion_g3/operators/op_mean.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ Tensor& mean_out(
118118
for (int i = 0; i < kNnlibMaxDim; i++) {
119119
out_shape[i] = 1;
120120
inp_shape[i] = 1;
121-
p_axis[i] = 1;
121+
p_axis[i] = -1;
122122
}
123123

124124
int num_axis_dims = prepare_data(
@@ -135,20 +135,10 @@ Tensor& mean_out(
135135
num_out_dims = 1;
136136
}
137137

138-
int inp_shape_max = inp_shape[p_axis[0]];
139-
for (int i = 1; i < num_axis_dims; i++) {
140-
if (inp_shape[p_axis[i]] > inp_shape_max) {
141-
inp_shape_max = inp_shape[p_axis[i]];
142-
}
138+
if ((out.dim() == 0) && (out.numel())) {
139+
num_out_dims = 1;
143140
}
144141

145-
int scratch_size = in.numel() / inp_shape_max;
146-
147-
executorch::runtime::Result<void*> temp_mem =
148-
ctx.allocate_temp(scratch_size * sizeof(float));
149-
150-
void* __restrict__ p_scratch_in = (void* __restrict__)(temp_mem.get());
151-
152142
XT_KERNEL_CHECK(
153143
ctx,
154144
out,
@@ -160,8 +150,7 @@ Tensor& mean_out(
160150
inp_shape,
161151
num_inp_dims,
162152
p_axis,
163-
num_axis_dims,
164-
p_scratch_in);
153+
num_axis_dims);
165154
} else {
166155
ET_KERNEL_CHECK(
167156
ctx,

0 commit comments

Comments
 (0)