Skip to content

[SPIR-V] Validate and fix bit width of scalar registers #95147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR improves legalization process of SPIR-V instructions. Namely, it introduces validation and fixing of bit width of scalar registers as a part of pre-legalizer. A test case is added that demonstrates ability to legalize instructions with non 8/16/32/64 bit width both with and without vendor-specific SPIR-V extension (SPV_INTEL_arbitrary_precision_integers). In the case of absence of the extension, a generated SPIR-V code will fallback to 8/16/32/64 bit width in OpTypeInt, but SPIR-V Backend still is able to legalize operations with original integer sizes.

@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR improves legalization process of SPIR-V instructions. Namely, it introduces validation and fixing of bit width of scalar registers as a part of pre-legalizer. A test case is added that demonstrates ability to legalize instructions with non 8/16/32/64 bit width both with and without vendor-specific SPIR-V extension (SPV_INTEL_arbitrary_precision_integers). In the case of absence of the extension, a generated SPIR-V code will fallback to 8/16/32/64 bit width in OpTypeInt, but SPIR-V Backend still is able to legalize operations with original integer sizes.


Full diff: https://github.com/llvm/llvm-project/pull/95147.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+20-5)
  • (added) llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll (+56)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index adc5b36af6f18..aaba6e873e2c1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -271,6 +271,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
   return SpirvTy;
 }
 
+// To support current approach and limitations wrt. bit width here we widen a
+// scalar register with a bit width greater than 1 to valid sizes and cap it to
+// 64 width.
+static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
+  LLT RegType = MRI.getType(Reg);
+  if (!RegType.isScalar())
+    return;
+  unsigned Sz = RegType.getScalarSizeInBits();
+  if (Sz == 1)
+    return;
+  unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
+  if (NewSz != Sz)
+    MRI.setType(Reg, LLT::scalar(NewSz));
+}
+
 static std::pair<Register, unsigned>
 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
                const SPIRVGlobalRegistry &GR) {
@@ -406,6 +421,11 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       MachineInstr &MI = *MII;
       unsigned MIOp = MI.getOpcode();
 
+      // validate bit width of scalar registers
+      for (const auto& MOP : MI.operands())
+        if (MOP.isReg())
+          widenScalarLLTNextPow2(MOP.getReg(), MRI);
+
       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
         Register Reg = MI.getOperand(1).getReg();
         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
@@ -475,11 +495,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
         propagateSPIRVType(&MI, GR, MRI, MIB);
-      } else if (MIOp == TargetOpcode::G_BITREVERSE) {
-        Register Reg = MI.getOperand(0).getReg();
-        LLT RegType = MRI.getType(Reg);
-        if (RegType.getSizeInBits() < 32)
-          MRI.setType(Reg, LLT::scalar(32));
       }
 
       if (MII == Begin)
diff --git a/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll b/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
new file mode 100644
index 0000000000000..437e161864eca
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
@@ -0,0 +1,56 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXT
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXT
+
+; CHECK-DAG: OpName %[[#Struct:]] "struct"
+; CHECK-DAG: OpName %[[#Arg:]] "arg"
+; CHECK-DAG: OpName %[[#QArg:]] "qarg"
+; CHECK-DAG: OpName %[[#R:]] "r"
+; CHECK-DAG: OpName %[[#Q:]] "q"
+; CHECK-DAG: OpName %[[#Tr:]] "tr"
+; CHECK-DAG: OpName %[[#Tq:]] "tq"
+; CHECK-DAG: %[[#Struct]] = OpTypeStruct %[[#]] %[[#]] %[[#]]
+; CHECK-DAG: %[[#PtrStruct:]] = OpTypePointer CrossWorkgroup %[[#Struct]]
+; CHECK-EXT-DAG: %[[#Int40:]] = OpTypeInt 40 0
+; CHECK-EXT-DAG: %[[#Int50:]] = OpTypeInt 50 0
+; CHECK-NOEXT-DAG: %[[#Int40:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#PtrInt40:]] = OpTypePointer CrossWorkgroup %[[#Int40]]
+
+; CHECK: OpFunction
+
+; CHECK-EXT: %[[#Tr]] = OpUConvert %[[#Int40]] %[[#R]]
+; CHECK-EXT: %[[#Store:]] = OpInBoundsPtrAccessChain %[[#PtrStruct]] %[[#Arg]] %[[#]]
+; CHECK-EXT: %[[#StoreAsInt40:]] = OpBitcast %[[#PtrInt40]] %[[#Store]]
+; CHECK-EXT: OpStore %[[#StoreAsInt40]] %[[#Tr]]
+
+; CHECK-NOEXT: %[[#Store:]] = OpInBoundsPtrAccessChain %[[#PtrStruct]] %[[#Arg]] %[[#]]
+; CHECK-NOEXT: %[[#StoreAsInt40:]] = OpBitcast %[[#PtrInt40]] %[[#Store]]
+; CHECK-NOEXT: OpStore %[[#StoreAsInt40]] %[[#R]]
+
+; CHECK: OpFunction
+
+; CHECK-EXT: %[[#Tq]] = OpUConvert %[[#Int40]] %[[#Q]]
+; CHECK-EXT: OpStore %[[#QArg]] %[[#Tq]]
+
+; CHECK-NOEXT: OpStore %[[#QArg]] %[[#Q]]
+
+%struct = type <{ i32, i8, [3 x i8] }>
+
+define spir_kernel void @foo(ptr addrspace(1) %arg, i64 %r) {
+  %tr = trunc i64 %r to i40
+  %addr = getelementptr inbounds %struct, ptr addrspace(1) %arg, i64 0
+  store i40 %tr, ptr addrspace(1) %addr
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %qarg, i50 %q) {
+  %tq = trunc i50 %q to i40
+  store i40 %tq, ptr addrspace(1) %qarg
+  ret void
+}

Copy link

github-actions bot commented Jun 11, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 163d036 into llvm:main Jun 11, 2024
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants