Skip to content

Commit e85e29c

Browse files
authored
[HLSL] select scalar overloads for vector conditions (#129396)
This PR adds scalar/vector overloads for vector conditions to the `select` builtin, and updates the sema checking and codegen to allow scalars to extend to vectors. Fixes #126570
1 parent 74ca579 commit e85e29c

File tree

11 files changed

+217
-162
lines changed

11 files changed

+217
-162
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12707,6 +12707,10 @@ def err_hlsl_param_qualifier_mismatch :
1270712707
def err_hlsl_vector_compound_assignment_truncation : Error<
1270812708
"left hand operand of type %0 to compound assignment cannot be truncated "
1270912709
"when used with right hand operand of type %1">;
12710+
def err_hlsl_builtin_scalar_vector_mismatch
12711+
: Error<
12712+
"%select{all|second and third}0 arguments to %1 must be of scalar or "
12713+
"vector type with matching scalar element type%diff{: $ vs $|}2,3">;
1271012714

1271112715
def warn_hlsl_impcast_vector_truncation : Warning<
1271212716
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19836,6 +19836,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1983619836
RValFalse.isScalar()
1983719837
? RValFalse.getScalarVal()
1983819838
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
19839+
if (auto *VTy = E->getType()->getAs<VectorType>()) {
19840+
if (!OpTrue->getType()->isVectorTy())
19841+
OpTrue =
19842+
Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
19843+
if (!OpFalse->getType()->isVectorTy())
19844+
OpFalse =
19845+
Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
19846+
}
1983919847

1984019848
Value *SelectVal =
1984119849
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");

clang/lib/Headers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ set(hlsl_h
8787
set(hlsl_subdir_files
8888
hlsl/hlsl_basic_types.h
8989
hlsl/hlsl_alias_intrinsics.h
90+
hlsl/hlsl_intrinsic_helpers.h
9091
hlsl/hlsl_intrinsics.h
9192
hlsl/hlsl_detail.h
9293
)

clang/lib/Headers/hlsl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
#pragma clang diagnostic ignored "-Whlsl-dxc-compatability"
1717
#endif
1818

19+
// Basic types, type traits and type-independent templates.
1920
#include "hlsl/hlsl_basic_types.h"
21+
#include "hlsl/hlsl_detail.h"
22+
23+
// HLSL standard library function declarations/definitions.
24+
#include "hlsl/hlsl_alias_intrinsics.h"
2025
#include "hlsl/hlsl_intrinsics.h"
2126

2227
#if defined(__clang__)

clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,6 +2123,41 @@ template <typename T, int Sz>
21232123
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
21242124
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
21252125

2126+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
2127+
/// vector<T,Sz> FalseVals)
2128+
/// \brief ternary operator for vectors. All vectors must be the same size.
2129+
/// \param Conds The Condition input values.
2130+
/// \param TrueVal The scalar value to splat from when conditions are true.
2131+
/// \param FalseVals The vector values are chosen from when conditions are
2132+
/// false.
2133+
2134+
template <typename T, int Sz>
2135+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2136+
vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
2137+
2138+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
2139+
/// T FalseVal)
2140+
/// \brief ternary operator for vectors. All vectors must be the same size.
2141+
/// \param Conds The Condition input values.
2142+
/// \param TrueVals The vector values are chosen from when conditions are true.
2143+
/// \param FalseVal The scalar value to splat from when conditions are false.
2144+
2145+
template <typename T, int Sz>
2146+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2147+
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
2148+
2149+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
2150+
/// T FalseVal)
2151+
/// \brief ternary operator for vectors. All vectors must be the same size.
2152+
/// \param Conds The Condition input values.
2153+
/// \param TrueVal The scalar value to splat from when conditions are true.
2154+
/// \param FalseVal The scalar value to splat from when conditions are false.
2155+
2156+
template <typename T, int Sz>
2157+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2158+
__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
2159+
vector<bool, Sz>, T, T);
2160+
21262161
//===----------------------------------------------------------------------===//
21272162
// sin builtins
21282163
//===----------------------------------------------------------------------===//

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===----- detail.h - HLSL definitions for intrinsics ----------===//
1+
//===----- hlsl_detail.h - HLSL definitions for intrinsics ----------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,8 +9,6 @@
99
#ifndef _HLSL_HLSL_DETAILS_H_
1010
#define _HLSL_HLSL_DETAILS_H_
1111

12-
#include "hlsl_alias_intrinsics.h"
13-
1412
namespace hlsl {
1513

1614
namespace __detail {
@@ -43,59 +41,9 @@ constexpr enable_if_t<sizeof(U) == sizeof(T), U> bit_cast(T F) {
4341
return __builtin_bit_cast(U, F);
4442
}
4543

46-
constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
47-
// Use the same scaling factor used by FXC, and DXC for DXIL
48-
// (i.e., 255.001953)
49-
// https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
50-
// The DXC implementation refers to a comment on the following stackoverflow
51-
// discussion to justify the scaling factor: "Built-in rounding, necessary
52-
// because of truncation. 0.001953 * 256 = 0.5"
53-
// https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
54-
return V.zyxw * 255.001953f;
55-
}
56-
57-
template <typename T>
58-
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
59-
length_impl(T X) {
60-
return abs(X);
61-
}
62-
63-
template <typename T, int N>
64-
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
65-
length_vec_impl(vector<T, N> X) {
66-
#if (__has_builtin(__builtin_spirv_length))
67-
return __builtin_spirv_length(X);
68-
#else
69-
return sqrt(dot(X, X));
70-
#endif
71-
}
72-
73-
template <typename T>
74-
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
75-
distance_impl(T X, T Y) {
76-
return length_impl(X - Y);
77-
}
78-
79-
template <typename T, int N>
80-
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
81-
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
82-
return length_vec_impl(X - Y);
83-
}
84-
85-
template <typename T>
86-
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
87-
reflect_impl(T I, T N) {
88-
return I - 2 * N * I * N;
89-
}
90-
91-
template <typename T, int L>
92-
constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
93-
#if (__has_builtin(__builtin_spirv_reflect))
94-
return __builtin_spirv_reflect(I, N);
95-
#else
96-
return I - 2 * N * dot(I, N);
97-
#endif
98-
}
44+
template <typename T> struct is_arithmetic {
45+
static const bool Value = __is_arithmetic(T);
46+
};
9947

