@@ -58,17 +58,17 @@ Tensor& div_out(
58
58
static constexpr const char op_name[] = " div.out" ;
59
59
60
60
ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61
- utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62
- [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63
- return val_a / val_b;
64
- },
61
+ utils::apply_bitensor_elementwise_fn<
62
+ CTYPE_COMPUTE,
63
+ op_name,
64
+ utils::SupportedTensorDtypes::FLOATHBF16>(
65
+ [](const auto val_a, const auto val_b) { return val_a / val_b; },
65
66
ctx,
66
67
a,
67
68
utils::SupportedTensorDtypes::REALHBBF16,
68
69
b,
69
70
utils::SupportedTensorDtypes::REALHBBF16,
70
- out,
71
- utils::SupportedTensorDtypes::FLOATHBF16);
71
+ out);
72
72
});
73
73
74
74
return out;
@@ -122,9 +122,13 @@ Tensor& div_out_mode(
122
122
bool div_by_zero_error = false ;
123
123
124
124
ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125
- utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
125
+ utils::apply_bitensor_elementwise_fn<
126
+ CTYPE_COMPUTE,
127
+ op_name,
128
+ utils::SupportedTensorDtypes::REALHBF16>(
126
129
[mode_is_trunc, &div_by_zero_error](
127
130
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
131
+ // TODO: rewrite this to be vectorization-capable.
128
132
if (is_integral_type<CTYPE_COMPUTE, /* includeBool=*/ true >::value) {
129
133
if (val_b == 0 ) {
130
134
div_by_zero_error = true ;
@@ -146,8 +150,7 @@ Tensor& div_out_mode(
146
150
utils::SupportedTensorDtypes::REALHBBF16,
147
151
b,
148
152
utils::SupportedTensorDtypes::REALHBBF16,
149
- out,
150
- utils::SupportedTensorDtypes::REALHBF16);
153
+ out);
151
154
});
152
155
153
156
ET_KERNEL_CHECK_MSG (
@@ -188,13 +191,15 @@ Tensor& div_scalar_out(
188
191
189
192
ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190
193
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191
- utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192
- [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
194
+ utils::apply_unitensor_elementwise_fn<
195
+ CTYPE_COMPUTE,
196
+ op_name,
197
+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
198
+ [val_b](const auto val_a) { return val_a / val_b; },
193
199
ctx,
194
200
a,
195
201
utils::SupportedTensorDtypes::REALHBBF16,
196
- out,
197
- utils::SupportedTensorDtypes::SAME_AS_COMMON);
202
+ out);
198
203
});
199
204
200
205
return out;
0 commit comments