@@ -29,16 +29,21 @@ inline scalar_t vec_reduce_all(
29
29
30
30
template <typename scalar_t , typename Op>
31
31
struct VecReduceAllSIMD {
32
- static inline scalar_t apply (const Op& vec_fun, const Vectorized<scalar_t >& acc_vec) {
32
+ static inline scalar_t apply (
33
+ const Op& vec_fun,
34
+ const Vectorized<scalar_t >& acc_vec) {
33
35
return vec_reduce_all (vec_fun, acc_vec, Vectorized<scalar_t >::size ());
34
36
}
35
37
};
36
38
37
- #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
39
+ #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \
40
+ !defined(C10_MOBILE)
38
41
#if defined(CPU_CAPABILITY_AVX2)
39
42
template <typename Op>
40
43
struct VecReduceAllSIMD <float , Op> {
41
- static inline float apply (const Op& vec_fun, const Vectorized<float >& acc_vec) {
44
+ static inline float apply (
45
+ const Op& vec_fun,
46
+ const Vectorized<float >& acc_vec) {
42
47
using Vec = Vectorized<float >;
43
48
Vec v = acc_vec;
44
49
// 128-bit shuffle
@@ -57,7 +62,9 @@ struct VecReduceAllSIMD<float, Op> {
57
62
#if defined(CPU_CAPABILITY_AVX512)
58
63
template <typename Op>
59
64
struct VecReduceAllSIMD <float , Op> {
60
- static inline float apply (const Op& vec_fun, const Vectorized<float >& acc_vec) {
65
+ static inline float apply (
66
+ const Op& vec_fun,
67
+ const Vectorized<float >& acc_vec) {
61
68
using Vec = Vectorized<float >;
62
69
Vec v = acc_vec;
63
70
// 256-bit shuffle
@@ -76,36 +83,47 @@ struct VecReduceAllSIMD<float, Op> {
76
83
}
77
84
};
78
85
#endif // defined(CPU_CAPABILITY_AVX512)
79
- #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
86
+ #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
87
+ // !defined(C10_MOBILE)
80
88
81
- #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE)
89
+ #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
90
+ !defined(CPU_CAPABILITY_SVE)
82
91
template <typename Op>
83
92
struct VecReduceAllSIMD <float , Op> {
84
- static inline float apply (const Op& vec_fun, const Vectorized<float >& acc_vec) {
93
+ static inline float apply (
94
+ const Op& vec_fun,
95
+ const Vectorized<float >& acc_vec) {
85
96
using Vec = Vectorized<float >;
86
97
Vec v = acc_vec;
87
98
88
- // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
99
+ // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7,
100
+ // a4+a8, a1+a5, a2+a6, -, -, -, -]
89
101
float32x4_t v1_1 = vextq_f32 (v, v, 2 );
90
102
Vec v1 = v1_1;
91
103
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
92
104
v = vec_fun (v, v1);
93
105
94
- // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
106
+ // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -,
107
+ // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -,
108
+ // -]
95
109
v1_1 = vrev64q_f32 (v);
96
110
v1 = v1_1;
97
- // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
111
+ // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8,
112
+ // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
98
113
v = vec_fun (v, v1);
99
114
100
115
return v[0 ];
101
116
}
102
117
};
103
118
#endif // defined(__aarch64__)
104
119
105
- #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256)
120
+ #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
121
+ defined (CPU_CAPABILITY_SVE256)
106
122
template <typename Op>
107
123
struct VecReduceAllSIMD<float, Op> {
108
- static inline float apply (const Op& vec_fun, const Vectorized<float >& acc_vec) {
124
+ static inline float apply (
125
+ const Op& vec_fun,
126
+ const Vectorized<float >& acc_vec) {
109
127
using Vec = Vectorized<float >;
110
128
Vec v = acc_vec;
111
129
// 128-bit shuffle
@@ -125,15 +143,21 @@ struct VecReduceAllSIMD<float, Op> {
125
143
};
126
144
#endif // defined(__aarch64__)
127
145
128
-
129
146
template <typename scalar_t , typename Op>
130
- inline scalar_t vec_reduce_all (const Op& vec_fun, const Vectorized<scalar_t >& acc_vec) {
147
+ inline scalar_t vec_reduce_all (
148
+ const Op& vec_fun,
149
+ const Vectorized<scalar_t >& acc_vec) {
131
150
return VecReduceAllSIMD<scalar_t , Op>::apply (vec_fun, acc_vec);
132
151
}
133
152
134
- template <typename scalar_t , typename Op,
135
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
136
- inline scalar_t reduce_all (const Op& vec_fun, const scalar_t * data, int64_t size) {
153
+ template <
154
+ typename scalar_t ,
155
+ typename Op,
156
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
157
+ inline scalar_t reduce_all (
158
+ const Op& vec_fun,
159
+ const scalar_t * data,
160
+ int64_t size) {
137
161
using Vec = vec::Vectorized<scalar_t >;
138
162
if (size < Vec::size ())
139
163
return vec_reduce_all (vec_fun, Vec::loadu (data, size), size);
@@ -151,16 +175,22 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size
151
175
}
152
176
153
177
// similar to reduce_all, but reduces into two outputs
154
- template <typename scalar_t , typename Op1, typename Op2,
155
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
156
- inline std::pair<scalar_t , scalar_t > reduce2_all (const Op1& vec_fun1, const Op2& vec_fun2,
157
- const scalar_t * data, int64_t size) {
178
+ template <
179
+ typename scalar_t ,
180
+ typename Op1,
181
+ typename Op2,
182
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
183
+ inline std::pair<scalar_t , scalar_t > reduce2_all (
184
+ const Op1& vec_fun1,
185
+ const Op2& vec_fun2,
186
+ const scalar_t * data,
187
+ int64_t size) {
158
188
using Vec = vec::Vectorized<scalar_t >;
159
189
if (size < Vec::size ()) {
160
190
auto loaded_data = Vec::loadu (data, size);
161
191
return std::pair<scalar_t , scalar_t >(
162
- vec_reduce_all (vec_fun1, loaded_data, size),
163
- vec_reduce_all (vec_fun2, loaded_data, size));
192
+ vec_reduce_all (vec_fun1, loaded_data, size),
193
+ vec_reduce_all (vec_fun2, loaded_data, size));
164
194
}
165
195
int64_t d = Vec::size ();
166
196
Vec acc_vec1 = Vec::loadu (data);
@@ -176,12 +206,14 @@ inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2&
176
206
acc_vec2 = Vec::set (acc_vec2, vec_fun2 (acc_vec2, data_vec), size - d);
177
207
}
178
208
return std::pair<scalar_t , scalar_t >(
179
- vec_reduce_all (vec_fun1, acc_vec1),
180
- vec_reduce_all (vec_fun2, acc_vec2));
209
+ vec_reduce_all (vec_fun1, acc_vec1), vec_reduce_all (vec_fun2, acc_vec2));
181
210
}
182
211
183
- template <typename scalar_t , typename MapOp, typename ReduceOp,
184
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
212
+ template <
213
+ typename scalar_t ,
214
+ typename MapOp,
215
+ typename ReduceOp,
216
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
185
217
inline scalar_t map_reduce_all (
186
218
const MapOp& map_fun,
187
219
const ReduceOp& red_fun,
@@ -205,8 +237,11 @@ inline scalar_t map_reduce_all(
205
237
return vec_reduce_all (red_fun, acc_vec);
206
238
}
207
239
208
- template <typename scalar_t , typename MapOp, typename ReduceOp,
209
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
240
+ template <
241
+ typename scalar_t ,
242
+ typename MapOp,
243
+ typename ReduceOp,
244
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
210
245
inline scalar_t map2_reduce_all (
211
246
const MapOp& map_fun,
212
247
const ReduceOp& red_fun,
@@ -237,8 +272,11 @@ inline scalar_t map2_reduce_all(
237
272
return vec_reduce_all (red_fun, acc_vec);
238
273
}
239
274
240
- template <typename scalar_t , typename MapOp, typename ReduceOp,
241
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
275
+ template <
276
+ typename scalar_t ,
277
+ typename MapOp,
278
+ typename ReduceOp,
279
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
242
280
inline scalar_t map3_reduce_all (
243
281
const MapOp& map_fun,
244
282
const ReduceOp& red_fun,
@@ -274,8 +312,10 @@ inline scalar_t map3_reduce_all(
274
312
return vec_reduce_all (red_fun, acc_vec);
275
313
}
276
314
277
- template <typename scalar_t , typename Op,
278
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
315
+ template <
316
+ typename scalar_t ,
317
+ typename Op,
318
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
279
319
inline void map (
280
320
const Op& vec_fun,
281
321
scalar_t * output_data,
@@ -293,8 +333,10 @@ inline void map(
293
333
}
294
334
}
295
335
296
- template <typename scalar_t , typename Op,
297
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
336
+ template <
337
+ typename scalar_t ,
338
+ typename Op,
339
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
298
340
inline void map2 (
299
341
const Op& vec_fun,
300
342
scalar_t * output_data,
@@ -317,8 +359,10 @@ inline void map2(
317
359
}
318
360
}
319
361
320
- template <typename scalar_t , typename Op,
321
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
362
+ template <
363
+ typename scalar_t ,
364
+ typename Op,
365
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
322
366
inline void map3 (
323
367
const Op& vec_fun,
324
368
scalar_t * output_data,
@@ -344,8 +388,10 @@ inline void map3(
344
388
}
345
389
}
346
390
347
- template <typename scalar_t , typename Op,
348
- typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
391
+ template <
392
+ typename scalar_t ,
393
+ typename Op,
394
+ typename std::enable_if_t <!is_reduced_floating_point_v<scalar_t >, int > = 0 >
349
395
inline void map4 (
350
396
const Op& vec_fun,
351
397
scalar_t * output_data,
0 commit comments