Skip to content

Commit 88d9616

Browse files
authored
Fix scalar (single element tensor) binary ops on HiFi
Differential Revision: D71667998 Pull Request resolved: #9523
1 parent 90f0843 commit 88d9616

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

backends/cadence/hifi/operators/op_add.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,13 @@ Tensor& add_out(
142142
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
143143

144144
if ((a_dim == 0) && float_types) {
145-
for (int i = 0; i < max_dim; i++)
145+
for (int i = 0; i < b.numel(); i++)
146146
out.mutable_data_ptr<float>()[i] =
147147
a.const_data_ptr<float>()[0] + b.const_data_ptr<float>()[i];
148148
return out;
149149
}
150150
if ((b_dim == 0) && float_types) {
151-
for (int i = 0; i < max_dim; i++)
151+
for (int i = 0; i < a.numel(); i++)
152152
out.mutable_data_ptr<float>()[i] =
153153
a.const_data_ptr<float>()[i] + b.const_data_ptr<float>()[0];
154154
return out;

backends/cadence/hifi/operators/op_div.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
9090
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
9191

9292
if ((a_dim == 0) && float_types) {
93-
for (int i = 0; i < max_dim; i++)
93+
for (int i = 0; i < b.numel(); i++)
9494
out.mutable_data_ptr<float>()[i] =
9595
a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
9696
return out;
9797
}
9898
if ((b_dim == 0) && float_types) {
99-
for (int i = 0; i < max_dim; i++)
99+
for (int i = 0; i < a.numel(); i++)
100100
out.mutable_data_ptr<float>()[i] =
101101
a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
102102
return out;
@@ -218,13 +218,13 @@ Tensor& div_out_mode(
218218
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
219219

220220
if ((a_dim == 0) && float_types) {
221-
for (int i = 0; i < max_dim; i++)
221+
for (int i = 0; i < b.numel(); i++)
222222
out.mutable_data_ptr<float>()[i] =
223223
a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
224224
return out;
225225
}
226226
if ((b_dim == 0) && float_types) {
227-
for (int i = 0; i < max_dim; i++)
227+
for (int i = 0; i < a.numel(); i++)
228228
out.mutable_data_ptr<float>()[i] =
229229
a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
230230
return out;

backends/cadence/hifi/operators/op_mul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
108108
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
109109

110110
if ((a_dim == 0) && float_types) {
111-
for (int i = 0; i < max_dim; i++)
111+
for (int i = 0; i < b.numel(); i++)
112112
out.mutable_data_ptr<float>()[i] =
113113
a.const_data_ptr<float>()[0] * b.const_data_ptr<float>()[i];
114114
return out;
115115
}
116116
if ((b_dim == 0) && float_types) {
117-
for (int i = 0; i < max_dim; i++)
117+
for (int i = 0; i < a.numel(); i++)
118118
out.mutable_data_ptr<float>()[i] =
119119
a.const_data_ptr<float>()[i] * b.const_data_ptr<float>()[0];
120120
return out;

backends/cadence/hifi/operators/op_sub.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ Tensor& sub_out(
137137
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
138138

139139
if ((a_dim == 0) && float_types) {
140-
for (int i = 0; i < max_dim; i++)
140+
for (int i = 0; i < b.numel(); i++)
141141
out.mutable_data_ptr<float>()[i] =
142142
a.const_data_ptr<float>()[0] - b.const_data_ptr<float>()[i];
143143
return out;
144144
}
145145
if ((b_dim == 0) && float_types) {
146-
for (int i = 0; i < max_dim; i++)
146+
for (int i = 0; i < a.numel(); i++)
147147
out.mutable_data_ptr<float>()[i] =
148148
a.const_data_ptr<float>()[i] - b.const_data_ptr<float>()[0];
149149
return out;

0 commit comments

Comments
 (0)