@@ -740,6 +740,25 @@ struct Vectorized {
740
740
}
741
741
};
742
742
743
+ // There is an implicit conversion that would make this work if
744
+ // these operators weren't template functions, but they are template
745
+ // functions (and can't be moved to be non-member friends defined in
746
+ // the class body as suggested in
747
+ // https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255
748
+ // because we have a lot of disparate specializations of
749
+ // Vectorized). So, just explicitly make scalars work.
750
+ #define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (name ) \
751
+ template <class T > \
752
+ Vectorized<T> inline name (const Vectorized<T>& a, T b) { \
753
+ return name (a, Vectorized<T>(b)); \
754
+ } \
755
+ template <class T > \
756
+ Vectorized<T> inline name (T a, const Vectorized<T>& b) { \
757
+ return name (Vectorized<T>(a), b); \
758
+ }
759
+ #define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (op ) \
760
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (operator op)
761
+
743
762
template <class T>
744
763
Vectorized<T> inline operator+(const Vectorized<T>& a, const Vectorized<T>& b) {
745
764
Vectorized<T> c;
@@ -749,6 +768,8 @@ Vectorized<T> inline operator+(const Vectorized<T>& a, const Vectorized<T>& b) {
749
768
return c;
750
769
}
751
770
771
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (+)
772
+
752
773
template <class T>
753
774
Vectorized<T> inline operator-(const Vectorized<T>& a, const Vectorized<T>& b) {
754
775
Vectorized<T> c;
@@ -758,6 +779,8 @@ Vectorized<T> inline operator-(const Vectorized<T>& a, const Vectorized<T>& b) {
758
779
return c;
759
780
}
760
781
782
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (-)
783
+
761
784
template <class T>
762
785
Vectorized<T> inline operator*(const Vectorized<T>& a, const Vectorized<T>& b) {
763
786
Vectorized<T> c;
@@ -767,6 +790,8 @@ Vectorized<T> inline operator*(const Vectorized<T>& a, const Vectorized<T>& b) {
767
790
return c;
768
791
}
769
792
793
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (*)
794
+
770
795
template <class T>
771
796
Vectorized<T> inline operator/(const Vectorized<T>& a, const Vectorized<T>& b)
772
797
__ubsan_ignore_float_divide_by_zero__ {
@@ -777,12 +802,16 @@ Vectorized<T> inline operator/(const Vectorized<T>& a, const Vectorized<T>& b)
777
802
return c;
778
803
}
779
804
805
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (/)
806
+
780
807
template <class T, typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
781
808
Vectorized<T> inline operator%(const Vectorized<T>& a, const Vectorized<T>& b)
782
809
__ubsan_ignore_float_divide_by_zero__ {
783
810
return a - a / b * b;
784
811
}
785
812
813
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (%)
814
+
786
815
template <class T>
787
816
Vectorized<T> inline operator||(
788
817
const Vectorized<T>& a,
@@ -794,6 +823,8 @@ Vectorized<T> inline operator||(
794
823
return c;
795
824
}
796
825
826
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (||)
827
+
797
828
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
798
829
// either input is a NaN.
799
830
template <
@@ -830,6 +861,8 @@ Vectorized<T> inline maximum(const Vectorized<T>& a, const Vectorized<T>& b) {
830
861
return c;
831
862
}
832
863
864
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (maximum)
865
+
833
866
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
834
867
// either input is a NaN.
835
868
template <
@@ -866,6 +899,8 @@ Vectorized<T> inline minimum(const Vectorized<T>& a, const Vectorized<T>& b) {
866
899
return c;
867
900
}
868
901
902
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (minimum)
903
+
869
904
template <
870
905
class T,
871
906
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
@@ -880,6 +915,42 @@ Vectorized<T> inline clamp(
880
915
return c;
881
916
}
882
917
918
+ #define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC (name ) \
919
+ template <class T > \
920
+ Vectorized<T> inline name ( \
921
+ const Vectorized<T>& a, const Vectorized<T>& b, T c) { \
922
+ return name (a, b, Vectorized<T>(c)); \
923
+ } \
924
+ \
925
+ template <class T > \
926
+ Vectorized<T> inline name ( \
927
+ const Vectorized<T>& a, T b, const Vectorized<T>& c) { \
928
+ return name (a, Vectorized<T>(b), c); \
929
+ } \
930
+ \
931
+ template <class T > \
932
+ Vectorized<T> inline name (const Vectorized<T>& a, T b, T c) { \
933
+ return name (a, Vectorized<T>(b), Vectorized<T>(c)); \
934
+ } \
935
+ \
936
+ template <class T > \
937
+ Vectorized<T> inline name ( \
938
+ T a, const Vectorized<T>& b, const Vectorized<T>& c) { \
939
+ return name (Vectorized<T>(a), b, c); \
940
+ } \
941
+ \
942
+ template <class T > \
943
+ Vectorized<T> inline name (T a, const Vectorized<T>& b, T c) { \
944
+ return name (Vectorized<T>(a), b, Vectorized<T>(c)); \
945
+ } \
946
+ \
947
+ template <class T > \
948
+ Vectorized<T> inline name (T a, T b, const Vectorized<T>& c) { \
949
+ return name (Vectorized<T>(a), Vectorized<T>(b), c); \
950
+ }
951
+
952
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC (clamp)
953
+
883
954
template <
884
955
class T,
885
956
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
@@ -893,6 +964,8 @@ Vectorized<T> inline clamp_max(
893
964
return c;
894
965
}
895
966
967
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (clamp_max)
968
+
896
969
template <
897
970
class T,
898
971
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
@@ -906,6 +979,8 @@ Vectorized<T> inline clamp_min(
906
979
return c;
907
980
}
908
981
982
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (clamp_min)
983
+
909
984
struct Vectorizedi;
910
985
911
986
#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
@@ -1049,6 +1124,10 @@ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
1049
1124
1050
1125
#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
1051
1126
1127
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (&)
1128
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|)
1129
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^)
1130
+
1052
1131
template <
1053
1132
class T,
1054
1133
typename std::
@@ -1142,6 +1221,8 @@ inline Vectorized<T> fmadd(
1142
1221
return a * b + c;
1143
1222
}
1144
1223
1224
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC (fmadd)
1225
+
1145
1226
template <typename T>
1146
1227
inline Vectorized<T> fmsub(
1147
1228
const Vectorized<T>& a,
@@ -1150,6 +1231,8 @@ inline Vectorized<T> fmsub(
1150
1231
return a * b - c;
1151
1232
}
1152
1233
1234
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC (fmsub)
1235
+
1153
1236
template <typename T>
1154
1237
Vectorized<T> inline operator&&(
1155
1238
const Vectorized<T>& a,
@@ -1161,6 +1244,8 @@ Vectorized<T> inline operator&&(
1161
1244
return ret;
1162
1245
}
1163
1246
1247
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP (&&)
1248
+
1164
1249
template <int64_t scale = 1, typename T = void>
1165
1250
std::enable_if_t<
1166
1251
scale == 1 || scale == 2 || scale == 4 || scale == 8,
@@ -1298,6 +1383,8 @@ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1298
1383
Vectorized<T>::loadu (static_cast <void *>(buffer2)));
1299
1384
}
1300
1385
1386
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (deinterleave2)
1387
+
1301
1388
// clang-format off
1302
1389
// inverse operation of deinterleave2
1303
1390
// Example inputs for AVX512:
@@ -1335,6 +1422,12 @@ interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1335
1422
Vectorized<T>::loadu (static_cast <void *>(buffer2)));
1336
1423
}
1337
1424
1425
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC (interleave2)
1426
+
1427
+ #undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC
1428
+ #undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP
1429
+ #undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC
1430
+
1338
1431
template <typename src_T, typename dst_T>
1339
1432
inline void convert (const src_T* src, dst_T* dst, int64_t n) {
1340
1433
#ifndef _MSC_VER
0 commit comments