Skip to content

Add support for the SPIR-V extension SPV_KHR_uniform_group_instructions #82064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ struct AtomicFloatingBuiltin {

#define GET_AtomicFloatingBuiltins_DECL
#define GET_AtomicFloatingBuiltins_IMPL
struct GroupUniformBuiltin {
StringRef Name;
uint32_t Opcode;
bool IsLogical;
};

#define GET_GroupUniformBuiltins_DECL
#define GET_GroupUniformBuiltins_IMPL

struct GetBuiltin {
StringRef Name;
Expand Down Expand Up @@ -1014,6 +1022,57 @@ static bool generateIntelSubgroupsInst(const SPIRV::IncomingCall *Call,
return true;
}

static bool generateGroupUniformInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
MachineFunction &MF = MIRBuilder.getMF();
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
if (!ST->canUseExtension(
SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
std::string DiagMsg = std::string(Builtin->Name) +
": the builtin requires the following SPIR-V "
"extension: SPV_KHR_uniform_group_instructions";
report_fatal_error(DiagMsg.c_str(), false);
}
const SPIRV::GroupUniformBuiltin *GroupUniform =
SPIRV::lookupGroupUniformBuiltin(Builtin->Name);
MachineRegisterInfo *MRI = MIRBuilder.getMRI();

Register GroupResultReg = Call->ReturnRegister;
MRI->setRegClass(GroupResultReg, &SPIRV::IDRegClass);

// Scope
Register ScopeReg = Call->Arguments[0];
MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);

// Group Operation
Register ConstGroupOpReg = Call->Arguments[1];
const MachineInstr *Const = getDefInstrMaybeConstant(ConstGroupOpReg, MRI);
if (!Const || Const->getOpcode() != TargetOpcode::G_CONSTANT)
report_fatal_error(
"expect a constant group operation for a uniform group instruction",
false);
const MachineOperand &ConstOperand = Const->getOperand(1);
if (!ConstOperand.isCImm())
report_fatal_error("uniform group instructions: group operation must be an "
"integer constant",
false);

// Value
Register ValueReg = Call->Arguments[2];
MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);

auto MIB = MIRBuilder.buildInstr(GroupUniform->Opcode)
.addDef(GroupResultReg)
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
.addUse(ScopeReg);
addNumImm(ConstOperand.getCImm()->getValue(), MIB);
MIB.addUse(ValueReg);

return true;
}

// These queries ask for a single size_t result for a given dimension index, e.g
// size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
// these values are all vec3 types, so we need to extract the correct index or
Expand Down Expand Up @@ -2112,6 +2171,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
case SPIRV::IntelSubgroups:
return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
case SPIRV::GroupUniform:
return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
}
return false;
}
Expand Down
97 changes: 96 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def VectorLoadStore : BuiltinGroup;
def LoadStore : BuiltinGroup;
def IntelSubgroups : BuiltinGroup;
def AtomicFloating : BuiltinGroup;
def GroupUniform : BuiltinGroup;

