Skip to content

Commit 5fd4f32

Browse files
authored
[HLSL] Implement SV_GroupID semantic (#115911)
Support SV_GroupID attribute. Translate it into dx.group.id in clang codeGen. Fixes: #70120
1 parent 4ab298b commit 5fd4f32

File tree

11 files changed

+134
-12
lines changed

11 files changed

+134
-12
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4621,6 +4621,13 @@ def HLSLNumThreads: InheritableAttr {
46214621
let Documentation = [NumThreadsDocs];
46224622
}
46234623

4624+
def HLSLSV_GroupID: HLSLAnnotationAttr {
4625+
let Spellings = [HLSLAnnotation<"SV_GroupID">];
4626+
let Subjects = SubjectList<[ParmVar, Field]>;
4627+
let LangOpts = [HLSL];
4628+
let Documentation = [HLSLSV_GroupIDDocs];
4629+
}
4630+
46244631
def HLSLSV_GroupIndex: HLSLAnnotationAttr {
46254632
let Spellings = [HLSLAnnotation<"SV_GroupIndex">];
46264633
let Subjects = SubjectList<[ParmVar, GlobalVar]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7934,6 +7934,16 @@ randomized.
79347934
}];
79357935
}
79367936

7937+
def HLSLSV_GroupIDDocs : Documentation {
7938+
let Category = DocCatFunction;
7939+
let Content = [{
7940+
The ``SV_GroupID`` semantic, when applied to an input parameter, specifies which
7941+
thread group a shader is executing in. This attribute is only supported in compute shaders.
7942+
7943+
The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid
7944+
}];
7945+
}
7946+
79377947
def HLSLSV_GroupIndexDocs : Documentation {
79387948
let Category = DocCatFunction;
79397949
let Content = [{

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
119119
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
120120
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
121121
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
122+
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
122123
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
123124
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
124125
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
@@ -136,6 +137,9 @@ class SemaHLSL : public SemaBase {
136137

137138
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old);
138139

140+
// Diagnose whether the input ID is uint/unit2/uint3 type.
141+
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);
142+
139143
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
140144

141145
QualType getInoutParameterType(QualType Ty);

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
389389
CGM.getIntrinsic(getThreadIdIntrinsic());
390390
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
391391
}
392+
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
393+
llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
394+
return buildVectorInput(B, GroupIDIntrinsic, Ty);
395+
}
392396
assert(false && "Unhandled parameter attribute");
393397
return nullptr;
394398
}

clang/lib/Parse/ParseHLSL.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
280280
case ParsedAttr::UnknownAttribute:
281281
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
282282
return;
283+
case ParsedAttr::AT_HLSLSV_GroupID:
283284
case ParsedAttr::AT_HLSLSV_GroupIndex:
284285
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
285286
break;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7103,6 +7103,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
71037103
case ParsedAttr::AT_HLSLWaveSize:
71047104
S.HLSL().handleWaveSizeAttr(D, AL);
71057105
break;
7106+
case ParsedAttr::AT_HLSLSV_GroupID:
7107+
S.HLSL().handleSV_GroupIDAttr(D, AL);
7108+
break;
71067109
case ParsedAttr::AT_HLSLSV_GroupIndex:
71077110
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
71087111
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
434434
switch (AnnotationAttr->getKind()) {
435435
case attr::HLSLSV_DispatchThreadID:
436436
case attr::HLSLSV_GroupIndex:
437+
case attr::HLSLSV_GroupID:
437438
if (ST == llvm::Triple::Compute)
438439
return;
439440
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
@@ -764,26 +765,36 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
764765
D->addAttr(NewAttr);
765766
}
766767

767-
static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
768-
if (!T->hasUnsignedIntegerRepresentation())
768+
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
769+
const auto *VT = T->getAs<VectorType>();
770+
771+
if (!T->hasUnsignedIntegerRepresentation() ||
772+
(VT && VT->getNumElements() > 3)) {
773+
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
774+
<< AL << "uint/uint2/uint3";
769775
return false;
770-
if (const auto *VT = T->getAs<VectorType>())
771-
return VT->getNumElements() <= 3;
776+
}
777+
772778
return true;
773779
}
774780

