Skip to content

Commit c294920

Browse files
author
Zonglin Peng
committed
init
1 parent 3b45e88 commit c294920

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

backends/cadence/reference/operators/quantized_relu_out.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/backends/cadence/reference/kernels/kernels.h>
10+
#include <executorch/backends/cadence/reference/operators/operators.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112

1213
namespace impl {
@@ -75,6 +76,59 @@ void quantized_relu_out(
7576
}
7677
}
7778

79+
template <typename T>
80+
void quantized_relu_per_tensor_out_(
81+
__ET_UNUSED KernelRuntimeContext& ctx,
82+
const Tensor& input,
83+
const int64_t in_zero_point,
84+
const int64_t out_zero_point,
85+
const int64_t out_multiplier,
86+
const int64_t out_shift,
87+
Tensor& output) {
88+
const T* __restrict__ in = input.const_data_ptr<T>();
89+
T* __restrict__ out = output.mutable_data_ptr<T>();
90+
91+
// Compute the out_scale from out_multiplier and out_shift
92+
const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift);
93+
94+
for (size_t i = 0, e = input.numel(); i < e; ++i) {
95+
const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0;
96+
out[i] = kernels::quantize<T>(temp, out_scale, out_zero_point);
97+
}
98+
}
99+
100+
void quantized_relu_per_tensor_out(
101+
KernelRuntimeContext& ctx,
102+
const Tensor& input,
103+
const int64_t in_zero_point,
104+
const int64_t out_zero_point,
105+
const int64_t out_multiplier,
106+
const int64_t out_shift,
107+
Tensor& output) {
108+
#define typed_quantized_relu(ctype, dtype) \
109+
case executorch::aten::ScalarType::dtype: { \
110+
quantized_relu_per_tensor_out_<ctype>( \
111+
ctx, \
112+
input, \
113+
in_zero_point, \
114+
out_zero_point, \
115+
out_multiplier, \
116+
out_shift, \
117+
output); \
118+
break; \
119+
}
120+
121+
executorch::aten::ScalarType dtype = input.scalar_type();
122+
switch (dtype) {
123+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu)
124+
default:
125+
ET_DCHECK_MSG(
126+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
127+
}
128+
129+
#undef typed_quantized_relu
130+
}
131+
78132
}; // namespace native
79133
}; // namespace reference
80134
}; // namespace impl

0 commit comments

Comments
 (0)