//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
Expand Down Expand Up @@ -605,7 +606,10 @@ class GroupBuiltin<string name, Op operation> {
!eq(operation, OpGroupNonUniformBallotFindMSB));
bit IsLogical = !or(!eq(operation, OpGroupNonUniformLogicalAnd),
!eq(operation, OpGroupNonUniformLogicalOr),
!eq(operation, OpGroupNonUniformLogicalXor));
!eq(operation, OpGroupNonUniformLogicalXor),
!eq(operation, OpGroupLogicalAndKHR),
!eq(operation, OpGroupLogicalOrKHR),
!eq(operation, OpGroupLogicalXorKHR));
bit NoGroupOperation = !or(IsElect, IsAllOrAny, IsAllEqual,
IsBallot, IsInverseBallot,
IsBallotBitExtract, IsBallotFindBit,
Expand Down Expand Up @@ -873,6 +877,51 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;

// cl_khr_work_group_uniform_arithmetic / SPV_KHR_uniform_group_instructions
defm : DemangledGroupBuiltin<"group_reduce_imul", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_reduce_mulu", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_reduce_muls", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_imul", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_mulu", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_muls", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_imul", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_mulu", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_muls", OnlyWork, OpGroupIMulKHR>;

defm : DemangledGroupBuiltin<"group_reduce_mulf", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_reduce_mulh", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_reduce_muld", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_mulf", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_mulh", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_muld", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_mulf", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_mulh", OnlyWork, OpGroupFMulKHR>;
defm : DemangledGroupBuiltin<"group_scan_exclusive_muld", OnlyWork, OpGroupFMulKHR>;

defm : DemangledGroupBuiltin<"group_scan_exclusive_and", OnlyWork, OpGroupBitwiseAndKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_and", OnlyWork, OpGroupBitwiseAndKHR>;
defm : DemangledGroupBuiltin<"group_reduce_and", OnlyWork, OpGroupBitwiseAndKHR>;

defm : DemangledGroupBuiltin<"group_scan_exclusive_or", OnlyWork, OpGroupBitwiseOrKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_or", OnlyWork, OpGroupBitwiseOrKHR>;
defm : DemangledGroupBuiltin<"group_reduce_or", OnlyWork, OpGroupBitwiseOrKHR>;

defm : DemangledGroupBuiltin<"group_scan_exclusive_xor", OnlyWork, OpGroupBitwiseXorKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_xor", OnlyWork, OpGroupBitwiseXorKHR>;
defm : DemangledGroupBuiltin<"group_reduce_xor", OnlyWork, OpGroupBitwiseXorKHR>;

defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_and", OnlyWork, OpGroupLogicalAndKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_and", OnlyWork, OpGroupLogicalAndKHR>;
defm : DemangledGroupBuiltin<"group_reduce_logical_and", OnlyWork, OpGroupLogicalAndKHR>;

defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_or", OnlyWork, OpGroupLogicalOrKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_or", OnlyWork, OpGroupLogicalOrKHR>;
defm : DemangledGroupBuiltin<"group_reduce_logical_or", OnlyWork, OpGroupLogicalOrKHR>;

defm : DemangledGroupBuiltin<"group_scan_exclusive_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;
defm : DemangledGroupBuiltin<"group_scan_inclusive_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;
defm : DemangledGroupBuiltin<"group_reduce_logical_xor", OnlyWork, OpGroupLogicalXorKHR>;

//===----------------------------------------------------------------------===//
// Class defining an atomic instruction on floating-point numbers.
//
Expand Down Expand Up @@ -967,6 +1016,52 @@ foreach i = ["", "2", "4", "8", "16"] in {
}
// OpSubgroupImageBlockReadINTEL and OpSubgroupImageBlockWriteINTEL are to be resolved later on (in code)

//===----------------------------------------------------------------------===//
// Class defining a builtin for group operations within uniform control flow.
// It should be translated into a SPIR-V instruction using
// the SPV_KHR_uniform_group_instructions extension.
//
// name is the demangled name of the given builtin.
// opcode specifies the SPIR-V operation code of the generated instruction.
//===----------------------------------------------------------------------===//
class GroupUniformBuiltin<string name, Op operation> {
string Name = name;
Op Opcode = operation;
bit IsLogical = !or(!eq(operation, OpGroupLogicalAndKHR),
!eq(operation, OpGroupLogicalOrKHR),
!eq(operation, OpGroupLogicalXorKHR));
}

// Table gathering all the Intel sub group builtins.
def GroupUniformBuiltins : GenericTable {
let FilterClass = "GroupUniformBuiltin";
let Fields = ["Name", "Opcode", "IsLogical"];
}

// Function to lookup group builtins by their name and set.
def lookupGroupUniformBuiltin : SearchIndex {
let Table = GroupUniformBuiltins;
let Key = ["Name"];
}

// Multiclass used to define incoming builtin records for
// the SPV_KHR_uniform_group_instructions extension
// and corresponding work group builtin records.
multiclass DemangledGroupUniformBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
def : DemangledBuiltin<!strconcat("__spirv_Group", name), OpenCL_std, GroupUniform, minNumArgs, maxNumArgs>;
def : GroupUniformBuiltin<!strconcat("__spirv_Group", name), operation>;
}

