Skip to content

Commit 82acec1

Browse files
authored
[HLSL] Implementation of dot intrinsic (llvm#81190)
This change implements llvm#70073 HLSL has a dot intrinsic defined here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td. This is used to associate all the dot product typdef defined hlsl_intrinsics.h with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp. In IntrinsicsDirectX.td we define the llvmIR for the dot product. A few goals were in mind for this IR. First it should operate on only vectors. Second the return type should be the vector element type. Third the second parameter vector should be of the same size as the first parameter. Finally `a dot b` should be the same as `b dot a`. In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL. The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp EmitHLSLBuiltinExp dot product intrinsics makes a destinction between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply. Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics. Fixes llvm#70073
1 parent b2ebd8b commit 82acec1

File tree

11 files changed

+703
-8
lines changed

11 files changed

+703
-8
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
45244524
let Prototype = "void*(unsigned char)";
45254525
}
45264526

4527+
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
4528+
let Spellings = ["__builtin_hlsl_dot"];
4529+
let Attributes = [NoThrow, Const];
4530+
let Prototype = "void(...)";
4531+
}
4532+
45274533
// Builtins for XRay.
45284534
def XRayCustomEvent : Builtin {
45294535
let Spellings = ["__xray_customevent"];

clang/include/clang/Sema/Sema.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14063,6 +14063,7 @@ class Sema final {
1406314063
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1406414064
CallExpr *TheCall);
1406514065
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
14066+
bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
1406614067
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
1406714068
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1406814069
CallExpr *TheCall);
@@ -14128,6 +14129,8 @@ class Sema final {
1412814129

1412914130
bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
1413014131

14132+
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
14133+
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
1413114134
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
1413214135
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
1413314136
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/IR/IntrinsicsAMDGPU.h"
4545
#include "llvm/IR/IntrinsicsARM.h"
4646
#include "llvm/IR/IntrinsicsBPF.h"
47+
#include "llvm/IR/IntrinsicsDirectX.h"
4748
#include "llvm/IR/IntrinsicsHexagon.h"
4849
#include "llvm/IR/IntrinsicsNVPTX.h"
4950
#include "llvm/IR/IntrinsicsPowerPC.h"
@@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
59825983
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
59835984
}
59845985

5986+
// EmitHLSLBuiltinExpr will check getLangOpts().HLSL
5987+
if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
5988+
return RValue::get(V);
5989+
59855990
if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
59865991
return EmitHipStdParUnsupportedBuiltin(this, FD);
59875992

@@ -17959,6 +17964,52 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1795917964
return Arg;
1796017965
}
1796117966

