|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include <executorch/backends/cadence/reference/kernels/kernels.h>
|
| 10 | +#include <executorch/backends/cadence/reference/operators/operators.h> |
10 | 11 | #include <executorch/runtime/kernel/kernel_includes.h>
|
11 | 12 |
|
12 | 13 | namespace impl {
|
@@ -75,6 +76,59 @@ void quantized_relu_out(
|
75 | 76 | }
|
76 | 77 | }
|
77 | 78 |
|
| 79 | +template <typename T> |
| 80 | +void quantized_relu_per_tensor_out_( |
| 81 | + __ET_UNUSED KernelRuntimeContext& ctx, |
| 82 | + const Tensor& input, |
| 83 | + const int64_t in_zero_point, |
| 84 | + const int64_t out_zero_point, |
| 85 | + const int64_t out_multiplier, |
| 86 | + const int64_t out_shift, |
| 87 | + Tensor& output) { |
| 88 | + const T* __restrict__ in = input.const_data_ptr<T>(); |
| 89 | + T* __restrict__ out = output.mutable_data_ptr<T>(); |
| 90 | + |
| 91 | + // Compute the out_scale from out_multiplier and out_shift |
| 92 | + const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); |
| 93 | + |
| 94 | + for (size_t i = 0, e = input.numel(); i < e; ++i) { |
| 95 | + const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; |
| 96 | + out[i] = kernels::quantize<T>(temp, out_scale, out_zero_point); |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +void quantized_relu_per_tensor_out( |
| 101 | + KernelRuntimeContext& ctx, |
| 102 | + const Tensor& input, |
| 103 | + const int64_t in_zero_point, |
| 104 | + const int64_t out_zero_point, |
| 105 | + const int64_t out_multiplier, |
| 106 | + const int64_t out_shift, |
| 107 | + Tensor& output) { |
| 108 | +#define typed_quantized_relu(ctype, dtype) \ |
| 109 | + case executorch::aten::ScalarType::dtype: { \ |
| 110 | + quantized_relu_per_tensor_out_<ctype>( \ |
| 111 | + ctx, \ |
| 112 | + input, \ |
| 113 | + in_zero_point, \ |
| 114 | + out_zero_point, \ |
| 115 | + out_multiplier, \ |
| 116 | + out_shift, \ |
| 117 | + output); \ |
| 118 | + break; \ |
| 119 | + } |
| 120 | + |
| 121 | + executorch::aten::ScalarType dtype = input.scalar_type(); |
| 122 | + switch (dtype) { |
| 123 | + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) |
| 124 | + default: |
| 125 | + ET_DCHECK_MSG( |
| 126 | + false, "Unhandled dtype %s", torch::executor::toString(dtype)); |
| 127 | + } |
| 128 | + |
| 129 | +#undef typed_quantized_relu |
| 130 | +} |
| 131 | + |
78 | 132 | }; // namespace native
|
79 | 133 | }; // namespace reference
|
80 | 134 | }; // namespace impl
|
0 commit comments