// cl_khr_work_group_uniform_arithmetic / SPV_KHR_uniform_group_instructions
defm : DemangledGroupUniformBuiltin<"IMulKHR", 3, 3, OpGroupIMulKHR>;
defm : DemangledGroupUniformBuiltin<"FMulKHR", 3, 3, OpGroupFMulKHR>;
defm : DemangledGroupUniformBuiltin<"BitwiseAndKHR", 3, 3, OpGroupBitwiseAndKHR>;
defm : DemangledGroupUniformBuiltin<"BitwiseOrKHR", 3, 3, OpGroupBitwiseOrKHR>;
defm : DemangledGroupUniformBuiltin<"BitwiseXorKHR", 3, 3, OpGroupBitwiseXorKHR>;
defm : DemangledGroupUniformBuiltin<"LogicalAndKHR", 3, 3, OpGroupLogicalAndKHR>;
defm : DemangledGroupUniformBuiltin<"LogicalOrKHR", 3, 3, OpGroupLogicalOrKHR>;
defm : DemangledGroupUniformBuiltin<"LogicalXorKHR", 3, 3, OpGroupLogicalXorKHR>;

//===----------------------------------------------------------------------===//
// Class defining a get builtin record used for lowering builtin calls such as
// "get_sub_group_eq_mask" or "get_global_id" to SPIR-V instructions.
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,8 @@ def OpConstantFunctionPointerINTEL: Op<5600, (outs ID:$res), (ins TYPE:$ty, ID:$
def OpFunctionPointerCallINTEL: Op<5601, (outs ID:$res), (ins TYPE:$ty, ID:$funPtr, variable_ops), "$res = OpFunctionPointerCallINTEL $ty $funPtr">;

// 3.49.21. Group and Subgroup Instructions

// - SPV_INTEL_subgroups
def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId),
"$res = OpSubgroupShuffleINTEL $type $data $invocationId">;
def OpSubgroupShuffleDownINTEL: Op<5572, (outs ID:$res), (ins TYPE:$type, ID:$current, ID:$next, ID:$delta),
Expand All @@ -792,3 +794,21 @@ def OpSubgroupImageBlockReadINTEL: Op<5577, (outs ID:$res), (ins TYPE:$type, ID:
"$res = OpSubgroupImageBlockReadINTEL $type $image $coordinate">;
def OpSubgroupImageBlockWriteINTEL: Op<5578, (outs), (ins ID:$image, ID:$coordinate, ID:$data),
"OpSubgroupImageBlockWriteINTEL $image $coordinate $data">;

// - SPV_KHR_uniform_group_instructions
def OpGroupIMulKHR: Op<6401, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupIMulKHR $type $scope $groupOp $value">;
def OpGroupFMulKHR: Op<6402, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupFMulKHR $type $scope $groupOp $value">;
def OpGroupBitwiseAndKHR: Op<6403, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupBitwiseAndKHR $type $scope $groupOp $value">;
def OpGroupBitwiseOrKHR: Op<6404, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupBitwiseOrKHR $type $scope $groupOp $value">;
def OpGroupBitwiseXorKHR: Op<6405, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupBitwiseXorKHR $type $scope $groupOp $value">;
def OpGroupLogicalAndKHR: Op<6406, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupLogicalAndKHR $type $scope $groupOp $value">;
def OpGroupLogicalOrKHR: Op<6407, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupLogicalOrKHR $type $scope $groupOp $value">;
def OpGroupLogicalXorKHR: Op<6408, (outs ID:$res), (ins TYPE:$type, ID:$scope, i32imm:$groupOp, ID:$value),
"$res = OpGroupLogicalXorKHR $type $scope $groupOp $value">;
14 changes: 14 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,20 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
}
break;
case SPIRV::OpGroupIMulKHR:
case SPIRV::OpGroupFMulKHR:
case SPIRV::OpGroupBitwiseAndKHR:
case SPIRV::OpGroupBitwiseOrKHR:
case SPIRV::OpGroupBitwiseXorKHR:
case SPIRV::OpGroupLogicalAndKHR:
case SPIRV::OpGroupLogicalOrKHR:
case SPIRV::OpGroupLogicalXorKHR:
if (ST.canUseExtension(
SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
}
break;
case SPIRV::OpFunctionPointerCallINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
"use of local memory and work group barriers, and to "
"utilize specialized hardware to load and store blocks of "
"data from images or buffers."),
clEnumValN(SPIRV::Extension::SPV_KHR_uniform_group_instructions,
"SPV_KHR_uniform_group_instructions",
"Allows support for additional group operations within "
"uniform control flow."),
clEnumValN(SPIRV::Extension::SPV_KHR_no_integer_wrap_decoration,
"SPV_KHR_no_integer_wrap_decoration",
"Adds decorations to indicate that a given instruction does "
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
Loading