17967+
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
17968+
const CallExpr *E) {
17969+
if (!getLangOpts().HLSL)
17970+
return nullptr;
17971+
17972+
switch (BuiltinID) {
17973+
case Builtin::BI__builtin_hlsl_dot: {
17974+
Value *Op0 = EmitScalarExpr(E->getArg(0));
17975+
Value *Op1 = EmitScalarExpr(E->getArg(1));
17976+
llvm::Type *T0 = Op0->getType();
17977+
llvm::Type *T1 = Op1->getType();
17978+
if (!T0->isVectorTy() && !T1->isVectorTy()) {
17979+
if (T0->isFloatingPointTy())
17980+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17981+
17982+
if (T0->isIntegerTy())
17983+
return Builder.CreateMul(Op0, Op1, "dx.dot");
17984+
17985+
// Bools should have been promoted
17986+
llvm_unreachable(
17987+
"Scalar dot product is only supported on ints and floats.");
17988+
}
17989+
// A VectorSplat should have happened
17990+
assert(T0->isVectorTy() && T1->isVectorTy() &&
17991+
"Dot product of vector and scalar is not supported.");
17992+
17993+
// A vector sext or sitofp should have happened
17994+
assert(T0->getScalarType() == T1->getScalarType() &&
17995+
"Dot product of vectors need the same element types.");
17996+
17997+
[[maybe_unused]] auto *VecTy0 =
17998+
E->getArg(0)->getType()->getAs<VectorType>();
17999+
[[maybe_unused]] auto *VecTy1 =
18000+
E->getArg(1)->getType()->getAs<VectorType>();
18001+
// A HLSLVectorTruncation should have happend
18002+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
18003+
"Dot product requires vectors to be of the same size.");
18004+
18005+
return Builder.CreateIntrinsic(
18006+
/*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
18007+
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
18008+
} break;
18009+
}
18010+
return nullptr;
18011+
}
18012+
1796218013
Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1796318014
const CallExpr *E) {
1796418015
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
44054405
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44064406
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44074407
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
4408+
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
44084409
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
44094410
const CallExpr *E);
44104411
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,104 @@ double3 cos(double3);
179179
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
180180
double4 cos(double4);
181181

182+
//===----------------------------------------------------------------------===//
183+
// dot product builtins
184+
//===----------------------------------------------------------------------===//
185+
186+
/// \fn K dot(T X, T Y)
187+
/// \brief Return the dot product (a scalar value) of \a X and \a Y.
188+
/// \param X The X input value.
189+
/// \param Y The Y input value.
190+
191+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
192+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
193+
half dot(half, half);
194+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
195+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
196+
half dot(half2, half2);
197+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
198+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
199+
half dot(half3, half3);
200+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
201+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
202+
half dot(half4, half4);
203+
204+
#ifdef __HLSL_ENABLE_16_BIT
205+
_HLSL_AVAILABILITY(shadermodel, 6.2)
206+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
207+
int16_t dot(int16_t, int16_t);
208+
_HLSL_AVAILABILITY(shadermodel, 6.2)
209+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
210+
int16_t dot(int16_t2, int16_t2);
211+
_HLSL_AVAILABILITY(shadermodel, 6.2)
212+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
213+
int16_t dot(int16_t3, int16_t3);
214+
_HLSL_AVAILABILITY(shadermodel, 6.2)
215+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
216+
int16_t dot(int16_t4, int16_t4);
217+
218+
_HLSL_AVAILABILITY(shadermodel, 6.2)
219+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
220+
uint16_t dot(uint16_t, uint16_t);
221+
_HLSL_AVAILABILITY(shadermodel, 6.2)
222+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
223+
uint16_t dot(uint16_t2, uint16_t2);
224+
_HLSL_AVAILABILITY(shadermodel, 6.2)
225+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
226+
uint16_t dot(uint16_t3, uint16_t3);
227+
_HLSL_AVAILABILITY(shadermodel, 6.2)
228+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
229+
uint16_t dot(uint16_t4, uint16_t4);
230+
#endif
231+
232+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
233+
float dot(float, float);
234+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
235+
float dot(float2, float2);
236+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
237+
float dot(float3, float3);
238+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
239+
float dot(float4, float4);
240+
241+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
242+
double dot(double, double);
243+
244+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
245+
int dot(int, int);
246+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
247+
int dot(int2, int2);
248+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
249+
int dot(int3, int3);
250+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
251+
int dot(int4, int4);
252+
253+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
254+
uint dot(uint, uint);
255+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
256+
uint dot(uint2, uint2);
257+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
258+
uint dot(uint3, uint3);
259+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
260+
uint dot(uint4, uint4);
261+
262+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
263+
int64_t dot(int64_t, int64_t);
264+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
265+
int64_t dot(int64_t2, int64_t2);
266+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
267+
int64_t dot(int64_t3, int64_t3);
268+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
269+
int64_t dot(int64_t4, int64_t4);
270+
271+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
272+
uint64_t dot(uint64_t, uint64_t);
273+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
274+
uint64_t dot(uint64_t2, uint64_t2);
275+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
276+
uint64_t dot(uint64_t3, uint64_t3);
277+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
278+
uint64_t dot(uint64_t4, uint64_t4);
279+
182280
//===----------------------------------------------------------------------===//
183281
// floor builtins
184282
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
21202120
// not a valid type, emit an error message and return true. Otherwise return
21212121
// false.
21222122
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
2123-
QualType Ty) {
2124-
if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
2123+
QualType ArgTy, int ArgIndex) {
2124+
if (!ArgTy->getAs<VectorType>() &&
2125+
!ConstantMatrixType::isValidElementType(ArgTy)) {
21252126
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2126-
<< 1 << /* vector, integer or float ty*/ 0 << Ty;
2127+
<< ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
21272128
}
21282129

21292130
return false;
@@ -2961,6 +2962,9 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
29612962
}
29622963
}
29632964

