Skip to content

Commit 66ebda4

Browse files
Add support for the SPIR-V extension SPV_KHR_uniform_group_instructions (#82064)
This PR is to add support for the SPIR-V extension SPV_KHR_uniform_group_instructions that adds new instructions to SPIR-V to support additional group operations within uniform control flow.
1 parent f8cbb67 commit 66ebda4

File tree

7 files changed

+276
-1
lines changed

7 files changed

+276
-1
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ struct AtomicFloatingBuiltin {
100100

101101
#define GET_AtomicFloatingBuiltins_DECL
102102
#define GET_AtomicFloatingBuiltins_IMPL
103+
struct GroupUniformBuiltin {
104+
StringRef Name;
105+
uint32_t Opcode;
106+
bool IsLogical;
107+
};
108+
109+
#define GET_GroupUniformBuiltins_DECL
110+
#define GET_GroupUniformBuiltins_IMPL
103111

104112
struct GetBuiltin {
105113
StringRef Name;
@@ -1014,6 +1022,57 @@ static bool generateIntelSubgroupsInst(const SPIRV::IncomingCall *Call,
10141022
return true;
10151023
}
10161024

1025+
static bool generateGroupUniformInst(const SPIRV::IncomingCall *Call,
1026+
MachineIRBuilder &MIRBuilder,
1027+
SPIRVGlobalRegistry *GR) {
1028+
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1029+
MachineFunction &MF = MIRBuilder.getMF();
1030+
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1031+
if (!ST->canUseExtension(
1032+
SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1033+
std::string DiagMsg = std::string(Builtin->Name) +
1034+
": the builtin requires the following SPIR-V "
1035+
"extension: SPV_KHR_uniform_group_instructions";
1036+
report_fatal_error(DiagMsg.c_str(), false);
1037+
}
1038+
const SPIRV::GroupUniformBuiltin *GroupUniform =
1039+
SPIRV::lookupGroupUniformBuiltin(Builtin->Name);
1040+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1041+
1042+
Register GroupResultReg = Call->ReturnRegister;
1043+
MRI->setRegClass(GroupResultReg, &SPIRV::IDRegClass);
1044+
1045+
// Scope
1046+
Register ScopeReg = Call->Arguments[0];
1047+
MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
1048+
1049+
// Group Operation
1050+
Register ConstGroupOpReg = Call->Arguments[1];
1051+
const MachineInstr *Const = getDefInstrMaybeConstant(ConstGroupOpReg, MRI);
1052+
if (!Const || Const->getOpcode() != TargetOpcode::G_CONSTANT)
1053+
report_fatal_error(
1054+
"expect a constant group operation for a uniform group instruction",
1055+
false);
1056+
const MachineOperand &ConstOperand = Const->getOperand(1);
1057+
if (!ConstOperand.isCImm())
1058+
report_fatal_error("uniform group instructions: group operation must be an "
1059+
"integer constant",
1060+
false);
1061+
1062+
// Value
1063+
Register ValueReg = Call->Arguments[2];
1064+
MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
1065+
1066+
auto MIB = MIRBuilder.buildInstr(GroupUniform->Opcode)
1067+
.addDef(GroupResultReg)
1068+
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
1069+
.addUse(ScopeReg);
1070+
addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1071+
MIB.addUse(ValueReg);
1072+
1073+
return true;
1074+
}
1075+
10171076
// These queries ask for a single size_t result for a given dimension index, e.g
10181077
// size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
10191078
// these values are all vec3 types, so we need to extract the correct index or
@@ -2112,6 +2171,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
21122171
return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
21132172
case SPIRV::IntelSubgroups:
21142173
return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
2174+
case SPIRV::GroupUniform:
2175+
return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
21152176
}
21162177
return false;
21172178
}

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def VectorLoadStore : BuiltinGroup;
5656
def LoadStore : BuiltinGroup;
5757
def IntelSubgroups : BuiltinGroup;
5858
def AtomicFloating : BuiltinGroup;
59+
def GroupUniform : BuiltinGroup;
5960

6061
//===----------------------------------------------------------------------===//
6162
// Class defining a demangled builtin record. The information in the record
@@ -605,7 +606,10 @@ class GroupBuiltin<string name, Op operation> {
605606
!eq(operation, OpGroupNonUniformBallotFindMSB));
606607
bit IsLogical = !or(!eq(operation, OpGroupNonUniformLogicalAnd),
607608
!eq(operation, OpGroupNonUniformLogicalOr),
608-
!eq(operation, OpGroupNonUniformLogicalXor));
609+
!eq(operation, OpGroupNonUniformLogicalXor),
610+
!eq(operation, OpGroupLogicalAndKHR),
611+
!eq(operation, OpGroupLogicalOrKHR),
612+
!eq(operation, OpGroupLogicalXorKHR));
609613
bit NoGroupOperation = !or(IsElect, IsAllOrAny, IsAllEqual,
610614
IsBallot, IsInverseBallot,
611615
IsBallotBitExtract, IsBallotFindBit,
@@ -873,6 +877,51 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
873877
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
874878
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
875879

880+
// cl_khr_work_group_uniform_arithmetic / SPV_KHR_uniform_group_instructions
881+
defm : DemangledGroupBuiltin<"group_reduce_imul", OnlyWork, OpGroupIMulKHR>;
882+
defm : DemangledGroupBuiltin<"group_reduce_mulu", OnlyWork, OpGroupIMulKHR>;
883+
defm : DemangledGroupBuiltin<"group_reduce_muls", OnlyWork, OpGroupIMulKHR>;
884+
defm : DemangledGroupBuiltin<"group_scan_inclusive_imul", OnlyWork, OpGroupIMulKHR>;
885+
defm : DemangledGroupBuiltin<"group_scan_inclusive_mulu", OnlyWork, OpGroupIMulKHR>;
886+
defm : DemangledGroupBuiltin<"group_scan_inclusive_muls", OnlyWork, OpGroupIMulKHR>;
887+
defm : DemangledGroupBuiltin<"group_scan_exclusive_imul", OnlyWork, OpGroupIMulKHR>;
888+
defm : DemangledGroupBuiltin<"group_scan_exclusive_mulu", OnlyWork, OpGroupIMulKHR>;
889+
defm : DemangledGroupBuiltin<"group_scan_exclusive_muls", OnlyWork, OpGroupIMulKHR>;
890+
891+
defm : DemangledGroupBuiltin<"group_reduce_mulf", OnlyWork, OpGroupFMulKHR>;
892+
defm : DemangledGroupBuiltin<"group_reduce_mulh", OnlyWork, OpGroupFMulKHR>;
893+
defm : DemangledGroupBuiltin<"group_reduce_muld", OnlyWork, OpGroupFMulKHR>;
894+
defm : DemangledGroupBuiltin<"group_scan_inclusive_mulf", OnlyWork, OpGroupFMulKHR>;
895+
defm : DemangledGroupBuiltin<"group_scan_inclusive_mulh", OnlyWork, OpGroupFMulKHR>;
896+
defm : DemangledGroupBuiltin<"group_scan_inclusive_muld", OnlyWork, OpGroupFMulKHR>;
897+
defm : DemangledGroupBuiltin<"group_scan_exclusive_mulf", OnlyWork, OpGroupFMulKHR>;
898+
defm : DemangledGroupBuiltin<"group_scan_exclusive_mulh", OnlyWork, OpGroupFMulKHR>;
899+
defm : DemangledGroupBuiltin<"group_scan_exclusive_muld", OnlyWork, OpGroupFMulKHR>;
900+
901+
defm : DemangledGroupBuiltin<"group_scan_exclusive_and", OnlyWork, OpGroupBitwiseAndKHR>;
902+
defm : DemangledGroupBuiltin<"group_scan_inclusive_and", OnlyWork, OpGroupBitwiseAndKHR>;
903+
defm : DemangledGroupBuiltin<"group_reduce_and", OnlyWork, OpGroupBitwiseAndKHR>;
904+
905+
defm : DemangledGroupBuiltin<"group_scan_exclusive_or", OnlyWork, OpGroupBitwiseOrKHR>;
906+
defm : DemangledGroupBuiltin<"group_scan_inclusive_or", OnlyWork, OpGroupBitwiseOrKHR>;
907+
defm : DemangledGroupBuiltin<"group_reduce_or", OnlyWork, OpGroupBitwiseOrKHR>;
908+
909+
defm : DemangledGroupBuiltin<"group_scan_exclusive_xor", OnlyWork, OpGroupBitwiseXorKHR>;
910+
defm : DemangledGroupBuiltin<"group_scan_inclusive_xor", OnlyWork, OpGroupBitwiseXorKHR>;
911+
defm : DemangledGroupBuiltin<"group_reduce_xor", OnlyWork, OpGroupBitwiseXorKHR>;
912+
913+
defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_and", OnlyWork, OpGroupLogicalAndKHR>;
914+
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_and", OnlyWork, OpGroupLogicalAndKHR>;
915+
defm : DemangledGroupBuiltin<"group_reduce_logical_and", OnlyWork, OpGroupLogicalAndKHR>;
916+
917+
defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_or", OnlyWork, OpGroupLogicalOrKHR>;
918+
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_or", OnlyWork, OpGroupLogicalOrKHR>;
919+
defm : DemangledGroupBuiltin<"group_reduce_logical_or", OnlyWork, OpGroupLogicalOrKHR>;
920+
921+
defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;
922+
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;
923+
defm : DemangledGroupBuiltin<"group_reduce_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;
924+
876925
//===----------------------------------------------------------------------===//
877926
// Class defining an atomic instruction on floating-point numbers.
878927
//
@@ -967,6 +1016,52 @@ foreach i = ["", "2", "4", "8", "16"] in {
9671016
}
9681017
// OpSubgroupImageBlockReadINTEL and OpSubgroupImageBlockWriteINTEL are to be resolved later on (in code)
9691018

1019+
//===----------------------------------------------------------------------===//
1020+
// Class defining a builtin for group operations within uniform control flow.
1021+
// It should be translated into a SPIR-V instruction using
1022+
// the SPV_KHR_uniform_group_instructions extension.
1023+
//
1024+
// name is the demangled name of the given builtin.
1025+
// opcode specifies the SPIR-V operation code of the generated instruction.
1026+
//===----------------------------------------------------------------------===//
1027+
class GroupUniformBuiltin<string name, Op operation> {
1028+
string Name = name;
1029+
Op Opcode = operation;
1030+
bit IsLogical = !or(!eq(operation, OpGroupLogicalAndKHR),
1031+
!eq(operation, OpGroupLogicalOrKHR),
1032+
!eq(operation, OpGroupLogicalXorKHR));
1033+
}
1034+
1035+
// Table gathering all the Intel sub group builtins.
1036+
def GroupUniformBuiltins : GenericTable {
1037+
let FilterClass = "GroupUniformBuiltin";
1038+
let Fields = ["Name", "Opcode", "IsLogical"];
1039+
}
1040+
1041+
// Function to lookup group builtins by their name and set.
1042+
def lookupGroupUniformBuiltin : SearchIndex {
1043+
let Table = GroupUniformBuiltins;
1044+
let Key = ["Name"];
1045+
}
1046+
1047+
// Multiclass used to define incoming builtin records for
1048+
// the SPV_KHR_uniform_group_instructions extension
1049+
// and corresponding work group builtin records.
1050+
multiclass DemangledGroupUniformBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
1051+
def : DemangledBuiltin<!strconcat("__spirv_Group", name), OpenCL_std, GroupUniform, minNumArgs, maxNumArgs>;
1052+
def : GroupUniformBuiltin<!strconcat("__spirv_Group", name), operation>;
1053+
}
1054+
1055+
// cl_khr_work_group_uniform_arithmetic / SPV_KHR_uniform_group_instructions
1056+
defm : DemangledGroupUniformBuiltin<"IMulKHR", 3, 3, OpGroupIMulKHR>;
1057+
defm : DemangledGroupUniformBuiltin<"FMulKHR", 3, 3, OpGroupFMulKHR>;
1058+
defm : DemangledGroupUniformBuiltin<"BitwiseAndKHR", 3, 3, OpGroupBitwiseAndKHR>;
1059+
defm : DemangledGroupUniformBuiltin<"BitwiseOrKHR", 3, 3, OpGroupBitwiseOrKHR>;
1060+
defm : DemangledGroupUniformBuiltin<"BitwiseXorKHR", 3, 3, OpGroupBitwiseXorKHR>;
1061+
defm : DemangledGroupUniformBuiltin<"LogicalAndKHR", 3, 3, OpGroupLogicalAndKHR>;
1062+
defm : DemangledGroupUniformBuiltin<"LogicalOrKHR", 3, 3, OpGroupLogicalOrKHR>;
1063+
defm : DemangledGroupUniformBuiltin<"LogicalXorKHR", 3, 3, OpGroupLogicalXorKHR>;
1064+
9701065
//===----------------------------------------------------------------------===//
9711066
// Class defining a get builtin record used for lowering builtin calls such as
9721067
// "get_sub_group_eq_mask" or "get_global_id" to SPIR-V instructions.

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,8 @@ def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$
776776
def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;
777777

778778
// 3.49.21. Group and Subgroup Instructions
779+
780+
// - SPV_INTEL_subgroups
779781
def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId),
780782
"$res = OpSubgroupShuffleINTEL $type $data $invocationId">;
781783
def OpSubgroupShuffleDownINTEL: Op<5572, (outs ID:$res), (ins TYPE:$type, ID:$current, ID:$next, ID:$delta),
@@ -792,3 +794,21 @@ def OpSubgroupImageBlockReadINTEL: Op<5577, (outs ID:$res), (ins TYPE:$type, ID:
792794
"$res = OpSubgroupImageBlockReadINTEL $type $image $coordinate">;
793795
def OpSubgroupImageBlockWriteINTEL: Op<5578, (outs), (ins ID:$image, ID:$coordinate, ID:$data),
794796
"OpSubgroupImageBlockWriteINTEL $image $coordinate $data">;
797+
798+
// - SPV_KHR_uniform_group_instructions
799+
def OpGroupIMulKHR: Op<6401, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
800+
"$res = OpGroupIMulKHR $type $scope $groupOp $value">;
801+
def OpGroupFMulKHR: Op<6402, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
802+
"$res = OpGroupFMulKHR $type $scope $groupOp $value">;
803+
def OpGroupBitwiseAndKHR: Op<6403, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
804+
"$res = OpGroupBitwiseAndKHR $type $scope $groupOp $value">;
805+
def OpGroupBitwiseOrKHR: Op<6404, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
806+
"$res = OpGroupBitwiseOrKHR $type $scope $groupOp $value">;
807+
def OpGroupBitwiseXorKHR: Op<6405, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
808+
"$res = OpGroupBitwiseXorKHR $type $scope $groupOp $value">;
809+
def OpGroupLogicalAndKHR: Op<6406, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
810+
"$res = OpGroupLogicalAndKHR $type $scope $groupOp $value">;
811+
def OpGroupLogicalOrKHR: Op<6407, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
812+
"$res = OpGroupLogicalOrKHR $type $scope $groupOp $value">;
813+
def OpGroupLogicalXorKHR: Op<6408, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
814+
"$res = OpGroupLogicalXorKHR $type $scope $groupOp $value">;

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,20 @@ void addInstrRequirements(const MachineInstr &MI,
10691069
Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
10701070
}
10711071
break;
1072+
case SPIRV::OpGroupIMulKHR:
1073+
case SPIRV::OpGroupFMulKHR:
1074+
case SPIRV::OpGroupBitwiseAndKHR:
1075+
case SPIRV::OpGroupBitwiseOrKHR:
1076+
case SPIRV::OpGroupBitwiseXorKHR:
1077+
case SPIRV::OpGroupLogicalAndKHR:
1078+
case SPIRV::OpGroupLogicalOrKHR:
1079+
case SPIRV::OpGroupLogicalXorKHR:
1080+
if (ST.canUseExtension(
1081+
SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1082+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1083+
Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1084+
}
1085+
break;
10721086
case SPIRV::OpFunctionPointerCallINTEL:
10731087
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
10741088
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);

llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
5454
"use of local memory and work group barriers, and to "
5555
"utilize specialized hardware to load and store blocks of "
5656
"data from images or buffers."),
57+
clEnumValN(SPIRV::Extension::SPV_KHR_uniform_group_instructions,
58+
"SPV_KHR_uniform_group_instructions",
59+
"Allows support for additional group operations within "
60+
"uniform control flow."),
5761
clEnumValN(SPIRV::Extension::SPV_KHR_no_integer_wrap_decoration,
5862
"SPV_KHR_no_integer_wrap_decoration",
5963
"Adds decorations to indicate that a given instruction does "

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_
461461
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
462462
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
463463
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
464+
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
464465

465466
//===----------------------------------------------------------------------===//
466467
// Multiclass used to define SourceLanguage enum values and at the same time

0 commit comments

Comments
 (0)