Skip to content

Commit 0789430

Browse files
committed
[HLSL][SPIR-V] Add SV_DispatchThreadID semantic support
Add SPIR-V backend support for the HLSL SV_DispatchThreadID semantic attribute, which is lowered to a @llvm.dx.thread.id intrinsic. Fixes #82534
1 parent d052148 commit 0789430

File tree

3 files changed

+148
-1
lines changed

3 files changed

+148
-1
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
525525

526526
// Output decorations for the GV.
527527
// TODO: maybe move to GenerateDecorations pass.
528-
if (IsConst)
528+
const SPIRVSubtarget &ST =
529+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
530+
if (IsConst && ST.isOpenCLEnv())
529531
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
530532

531533
if (GVar && GVar->getAlign().valueOrOne().value() != 1) {

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
2828
#include "llvm/CodeGen/MachineInstrBuilder.h"
2929
#include "llvm/CodeGen/MachineRegisterInfo.h"
30+
#include "llvm/IR/IntrinsicsDirectX.h"
3031
#include "llvm/IR/IntrinsicsSPIRV.h"
3132
#include "llvm/Support/Debug.h"
3233

@@ -194,6 +195,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
194195
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
195196
MachineInstr &I) const;
196197

198+
bool selectDXThreadId(Register ResVReg, const SPIRVType *ResType,
199+
MachineInstr &I) const;
200+
197201
bool selectUnmergeValues(MachineInstr &I) const;
198202

199203
Register buildI32Constant(uint32_t Val, MachineInstr &I,
@@ -301,6 +305,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
301305
case TargetOpcode::G_FREEZE:
302306
return selectFreeze(ResVReg, ResType, I);
303307

308+
case TargetOpcode::G_INTRINSIC:
304309
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
305310
case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
306311
return selectIntrinsic(ResVReg, ResType, I);
@@ -1614,6 +1619,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
16141619
.addUse(I.getOperand(2).getReg())
16151620
.addUse(I.getOperand(3).getReg());
16161621
break;
1622+
case Intrinsic::dx_thread_id:
1623+
return selectDXThreadId(ResVReg, ResType, I);
16171624
default:
16181625
llvm_unreachable("Intrinsic selection not implemented");
16191626
}
@@ -1864,6 +1871,68 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
18641871
return Result;
18651872
}
18661873