2965+
if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall))
2966+
return ExprError();
2967+
29642968
// Since the target specific builtins for each arch overlap, only check those
29652969
// of the arch we are compiling for.
29662970
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
@@ -5161,6 +5165,70 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51615165
return false;
51625166
}
51635167

5168+
// Helper function for CheckHLSLBuiltinFunctionCall
5169+
bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
5170+
assert(TheCall->getNumArgs() > 1);
5171+
ExprResult A = TheCall->getArg(0);
5172+
ExprResult B = TheCall->getArg(1);
5173+
QualType ArgTyA = A.get()->getType();
5174+
QualType ArgTyB = B.get()->getType();
5175+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5176+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5177+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5178+
if (VecTyA == nullptr && VecTyB == nullptr)
5179+
return false;
5180+
5181+
if (VecTyA && VecTyB) {
5182+
bool retValue = false;
5183+
if (VecTyA->getElementType() != VecTyB->getElementType()) {
5184+
// Note: type promotion is intended to be handeled via the intrinsics
5185+
// and not the builtin itself.
5186+
S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
5187+
<< TheCall->getDirectCallee()
5188+
<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
5189+
retValue = true;
5190+
}
5191+
if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
5192+
// if we get here a HLSLVectorTruncation is needed.
5193+
S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5194+
<< TheCall->getDirectCallee()
5195+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5196+
TheCall->getArg(1)->getEndLoc());
5197+
retValue = true;
5198+
}
5199+
5200+
if (retValue)
5201+
TheCall->setType(VecTyA->getElementType());
5202+
5203+
return retValue;
5204+
}
5205+
5206+
// Note: if we get here one of the args is a scalar which
5207+
// requires a VectorSplat on Arg0 or Arg1
5208+
S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5209+
<< TheCall->getDirectCallee()
5210+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5211+
TheCall->getArg(1)->getEndLoc());
5212+
return true;
5213+
}
5214+
5215+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5216+
// returning an ExprError
5217+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5218+
switch (BuiltinID) {
5219+
case Builtin::BI__builtin_hlsl_dot: {
5220+
if (checkArgCount(*this, TheCall, 2))
5221+
return true;
5222+
if (CheckVectorElementCallArgs(this, TheCall))
5223+
return true;
5224+
if (SemaBuiltinVectorToScalarMath(TheCall))
5225+
return true;
5226+
break;
5227+
}
5228+
}
5229+
return false;
5230+
}
5231+
51645232
bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
51655233
CallExpr *TheCall) {
51665234
// position of memory order and scope arguments in the builtin
@@ -19594,23 +19662,43 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1959419662
TheCall->setArg(0, A.get());
1959519663
QualType TyA = A.get()->getType();
1959619664

19597-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19665+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1959819666
return true;
1959919667

1960019668
TheCall->setType(TyA);
1960119669
return false;
1960219670
}
1960319671

1960419672
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19673+
QualType Res;
19674+
if (SemaBuiltinVectorMath(TheCall, Res))
19675+
return true;
19676+
TheCall->setType(Res);
19677+
return false;
19678+
}
19679+
19680+
bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
19681+
QualType Res;
19682+
if (SemaBuiltinVectorMath(TheCall, Res))
19683+
return true;
19684+
19685+
if (auto *VecTy0 = Res->getAs<VectorType>())
19686+
TheCall->setType(VecTy0->getElementType());
19687+
else
19688+
TheCall->setType(Res);
19689+
19690+
return false;
19691+
}
19692+
19693+
bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
1960519694
if (checkArgCount(*this, TheCall, 2))
1960619695
return true;
1960719696

1960819697
ExprResult A = TheCall->getArg(0);
1960919698
ExprResult B = TheCall->getArg(1);
1961019699
// Do standard promotions between the two arguments, returning their common
1961119700
// type.
19612-
QualType Res =
19613-
UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19701+
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
1961419702
if (A.isInvalid() || B.isInvalid())
1961519703
return true;
1961619704

@@ -19622,12 +19710,11 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1962219710
diag::err_typecheck_call_different_arg_types)
1962319711
<< TyA << TyB;
1962419712

19625-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
19713+
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
1962619714
return true;
1962719715

1962819716
TheCall->setArg(0, A.get());
1962919717
TheCall->setArg(1, B.get());
19630-
TheCall->setType(Res);
1963119718
return false;
1963219719
}
1963319720

0 commit comments

Comments
 (0)