Skip to content

Commit c83da62

Browse files
swolchoktimocafe
authored andcommitted
Make at::vec::Vectorized ops work with scalars (pytorch#150380)
I noticed that I couldn't use `vec::Vectorized` operations with scalars, even though there is an implicit conversion from `T` to `vec::Vectorized<T>`, so I made it work. Test Plan: Added tests. Reverted vec_base.h, left the new tests in place, and confirmed that new tests don't compile in that state. Pull Request resolved: pytorch#150380 Approved by: https://github.com/Skylion007
1 parent 989b692 commit c83da62

File tree

3 files changed

+149
-18
lines changed

3 files changed

+149
-18
lines changed

aten/src/ATen/cpu/vec/vec_base.h

+93
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,25 @@ struct Vectorized {
740740
}
741741
};
742742

743+
// There is an implicit conversion that would make this work if
744+
// these operators weren't template functions, but they are template
745+
// functions (and can't be moved to be non-member friends defined in
746+
// the class body as suggested in
747+
// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255
748+
// because we have a lot of disparate specializations of
749+
// Vectorized). So, just explicitly make scalars work.
750+
#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \
751+
template <class T> \
752+
Vectorized<T> inline name(const Vectorized<T>& a, T b) { \
753+
return name(a, Vectorized<T>(b)); \
754+
} \
755+
template <class T> \
756+
Vectorized<T> inline name(T a, const Vectorized<T>& b) { \
757+
return name(Vectorized<T>(a), b); \
758+
}
759+
#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \
760+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op)
761+
743762
template <class T>
744763
Vectorized<T> inline operator+(const Vectorized<T>& a, const Vectorized<T>& b) {
745764
Vectorized<T> c;
@@ -749,6 +768,8 @@ Vectorized<T> inline operator+(const Vectorized<T>& a, const Vectorized<T>& b) {
749768
return c;
750769
}
751770

771+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+)
772+
752773
template <class T>
753774
Vectorized<T> inline operator-(const Vectorized<T>& a, const Vectorized<T>& b) {
754775
Vectorized<T> c;
@@ -758,6 +779,8 @@ Vectorized<T> inline operator-(const Vectorized<T>& a, const Vectorized<T>& b) {
758779
return c;
759780
}
760781

782+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-)
783+
761784
template <class T>
762785
Vectorized<T> inline operator*(const Vectorized<T>& a, const Vectorized<T>& b) {
763786
Vectorized<T> c;
@@ -767,6 +790,8 @@ Vectorized<T> inline operator*(const Vectorized<T>& a, const Vectorized<T>& b) {
767790
return c;
768791
}
769792

793+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*)
794+
770795
template <class T>
771796
Vectorized<T> inline operator/(const Vectorized<T>& a, const Vectorized<T>& b)
772797
__ubsan_ignore_float_divide_by_zero__ {
@@ -777,12 +802,16 @@ Vectorized<T> inline operator/(const Vectorized<T>& a, const Vectorized<T>& b)
777802
return c;
778803
}
779804

805+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/)
806+
780807
template <class T, typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
781808
Vectorized<T> inline operator%(const Vectorized<T>& a, const Vectorized<T>& b)
782809
__ubsan_ignore_float_divide_by_zero__ {
783810
return a - a / b * b;
784811
}
785812

813+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%)
814+
786815
template <class T>
787816
Vectorized<T> inline operator||(
788817
const Vectorized<T>& a,
@@ -794,6 +823,8 @@ Vectorized<T> inline operator||(
794823
return c;
795824
}
796825

826+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||)
827+
797828
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
798829
// either input is a NaN.
799830
template <
@@ -830,6 +861,8 @@ Vectorized<T> inline maximum(const Vectorized<T>& a, const Vectorized<T>& b) {
830861
return c;
831862
}
832863

864+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum)
865+
833866
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
834867
// either input is a NaN.
835868
template <
@@ -866,6 +899,8 @@ Vectorized<T> inline minimum(const Vectorized<T>& a, const Vectorized<T>& b) {
866899
return c;
867900
}
868901

902+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum)
903+
869904
template <
870905
class T,
871906
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
@@ -880,6 +915,42 @@ Vectorized<T> inline clamp(
880915
return c;
881916
}
882917

918+
#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \
919+
template <class T> \
920+
Vectorized<T> inline name( \
921+
const Vectorized<T>& a, const Vectorized<T>& b, T c) { \
922+
return name(a, b, Vectorized<T>(c)); \
923+
} \
924+
\
925+
template <class T> \
926+
Vectorized<T> inline name( \
927+
const Vectorized<T>& a, T b, const Vectorized<T>& c) { \
928+
return name(a, Vectorized<T>(b), c); \
929+
} \
930+
\
931+
template <class T> \
932+
Vectorized<T> inline name(const Vectorized<T>& a, T b, T c) { \
933+
return name(a, Vectorized<T>(b), Vectorized<T>(c)); \
934+
} \
935+
\
936+
template <class T> \
937+
Vectorized<T> inline name( \
938+
T a, const Vectorized<T>& b, const Vectorized<T>& c) { \
939+
return name(Vectorized<T>(a), b, c); \
940+
} \
941+
\
942+
template <class T> \
943+
Vectorized<T> inline name(T a, const Vectorized<T>& b, T c) { \
944+
return name(Vectorized<T>(a), b, Vectorized<T>(c)); \
945+
} \
946+
\
947+
template <class T> \
948+
Vectorized<T> inline name(T a, T b, const Vectorized<T>& c) { \
949+
return name(Vectorized<T>(a), Vectorized<T>(b), c); \
950+
}
951+
952+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp)
953+
883954
template <
884955
class T,
885956
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
@@ -893,6 +964,8 @@ Vectorized<T> inline clamp_max(
893964
return c;
894965
}
895966