10048
} // namespace __detail
10149
} // namespace hlsl
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===----- hlsl_intrinsic_helpers.h - HLSL helpers intrinsics -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _HLSL_HLSL_INTRINSIC_HELPERS_H_
10+
#define _HLSL_HLSL_INTRINSIC_HELPERS_H_
11+
12+
namespace hlsl {
13+
namespace __detail {
14+
15+
constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
16+
// Use the same scaling factor used by FXC, and DXC for DXIL
17+
// (i.e., 255.001953)
18+
// https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
19+
// The DXC implementation refers to a comment on the following stackoverflow
20+
// discussion to justify the scaling factor: "Built-in rounding, necessary
21+
// because of truncation. 0.001953 * 256 = 0.5"
22+
// https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
23+
return V.zyxw * 255.001953f;
24+
}
25+
26+
template <typename T>
27+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
28+
length_impl(T X) {
29+
return abs(X);
30+
}
31+
32+
template <typename T, int N>
33+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
34+
length_vec_impl(vector<T, N> X) {
35+
#if (__has_builtin(__builtin_spirv_length))
36+
return __builtin_spirv_length(X);
37+
#else
38+
return sqrt(dot(X, X));
39+
#endif
40+
}
41+
42+
template <typename T>
43+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
44+
distance_impl(T X, T Y) {
45+
return length_impl(X - Y);
46+
}
47+
48+
template <typename T, int N>
49+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
50+
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
51+
return length_vec_impl(X - Y);
52+
}
53+
54+
template <typename T>
55+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
56+
reflect_impl(T I, T N) {
57+
return I - 2 * N * I * N;
58+
}
59+
60+
template <typename T, int L>
61+
constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
62+
#if (__has_builtin(__builtin_spirv_reflect))
63+
return __builtin_spirv_reflect(I, N);
64+
#else
65+
return I - 2 * N * dot(I, N);
66+
#endif
67+
}
68+
} // namespace __detail
69+
} // namespace hlsl
70+
71+
#endif // _HLSL_HLSL_INTRINSIC_HELPERS_H_

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef _HLSL_HLSL_INTRINSICS_H_
1010
#define _HLSL_HLSL_INTRINSICS_H_
1111