775781
void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
776782
auto *VD = cast<ValueDecl>(D);
777-
if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
778-
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
779-
<< AL << "uint/uint2/uint3";
783+
if (!diagnoseInputIDType(VD->getType(), AL))
780784
return;
781-
}
782785

783786
D->addAttr(::new (getASTContext())
784787
HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
785788
}
786789

790+
void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
791+
auto *VD = cast<ValueDecl>(D);
792+
if (!diagnoseInputIDType(VD->getType(), AL))
793+
return;
794+
795+
D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
796+
}
797+
787798
void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
788799
if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
789800
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
2+
3+
// Make sure SV_GroupID translated into dx.group.id.
4+
5+
// CHECK: define void @foo()
6+
// CHECK: %[[#ID:]] = call i32 @llvm.dx.group.id(i32 0)
7+
// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
8+
[shader("compute")]
9+
[numthreads(8,8,1)]
10+
void foo(uint Idx : SV_GroupID) {}
11+
12+
// CHECK: define void @bar()
13+
// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0)
14+
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
15+
// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1)
16+
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
17+
// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
18+
[shader("compute")]
19+
[numthreads(8,8,1)]
20+
void bar(uint2 Idx : SV_GroupID) {}
21+
22+
// CHECK: define void @test()
23+
// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0)
24+
// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
25+
// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1)
26+
// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
27+
// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.group.id(i32 2)
28+
// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
29+
// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
30+
[shader("compute")]
31+
[numthreads(8,8,1)]
32+
void test(uint3 Idx : SV_GroupID) {}

clang/test/SemaHLSL/Semantics/entry_parameter.hlsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s
33

44
[numthreads(8,8,1)]
5-
// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
6-
// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
7-
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) {
8-
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)'
5+
// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
6+
// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
7+
// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
8+
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
9+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
910
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
1011
// CHECK-NEXT: HLSLSV_GroupIndexAttr
1112
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
1213
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
14+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
15+
// CHECK-NEXT: HLSLSV_GroupIDAttr
1316
}

clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,25 @@ struct ST2 {
2727
static uint X : SV_DispatchThreadID;
2828
uint s : SV_DispatchThreadID;
2929
};
30+
31+
[numthreads(8,8,1)]
32+
// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
33+
void CSMain_GID(float ID : SV_GroupID) {
34+
}
35+
36+
[numthreads(8,8,1)]
37+
// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
38+
void CSMain2_GID(ST GID : SV_GroupID) {
39+
40+
}
41+
42+
void foo_GID() {
43+
// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
44+
uint GIS : SV_GroupID;
45+
}
46+
47+
struct ST2_GID {
48+
// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
49+
static uint GID : SV_GroupID;
50+
uint s_gid : SV_GroupID;
51+
};

clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,28 @@ void CSMain3(uint3 : SV_DispatchThreadID) {
2424
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 'uint3'
2525
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
2626
}
27+
28+
[numthreads(8,8,1)]
29+
void CSMain_GID(uint ID : SV_GroupID) {
30+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GID 'void (uint)'
31+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:22 ID 'uint'
32+
// CHECK-NEXT: HLSLSV_GroupIDAttr
33+
}
34+
[numthreads(8,8,1)]
35+
void CSMain1_GID(uint2 ID : SV_GroupID) {
36+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GID 'void (uint2)'
37+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint2'
38+
// CHECK-NEXT: HLSLSV_GroupIDAttr
39+
}
40+
[numthreads(8,8,1)]
41+
void CSMain2_GID(uint3 ID : SV_GroupID) {
42+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GID 'void (uint3)'
43+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint3'
44+
// CHECK-NEXT: HLSLSV_GroupIDAttr
45+
}
46+
[numthreads(8,8,1)]
47+
void CSMain3_GID(uint3 : SV_GroupID) {
48+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GID 'void (uint3)'
49+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
50+
// CHECK-NEXT: HLSLSV_GroupIDAttr
51+
}

0 commit comments

Comments
 (0)