967+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max)
968+
896969
template <
897970
class T,
898971
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
@@ -906,6 +979,8 @@ Vectorized<T> inline clamp_min(
906979
return c;
907980
}
908981

982+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min)
983+
909984
struct Vectorizedi;
910985

911986
#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
@@ -1049,6 +1124,10 @@ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
10491124

10501125
#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
10511126

1127+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&)
1128+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|)
1129+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^)
1130+
10521131
template <
10531132
class T,
10541133
typename std::
@@ -1142,6 +1221,8 @@ inline Vectorized<T> fmadd(
11421221
return a * b + c;
11431222
}
11441223

1224+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd)
1225+
11451226
template <typename T>
11461227
inline Vectorized<T> fmsub(
11471228
const Vectorized<T>& a,
@@ -1150,6 +1231,8 @@ inline Vectorized<T> fmsub(
11501231
return a * b - c;
11511232
}
11521233

1234+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub)
1235+
11531236
template <typename T>
11541237
Vectorized<T> inline operator&&(
11551238
const Vectorized<T>& a,
@@ -1161,6 +1244,8 @@ Vectorized<T> inline operator&&(
11611244
return ret;
11621245
}
11631246

1247+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&)
1248+
11641249
template <int64_t scale = 1, typename T = void>
11651250
std::enable_if_t<
11661251
scale == 1 || scale == 2 || scale == 4 || scale == 8,
@@ -1298,6 +1383,8 @@ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
12981383
Vectorized<T>::loadu(static_cast<void*>(buffer2)));
12991384
}
13001385

1386+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2)
1387+
13011388
// clang-format off
13021389
// inverse operation of deinterleave2
13031390
// Example inputs for AVX512:
@@ -1335,6 +1422,12 @@ interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
13351422
Vectorized<T>::loadu(static_cast<void*>(buffer2)));
13361423
}
13371424