12-
#include "hlsl_detail.h"
12+
#include "hlsl/hlsl_intrinsic_helpers.h"
1313

1414
namespace hlsl {
1515

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,40 +2225,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
22252225
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
22262226
assert(TheCall->getNumArgs() == 3);
22272227
Expr *Arg1 = TheCall->getArg(1);
2228+
QualType Arg1Ty = Arg1->getType();
22282229
Expr *Arg2 = TheCall->getArg(2);
2229-
if (!Arg1->getType()->isVectorType()) {
2230-
S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
2231-
<< "Second" << TheCall->getDirectCallee() << Arg1->getType()
2230+
QualType Arg2Ty = Arg2->getType();
2231+
2232+
QualType Arg1ScalarTy = Arg1Ty;
2233+
if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2234+
Arg1ScalarTy = VTy->getElementType();
2235+
2236+
QualType Arg2ScalarTy = Arg2Ty;
2237+
if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2238+
Arg2ScalarTy = VTy->getElementType();
2239+
2240+
if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
2241+
S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
2242+
<< /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2243+
2244+
QualType Arg0Ty = TheCall->getArg(0)->getType();
2245+
unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2246+
unsigned Arg1Length = Arg1Ty->isVectorType()
2247+
? Arg1Ty->getAs<VectorType>()->getNumElements()
2248+
: 0;
2249+
unsigned Arg2Length = Arg2Ty->isVectorType()
2250+
? Arg2Ty->getAs<VectorType>()->getNumElements()
2251+
: 0;
2252+
if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2253+
S->Diag(TheCall->getBeginLoc(),
2254+
diag::err_typecheck_vector_lengths_not_equal)
2255+
<< Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
22322256
<< Arg1->getSourceRange();
22332257
return true;
22342258
}
22352259

2236-
if (!Arg2->getType()->isVectorType()) {
2237-
S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
2238-
<< "Third" << TheCall->getDirectCallee() << Arg2->getType()
2239-
<< Arg2->getSourceRange();
2240-
return true;
2241-
}
2242-
2243-
if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2260+
if (Arg2Length > 0 && Arg0Length != Arg2Length) {
22442261
S->Diag(TheCall->getBeginLoc(),
2245-
diag::err_typecheck_call_different_arg_types)
2246-
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2262+
diag::err_typecheck_vector_lengths_not_equal)
2263+
<< Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
22472264
<< Arg2->getSourceRange();
22482265
return true;
22492266
}
22502267

2251-
// caller has checked that Arg0 is a vector.
2252-
// check all three args have the same length.
2253-
if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
2254-
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
2255-
S->Diag(TheCall->getBeginLoc(),
2256-
diag::err_typecheck_vector_lengths_not_equal)
2257-
<< TheCall->getArg(0)->getType() << Arg1->getType()
2258-
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
2259-
return true;
2260-
}
2261-
TheCall->setType(Arg1->getType());
2268+
TheCall->setType(
2269+
S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
22622270
return false;
22632271
}
22642272

clang/test/CodeGenHLSL/builtins/select.hlsl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
5252
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
5353
return select(cond0, tVals, fVals);
5454
}
55+
56+
// CHECK-LABEL: test_select_vector_scalar_vector
57+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
58+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
59+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
60+
// CHECK: ret <4 x i32> [[SELECT]]
61+
int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
62+
return select(cond0, tVal, fVals);
63+
}
64+
65+
// CHECK-LABEL: test_select_vector_vector_scalar
66+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
67+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
68+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
69+
// CHECK: ret <4 x i32> [[SELECT]]
70+
int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
71+
return select(cond0, tVals, fVal);
72+
}
73+
74+
// CHECK-LABEL: test_select_vector_scalar_scalar
75+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
76+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
77+
// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
78+
// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
79+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
80+
// CHECK: ret <4 x i32> [[SELECT]]
81+
int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
82+
return select(cond0, tVal, fVal);
83+
}

0 commit comments

Comments
 (0)