Skip to content

Commit 16c84c4

Browse files
authored
[DirectX] Add target builtins (#134439)
- fixes #132303 - Moves dot2add from a language builtin to a target builtin. - Sets the scaffolding for Sema checks for DX builtins - Setup DirectX backend as able to have target builtins - Adds a DX TargetBuiltins emitter in `clang/lib/CodeGen/TargetBuiltins/DirectX.cpp`
1 parent 7fe6e70 commit 16c84c4

File tree

20 files changed

+187
-23
lines changed

20 files changed

+187
-23
lines changed

.github/new-prs-labeler.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,12 @@ backend:DirectX:
660660
- '**/*dxil*/**'
661661
- '**/*DXContainer*'
662662
- '**/*DXContainer*/**'
663+
- clang/lib/Sema/SemaDirectX.cpp
664+
- clang/include/clang/Sema/SemaDirectX.h
665+
- clang/include/clang/Basic/BuiltinsDirectX.td
666+
- clang/lib/CodeGen/TargetBuiltins/DirectX.cpp
667+
- clang/test/CodeGenDirectX/**
668+
- clang/test/SemaDirectX/**
663669

664670
backend:SPIR-V:
665671
- clang/lib/Driver/ToolChains/SPIRV.*

clang/include/clang/Basic/Builtins.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4891,12 +4891,6 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
48914891
let Prototype = "void(...)";
48924892
}
48934893

4894-
def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
4895-
let Spellings = ["__builtin_hlsl_dot2add"];
4896-
let Attributes = [NoThrow, Const];
4897-
let Prototype = "float(_ExtVector<2, _Float16>, _ExtVector<2, _Float16>, float)";
4898-
}
4899-
49004894
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
49014895
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
49024896
let Attributes = [NoThrow, Const];
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===--- BuiltinsDirectX.td - DirectX Builtin function database -----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
include "clang/Basic/BuiltinsBase.td"
10+
11+
def DxDot2Add : Builtin {
12+
let Spellings = ["__builtin_dx_dot2add"];
13+
let Attributes = [NoThrow, Const];
14+
let Prototype = "float(_ExtVector<2, _Float16>, _ExtVector<2, _Float16>, float)";
15+
}

clang/include/clang/Basic/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ clang_tablegen(BuiltinsBPF.inc -gen-clang-builtins
8282
SOURCE BuiltinsBPF.td
8383
TARGET ClangBuiltinsBPF)
8484

85+
clang_tablegen(BuiltinsDirectX.inc -gen-clang-builtins
86+
SOURCE BuiltinsDirectX.td
87+
TARGET ClangBuiltinsDirectX)
88+
8589
clang_tablegen(BuiltinsHexagon.inc -gen-clang-builtins
8690
SOURCE BuiltinsHexagon.td
8791
TARGET ClangBuiltinsHexagon)

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ namespace clang {
141141
};
142142
}
143143

144+
/// DirectX builtins
145+
namespace DirectX {
146+
enum {
147+
LastTIBuiltin = clang::Builtin::FirstTSBuiltin - 1,
148+
#define GET_BUILTIN_ENUMERATORS
149+
#include "clang/Basic/BuiltinsDirectX.inc"
150+
#undef GET_BUILTIN_ENUMERATORS
151+
LastTSBuiltin
152+
};
153+
} // namespace DirectX
154+
144155
/// SPIRV builtins
145156
namespace SPIRV {
146157
enum {

clang/include/clang/Sema/Sema.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class SemaAVR;
160160
class SemaBPF;
161161
class SemaCodeCompletion;
162162
class SemaCUDA;
163+
class SemaDirectX;
163164
class SemaHLSL;
164165
class SemaHexagon;
165166
class SemaLoongArch;
@@ -1074,6 +1075,11 @@ class Sema final : public SemaBase {
10741075
return *CUDAPtr;
10751076
}
10761077

1078+
SemaDirectX &DirectX() {
1079+
assert(DirectXPtr);
1080+
return *DirectXPtr;
1081+
}
1082+
10771083
SemaHLSL &HLSL() {
10781084
assert(HLSLPtr);
10791085
return *HLSLPtr;
@@ -1212,6 +1218,7 @@ class Sema final : public SemaBase {
12121218
std::unique_ptr<SemaBPF> BPFPtr;
12131219
std::unique_ptr<SemaCodeCompletion> CodeCompletionPtr;
12141220
std::unique_ptr<SemaCUDA> CUDAPtr;
1221+
std::unique_ptr<SemaDirectX> DirectXPtr;
12151222
std::unique_ptr<SemaHLSL> HLSLPtr;
12161223
std::unique_ptr<SemaHexagon> HexagonPtr;
12171224
std::unique_ptr<SemaLoongArch> LoongArchPtr;
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===----- SemaDirectX.h ----- Semantic Analysis for DirectX constructs----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// \file
9+
/// This file declares semantic analysis for DirectX constructs.
10+
///
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CLANG_SEMA_SEMADIRECTX_H
14+
#define LLVM_CLANG_SEMA_SEMADIRECTX_H
15+
16+
#include "clang/AST/ASTFwd.h"
17+
#include "clang/Sema/SemaBase.h"
18+
19+
namespace clang {
20+
class SemaDirectX : public SemaBase {
21+
public:
22+
SemaDirectX(Sema &S);
23+
24+
bool CheckDirectXBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
25+
};
26+
} // namespace clang
27+
28+
#endif // LLVM_CLANG_SEMA_SEMADIRECTX_H

clang/lib/Basic/Targets/DirectX.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,31 @@
1212

1313
#include "DirectX.h"
1414
#include "Targets.h"
15+
#include "clang/Basic/TargetBuiltins.h"
1516

1617
using namespace clang;
1718
using namespace clang::targets;
1819

20+
static constexpr int NumBuiltins =
21+
clang::DirectX::LastTSBuiltin - Builtin::FirstTSBuiltin;
22+
23+
#define GET_BUILTIN_STR_TABLE
24+
#include "clang/Basic/BuiltinsDirectX.inc"
25+
#undef GET_BUILTIN_STR_TABLE
26+
27+
static constexpr Builtin::Info BuiltinInfos[] = {
28+
#define GET_BUILTIN_INFOS
29+
#include "clang/Basic/BuiltinsDirectX.inc"
30+
#undef GET_BUILTIN_INFOS
31+
};
32+
static_assert(std::size(BuiltinInfos) == NumBuiltins);
33+
1934
void DirectXTargetInfo::getTargetDefines(const LangOptions &Opts,
2035
MacroBuilder &Builder) const {
2136
DefineStd(Builder, "DIRECTX", Opts);
2237
}
38+
39+
llvm::SmallVector<Builtin::InfosShard>
40+
DirectXTargetInfo::getTargetBuiltins() const {
41+
return {{&BuiltinStrings, BuiltinInfos}};
42+
}

clang/lib/Basic/Targets/DirectX.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ class LLVM_LIBRARY_VISIBILITY DirectXTargetInfo : public TargetInfo {
7373
return Feature == "directx";
7474
}
7575

76-
llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override {
77-
return {};
78-
}
76+
llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override;
7977

8078
std::string_view getClobbers() const override { return ""; }
8179

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ static Value *EmitTargetArchBuiltinExpr(CodeGenFunction *CGF,
7070
case llvm::Triple::bpfeb:
7171
case llvm::Triple::bpfel:
7272
return CGF->EmitBPFBuiltinExpr(BuiltinID, E);
73+
case llvm::Triple::dxil:
74+
return CGF->EmitDirectXBuiltinExpr(BuiltinID, E);
7375
case llvm::Triple::x86:
7476
case llvm::Triple::x86_64:
7577
return CGF->EmitX86BuiltinExpr(BuiltinID, E);

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -380,18 +380,6 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
380380
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
381381
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
382382
}
383-
case Builtin::BI__builtin_hlsl_dot2add: {
384-
assert(CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil &&
385-
"Intrinsic dot2add is only allowed for dxil architecture");
386-
Value *A = EmitScalarExpr(E->getArg(0));
387-
Value *B = EmitScalarExpr(E->getArg(1));
388-
Value *C = EmitScalarExpr(E->getArg(2));
389-
390-
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
391-
return Builder.CreateIntrinsic(
392-
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
393-
"dx.dot2add");
394-
}
395383
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
396384
Value *A = EmitScalarExpr(E->getArg(0));
397385
Value *B = EmitScalarExpr(E->getArg(1));

clang/lib/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ add_clang_library(clangCodeGen
118118
SwiftCallingConv.cpp
119119
TargetBuiltins/ARM.cpp
120120
TargetBuiltins/AMDGPU.cpp
121+
TargetBuiltins/DirectX.cpp
121122
TargetBuiltins/Hexagon.cpp
122123
TargetBuiltins/NVPTX.cpp
123124
TargetBuiltins/PPC.cpp

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4809,6 +4809,7 @@ class CodeGenFunction : public CodeGenTypeCache {
48094809
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48104810
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
48114811
ReturnValueSlot ReturnValue);
4812+
llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48124813
llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48134814
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
48144815
const CallExpr *E);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===--------- DirectX.cpp - Emit LLVM Code for builtins ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This contains code to emit Builtin calls as LLVM code.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "CGHLSLRuntime.h"
14+
#include "CodeGenFunction.h"
15+
#include "clang/Basic/TargetBuiltins.h"
16+
#include "llvm/IR/Intrinsics.h"
17+
18+
using namespace clang;
19+
using namespace CodeGen;
20+
using namespace llvm;
21+
22+
Value *CodeGenFunction::EmitDirectXBuiltinExpr(unsigned BuiltinID,
23+
const CallExpr *E) {
24+
switch (BuiltinID) {
25+
case DirectX::BI__builtin_dx_dot2add: {
26+
Value *A = EmitScalarExpr(E->getArg(0));
27+
Value *B = EmitScalarExpr(E->getArg(1));
28+
Value *C = EmitScalarExpr(E->getArg(2));
29+
30+
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
31+
return Builder.CreateIntrinsic(
32+
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
33+
"dx.dot2add");
34+
}
35+
}
36+
return nullptr;
37+
}

clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
4646
}
4747

4848
constexpr float dot2add_impl(half2 a, half2 b, float c) {
49-
#if defined(__DIRECTX__)
50-
return __builtin_hlsl_dot2add(a, b, c);
49+
#if (__has_builtin(__builtin_dx_dot2add))
50+
return __builtin_dx_dot2add(a, b, c);
5151
#else
5252
return dot(a, b) + c;
5353
#endif

clang/lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ add_clang_library(clangSema
4747
SemaConsumer.cpp
4848
SemaCoroutine.cpp
4949
SemaCUDA.cpp
50+
SemaDirectX.cpp
5051
SemaDecl.cpp
5152
SemaDeclAttr.cpp
5253
SemaDeclCXX.cpp

clang/lib/Sema/Sema.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "clang/Sema/SemaCUDA.h"
4848
#include "clang/Sema/SemaCodeCompletion.h"
4949
#include "clang/Sema/SemaConsumer.h"
50+
#include "clang/Sema/SemaDirectX.h"
5051
#include "clang/Sema/SemaHLSL.h"
5152
#include "clang/Sema/SemaHexagon.h"
5253
#include "clang/Sema/SemaLoongArch.h"
@@ -226,6 +227,7 @@ Sema::Sema(Preprocessor &pp, ASTContext &ctxt, ASTConsumer &consumer,
226227
CodeCompletionPtr(
227228
std::make_unique<SemaCodeCompletion>(*this, CodeCompleter)),
228229
CUDAPtr(std::make_unique<SemaCUDA>(*this)),
230+
DirectXPtr(std::make_unique<SemaDirectX>(*this)),
229231
HLSLPtr(std::make_unique<SemaHLSL>(*this)),
230232
HexagonPtr(std::make_unique<SemaHexagon>(*this)),
231233
LoongArchPtr(std::make_unique<SemaLoongArch>(*this)),

clang/lib/Sema/SemaChecking.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
#include "clang/Sema/SemaAMDGPU.h"
6262
#include "clang/Sema/SemaARM.h"
6363
#include "clang/Sema/SemaBPF.h"
64+
#include "clang/Sema/SemaDirectX.h"
6465
#include "clang/Sema/SemaHLSL.h"
6566
#include "clang/Sema/SemaHexagon.h"
6667
#include "clang/Sema/SemaLoongArch.h"
@@ -1930,6 +1931,8 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
19301931
case llvm::Triple::bpfeb:
19311932
case llvm::Triple::bpfel:
19321933
return BPF().CheckBPFBuiltinFunctionCall(BuiltinID, TheCall);
1934+
case llvm::Triple::dxil:
1935+
return DirectX().CheckDirectXBuiltinFunctionCall(BuiltinID, TheCall);
19331936
case llvm::Triple::hexagon:
19341937
return Hexagon().CheckHexagonBuiltinFunctionCall(BuiltinID, TheCall);
19351938
case llvm::Triple::mips:

clang/lib/Sema/SemaDirectX.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- SemaDirectX.cpp - Semantic Analysis for DirectX constructs----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This implements Semantic Analysis for DirectX constructs.
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "clang/Sema/SemaDirectX.h"
12+
#include "clang/Basic/TargetBuiltins.h"
13+
#include "clang/Sema/Sema.h"
14+
15+
namespace clang {
16+
17+
SemaDirectX::SemaDirectX(Sema &S) : SemaBase(S) {}
18+
19+
bool SemaDirectX::CheckDirectXBuiltinFunctionCall(unsigned BuiltinID,
20+
CallExpr *TheCall) {
21+
return false;
22+
}
23+
} // namespace clang
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
3+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library %s -emit-llvm -o - | FileCheck %s
4+
5+
typedef _Float16 half;
6+
typedef half half2 __attribute__((ext_vector_type(2)));
7+
8+
// CHECK-LABEL: define float @test_dot2add(
9+
// CHECK-SAME: <2 x half> noundef [[X:%.*]], <2 x half> noundef [[Y:%.*]], float noundef [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
10+
// CHECK-NEXT: [[ENTRY:.*:]]
11+
// CHECK-NEXT: [[X_ADDR:%.*]] = alloca <2 x half>, align 4
12+
// CHECK-NEXT: [[Y_ADDR:%.*]] = alloca <2 x half>, align 4
13+
// CHECK-NEXT: [[Z_ADDR:%.*]] = alloca float, align 4
14+
// CHECK-NEXT: store <2 x half> [[X]], ptr [[X_ADDR]], align 4
15+
// CHECK-NEXT: store <2 x half> [[Y]], ptr [[Y_ADDR]], align 4
16+
// CHECK-NEXT: store float [[Z]], ptr [[Z_ADDR]], align 4
17+
// CHECK-NEXT: [[TMP0:%.*]] = load <2 x half>, ptr [[X_ADDR]], align 4
18+
// CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[Y_ADDR]], align 4
19+
// CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[Z_ADDR]], align 4
20+
// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add.v2f16(<2 x half> [[TMP0]], <2 x half> [[TMP1]], float [[TMP2]])
21+
// CHECK-NEXT: ret float [[DX_DOT2ADD]]
22+
//
23+
float test_dot2add(half2 X, half2 Y, float Z) { return __builtin_dx_dot2add(X, Y, Z); }

0 commit comments

Comments
 (0)