1425+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2)
1426+
1427+
#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC
1428+
#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP
1429+
#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC
1430+
13381431
template <typename src_T, typename dst_T>
13391432
inline void convert(const src_T* src, dst_T* dst, int64_t n) {
13401433
#ifndef _MSC_VER

aten/src/ATen/test/vec_test_all_types.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ namespace {
329329
test_binary<vec>(
330330
NAME_INFO(fmod),
331331
RESOLVE_OVERLOAD(std::fmod),
332-
[](vec v0, vec v1) { return v0.fmod(v1); },
332+
[](const auto& v0, const auto& v1) { return vec(v0).fmod(v1); },
333333
createDefaultBinaryTestCase<vec>(TestSeed()),
334334
RESOLVE_OVERLOAD(filter_fmod));
335335
}
@@ -599,8 +599,8 @@ namespace {
599599
test_binary<vec>(
600600
NAME_INFO(atan2),
601601
RESOLVE_OVERLOAD(std::atan2),
602-
[](vec v0, vec v1) {
603-
return v0.atan2(v1);
602+
[](const auto& v0, const auto& v1) {
603+
return vec(v0).atan2(v1);
604604
},
605605
createDefaultBinaryTestCase<vec>(TestSeed()));
606606
}
@@ -609,23 +609,23 @@ namespace {
609609
test_binary<vec>(
610610
NAME_INFO(pow),
611611
RESOLVE_OVERLOAD(std::pow),
612-
[](vec v0, vec v1) { return v0.pow(v1); },
612+
[](const auto& v0, const auto& v1) { return vec(v0).pow(v1); },
613613
createDefaultBinaryTestCase<vec>(TestSeed(), false, true));
614614
}
615615
TYPED_TEST(RealTests, Hypot) {
616616
using vec = TypeParam;
617617
test_binary<vec>(
618618
NAME_INFO(hypot),
619619
RESOLVE_OVERLOAD(std::hypot),
620-
[](vec v0, vec v1) { return v0.hypot(v1); },
620+
[](const auto& v0, const auto& v1) { return vec(v0).hypot(v1); },
621621
createDefaultBinaryTestCase<vec>(TestSeed(), false, true));
622622
}
623623
TYPED_TEST(RealTests, NextAfter) {
624624
using vec = TypeParam;
625625
test_binary<vec>(
626626
NAME_INFO(nextafter),
627627
RESOLVE_OVERLOAD(std::nextafter),
628-
[](vec v0, vec v1) { return v0.nextafter(v1); },
628+
[](const auto& v0, const auto& v1) { return vec(v0).nextafter(v1); },
629629
createDefaultBinaryTestCase<vec>(TestSeed(), false, true));
630630
}
631631
TYPED_TEST(Interleave, Interleave) {
@@ -675,7 +675,7 @@ namespace {
675675
test_binary<vec>(
676676
NAME_INFO(plus),
677677
std::plus<VT>(),
678-
[](const vec& v0, const vec& v1) -> vec {
678+
[](const auto& v0, const auto& v1) -> vec {
679679
return v0 + v1;
680680
},
681681
createDefaultBinaryTestCase<vec>(TestSeed()),
@@ -687,7 +687,7 @@ namespace {
687687
test_binary<vec>(
688688
NAME_INFO(minus),
689689
std::minus<VT>(),
690-
[](const vec& v0, const vec& v1) -> vec {
690+
[](const auto& v0, const auto& v1) -> vec {
691691
return v0 - v1;
692692
},
693693
createDefaultBinaryTestCase<vec>(TestSeed()),
@@ -698,7 +698,7 @@ namespace {
698698
test_binary<vec>(
699699
NAME_INFO(mult),
700700
RESOLVE_OVERLOAD(local_multiply),
701-
[](const vec& v0, const vec& v1) { return v0 * v1; },
701+
[](const auto& v0, const auto& v1) { return v0 * v1; },
702702
createDefaultBinaryTestCase<vec>(TestSeed(), false, true),
703703
RESOLVE_OVERLOAD(filter_mult_overflow));
704704
}
@@ -708,7 +708,7 @@ namespace {
708708
test_binary<vec>(
709709
NAME_INFO(division),
710710
RESOLVE_OVERLOAD(local_division),
711-
[](const vec& v0, const vec& v1) { return v0 / v1; },
711+
[](const auto& v0, const auto& v1) { return v0 / v1; },
712712
createDefaultBinaryTestCase<vec>(seed),
713713
RESOLVE_OVERLOAD(filter_div_ub));
714714
}
@@ -717,23 +717,23 @@ namespace {
717717
test_binary<vec>(
718718
NAME_INFO(bit_and),
719719
RESOLVE_OVERLOAD(local_and),
720-
[](const vec& v0, const vec& v1) { return v0 & v1; },
720+
[](const auto& v0, const auto& v1) { return v0 & v1; },
721721
createDefaultBinaryTestCase<vec>(TestSeed(), true));
722722
}
723723
TYPED_TEST(Bitwise, BitOr) {
724724
using vec = TypeParam;
725725
test_binary<vec>(
726726
NAME_INFO(bit_or),
727727
RESOLVE_OVERLOAD(local_or),
728-
[](const vec& v0, const vec& v1) { return v0 | v1; },
728+
[](const auto& v0, const auto& v1) { return v0 | v1; },
729729
createDefaultBinaryTestCase<vec>(TestSeed(), true));
730730
}
731731
TYPED_TEST(Bitwise, BitXor) {
732732
using vec = TypeParam;
733733
test_binary<vec>(
734734
NAME_INFO(bit_xor),
735735
RESOLVE_OVERLOAD(local_xor),
736-
[](const vec& v0, const vec& v1) { return v0 ^ v1; },
736+
[](const auto& v0, const auto& v1) { return v0 ^ v1; },
737737
createDefaultBinaryTestCase<vec>(TestSeed(), true));
738738
}
739739
TYPED_TEST(Comparison, Equal) {
@@ -796,7 +796,7 @@ namespace {
796796
test_binary<vec>(
797797
NAME_INFO(minimum),
798798
minimum<VT>,
799-
[](const vec& v0, const vec& v1) {
799+
[](const auto& v0, const auto& v1) {
800800
return minimum(v0, v1);
801801
},
802802
createDefaultBinaryTestCase<vec>(TestSeed()));
@@ -807,7 +807,7 @@ namespace {
807807
test_binary<vec>(
808808
NAME_INFO(maximum),
809809
maximum<VT>,
810-
[](const vec& v0, const vec& v1) {
810+
[](const auto& v0, const auto& v1) {
811811
return maximum(v0, v1);
812812
},
813813
createDefaultBinaryTestCase<vec>(TestSeed()));
@@ -818,7 +818,7 @@ namespace {
818818
test_binary<vec>(
819819
NAME_INFO(clamp min),
820820
clamp_min<VT>,
821-
[](const vec& v0, const vec& v1) {
821+
[](const auto& v0, const auto& v1) {
822822
return clamp_min(v0, v1);
823823
},
824824
createDefaultBinaryTestCase<vec>(TestSeed()));
@@ -829,7 +829,7 @@ namespace {
829829
test_binary<vec>(
830830
NAME_INFO(clamp max),
831831
clamp_max<VT>,
832-
[](const vec& v0, const vec& v1) {
832+
[](const auto& v0, const auto& v1) {
833833
return clamp_max(v0, v1);
834834
},
835835
createDefaultBinaryTestCase<vec>(TestSeed()));

0 commit comments

Comments
 (0)