1874+
bool SPIRVInstructionSelector::selectDXThreadId(Register ResVReg,
1875+
const SPIRVType *ResType,
1876+
MachineInstr &I) const {
1877+
// DX intrinsic: @llvm.dx.thread.id(i32)
1878+
// ID Name Description
1879+
// 93 ThreadId reads the thread ID
1880+
1881+
MachineIRBuilder MIRBuilder(I);
1882+
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
1883+
const SPIRVType *Vec3Ty =
1884+
GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
1885+
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
1886+
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
1887+
1888+
// Create new register for GlobalInvocationID builtin variable.
1889+
Register NewRegister =
1890+
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
1891+
MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 32));
1892+
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
1893+
1894+
// Build GlobalInvocationID global variable with the necessary decorations.
1895+
Register Variable = GR.buildGlobalVariable(
1896+
NewRegister, PtrType,
1897+
getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
1898+
SPIRV::StorageClass::Input, nullptr, true, true,
1899+
SPIRV::LinkageType::Import, MIRBuilder, false);
1900+
1901+
// Create new register for loading value.
1902+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1903+
Register LoadedRegister = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1904+
MIRBuilder.getMRI()->setType(LoadedRegister, LLT::pointer(0, 32));
1905+
GR.assignSPIRVTypeToVReg(Vec3Ty, LoadedRegister, MIRBuilder.getMF());
1906+
1907+
// Load v3uint value from the global variable.
1908+
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
1909+
.addDef(LoadedRegister)
1910+
.addUse(GR.getSPIRVTypeID(Vec3Ty))
1911+
.addUse(Variable);
1912+
1913+
// Get Thread ID index. Expecting operand is a constant immediate value,
1914+
// wrapped in a type assignment.
1915+
assert(I.getOperand(2).isReg());
1916+
Register ThreadIdReg = I.getOperand(2).getReg();
1917+
SPIRVType *ConstTy = this->MRI->getVRegDef(ThreadIdReg);
1918+
assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
1919+
ConstTy->getOperand(1).isReg());
1920+
Register ConstReg = ConstTy->getOperand(1).getReg();
1921+
const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
1922+
assert(Const && Const->getOpcode() == TargetOpcode::G_CONSTANT);
1923+
const llvm::APInt &Val = Const->getOperand(1).getCImm()->getValue();
1924+
const uint32_t ThreadId = Val.getZExtValue();
1925+
1926+
// Extract the thread ID from the loaded vector value.
1927+
MachineBasicBlock &BB = *I.getParent();
1928+
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1929+
.addDef(ResVReg)
1930+
.addUse(GR.getSPIRVTypeID(ResType))
1931+
.addUse(LoadedRegister)
1932+
.addImm(ThreadId);
1933+
return MIB.constrainAllUses(TII, TRI, RBI);
1934+
}
1935+
18671936
namespace llvm {
18681937
InstructionSelector *
18691938
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; This file generated from the following HLSL:
5+
; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header -o - DispatchThreadID.hlsl
6+
;
7+
; [shader("compute")]
8+
; [numthreads(1,1,1)]
9+
; void main(uint3 ID : SV_DispatchThreadID) {}
10+
11+
; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
12+
; CHECK-DAG: %[[#v3int:]] = OpTypeVector %[[#int]] 3
13+
; CHECK-DAG: %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]]
14+
; CHECK-DAG: %[[#tempvar:]] = OpUndef %[[#v3int]]
15+
; CHECK-DAG: %[[#GlobalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input
16+
17+
; CHECK-DAG: OpEntryPoint GLCompute {{.*}} %[[#GlobalInvocationId]]
18+
; CHECK-DAG: OpName %[[#GlobalInvocationId]] "__spirv_BuiltInGlobalInvocationId"
19+
; CHECK-DAG: OpDecorate %[[#GlobalInvocationId]] LinkageAttributes "__spirv_BuiltInGlobalInvocationId" Import
20+
; CHECK-DAG: OpDecorate %[[#GlobalInvocationId]] BuiltIn GlobalInvocationId
21+
22+
; ModuleID = 'DispatchThreadID.hlsl'
23+
source_filename = "DispatchThreadID.hlsl"
24+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
25+
target triple = "spirv-unknown-vulkan-library"
26+
27+
; Function Attrs: noinline norecurse nounwind optnone
28+
define internal spir_func void @main(<3 x i32> noundef %ID) #0 {
29+
entry:
30+
%ID.addr = alloca <3 x i32>, align 16
31+
store <3 x i32> %ID, ptr %ID.addr, align 16
32+
ret void
33+
}
34+
35+
; Function Attrs: norecurse
36+
define void @main.1() #1 {
37+
entry:
38+
39+
; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
40+
; CHECK: %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0
41+
%0 = call i32 @llvm.dx.thread.id(i32 0)
42+
43+
; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0
44+
%1 = insertelement <3 x i32> poison, i32 %0, i64 0
45+
46+
; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
47+
; CHECK: %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1
48+
%2 = call i32 @llvm.dx.thread.id(i32 1)
49+
50+
; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1
51+
%3 = insertelement <3 x i32> %1, i32 %2, i64 1
52+
53+
; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
54+
; CHECK: %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2
55+
%4 = call i32 @llvm.dx.thread.id(i32 2)
56+
57+
; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2
58+
%5 = insertelement <3 x i32> %3, i32 %4, i64 2
59+
60+
call void @main(<3 x i32> %5)
61+
ret void
62+
}
63+
64+
; Function Attrs: nounwind willreturn memory(none)
65+
declare i32 @llvm.dx.thread.id(i32) #2
66+
67+
attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
68+
attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
69+
attributes #2 = { nounwind willreturn memory(none) }
70+
71+
!llvm.module.flags = !{!0, !1}
72+
!llvm.ident = !{!2}
73+
74+
!0 = !{i32 1, !"wchar_size", i32 4}
75+
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
76+
!2 = !{!"clang version 19.0.0git ([email protected]:llvm/llvm-project.git c9afeaa6434a61b3b3a57c8eda6d2cfb25ab675b)"}

0 commit comments

Comments
 (0)