Skip to content

Add support for atomic instruction on floating-point numbers #81683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "SPIRVInstPrinter.h"
#include "SPIRV.h"
#include "SPIRVBaseInfo.h"
#include "SPIRVInstrInfo.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/MC/MCAsmInfo.h"
Expand Down Expand Up @@ -50,6 +51,7 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI,
void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
unsigned StartIndex,
raw_ostream &O) {
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::ASM_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;

assert((NumVarOps == 1 || NumVarOps == 2) &&
Expand All @@ -65,7 +67,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
}

// Format and print float values.
if (MI->getOpcode() == SPIRV::OpConstantF) {
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
: APFloat(APInt(64, Imm).bitsToDouble());

Expand Down
69 changes: 64 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ struct IntelSubgroupsBuiltin {
#define GET_IntelSubgroupsBuiltins_DECL
#define GET_IntelSubgroupsBuiltins_IMPL

struct AtomicFloatingBuiltin {
StringRef Name;
uint32_t Opcode;
};

#define GET_AtomicFloatingBuiltins_DECL
#define GET_AtomicFloatingBuiltins_IMPL

struct GetBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
Expand Down Expand Up @@ -402,7 +410,7 @@ getSPIRVMemSemantics(std::memory_order MemOrder) {
case std::memory_order::memory_order_seq_cst:
return SPIRV::MemorySemantics::SequentiallyConsistent;
default:
llvm_unreachable("Unknown CL memory scope");
report_fatal_error("Unknown CL memory scope");
}
}

Expand All @@ -419,7 +427,7 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
case SPIRV::CLMemoryScope::memory_scope_sub_group:
return SPIRV::Scope::Subgroup;
}
llvm_unreachable("Unknown CL memory scope");
report_fatal_error("Unknown CL memory scope");
}

static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
Expand Down Expand Up @@ -676,6 +684,38 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
return true;
}

/// Helper function for building an atomic floating-type instruction.
static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
unsigned Opcode,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
assert(Call->Arguments.size() == 4 &&
"Wrong number of atomic floating-type builtin");

MachineRegisterInfo *MRI = MIRBuilder.getMRI();

Register PtrReg = Call->Arguments[0];
MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);

Register ScopeReg = Call->Arguments[1];
MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);

Register MemSemanticsReg = Call->Arguments[2];
MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);

Register ValueReg = Call->Arguments[3];
MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);

MIRBuilder.buildInstr(Opcode)
.addDef(Call->ReturnRegister)
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
.addUse(PtrReg)
.addUse(ScopeReg)
.addUse(MemSemanticsReg)
.addUse(ValueReg);
return true;
}

/// Helper function for building atomic flag instructions (e.g.
/// OpAtomicFlagTestAndSet).
static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
Expand Down Expand Up @@ -786,7 +826,7 @@ static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
case SPIRV::Dim::DIM_3D:
return 3;
default:
llvm_unreachable("Cannot get num components for given Dim");
report_fatal_error("Cannot get num components for given Dim");
}
}

Expand Down Expand Up @@ -1157,6 +1197,23 @@ static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
}
}

static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
// Lookup the instruction opcode in the TableGen records.
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;

switch (Opcode) {
case SPIRV::OpAtomicFAddEXT:
case SPIRV::OpAtomicFMinEXT:
case SPIRV::OpAtomicFMaxEXT:
return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
default:
return false;
}
}

static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
Expand Down Expand Up @@ -1311,7 +1368,7 @@ getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
case SPIRV::CLK_ADDRESS_NONE:
return SPIRV::SamplerAddressingMode::None;
default:
llvm_unreachable("Unknown CL address mode");
report_fatal_error("Unknown CL address mode");
}
}

Expand Down Expand Up @@ -2021,6 +2078,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateBuiltinVar(Call.get(), MIRBuilder, GR);
case SPIRV::Atomic:
return generateAtomicInst(Call.get(), MIRBuilder, GR);
case SPIRV::AtomicFloating:
return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
case SPIRV::Barrier:
return generateBarrierInst(Call.get(), MIRBuilder, GR);
case SPIRV::Dot:
Expand Down Expand Up @@ -2089,7 +2148,7 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
return Type::getFloatTy(Context);
else if (Name.starts_with("half"))
return Type::getHalfTy(Context);
llvm_unreachable("Unable to recognize type!");
report_fatal_error("Unable to recognize type!");
}

//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def AsyncCopy : BuiltinGroup;
def VectorLoadStore : BuiltinGroup;
def LoadStore : BuiltinGroup;
def IntelSubgroups : BuiltinGroup;
def AtomicFloating : BuiltinGroup;

//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
Expand Down Expand Up @@ -872,6 +873,44 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;

//===----------------------------------------------------------------------===//
// Class defining an atomic instruction on floating-point numbers.
//
// name is the demangled name of the given builtin.
// opcode specifies the SPIR-V operation code of the generated instruction.
//===----------------------------------------------------------------------===//
class AtomicFloatingBuiltin<string name, Op operation> {
string Name = name;
Op Opcode = operation;
}

// Table gathering all builtins for atomic instructions on floating-point numbers
def AtomicFloatingBuiltins : GenericTable {
let FilterClass = "AtomicFloatingBuiltin";
let Fields = ["Name", "Opcode"];
}

// Function to lookup builtins by their name and set.
def lookupAtomicFloatingBuiltin : SearchIndex {
let Table = AtomicFloatingBuiltins;
let Key = ["Name"];
}

// Multiclass used to define incoming demangled builtin records and
// corresponding builtin records for atomic instructions on floating-point numbers.
multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
def : DemangledBuiltin<!strconcat("__spirv_AtomicF", name), OpenCL_std, AtomicFloating, minNumArgs, maxNumArgs>;
def : AtomicFloatingBuiltin<!strconcat("__spirv_AtomicF", name), operation>;
}

// SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_atomic_float16_add
// Atomic add, min and max instruction on floating-point numbers:
defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)

//===----------------------------------------------------------------------===//
// Class defining a sub group builtin that should be translated into a
// SPIR-V instruction using the SPV_INTEL_subgroups extension.
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ class SPIRVInstrInfo : public SPIRVGenInstrInfo {
bool KillSrc) const override;
bool expandPostRAPseudo(MachineInstr &MI) const override;
};

namespace SPIRV {
enum AsmComments {
// It is a half type
ASM_PRINTER_WIDTH16 = MachineInstr::TAsmComments
};
}; // namespace SPIRV

} // namespace llvm

#endif // LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def OpAtomicAnd: AtomicOpVal<"OpAtomicAnd", 240>;
def OpAtomicOr: AtomicOpVal<"OpAtomicOr", 241>;
def OpAtomicXor: AtomicOpVal<"OpAtomicXor", 242>;

def OpAtomicFAddEXT: AtomicOpVal<"OpAtomicFAddEXT", 6035>;
def OpAtomicFMinEXT: AtomicOpVal<"OpAtomicFMinEXT", 5614>;
def OpAtomicFMaxEXT: AtomicOpVal<"OpAtomicFMaxEXT", 5615>;

def OpAtomicFlagTestAndSet: AtomicOp<"OpAtomicFlagTestAndSet", 318>;
def OpAtomicFlagClear: Op<319, (outs), (ins ID:$ptr, ID:$sc, ID:$sem),
Expand Down
43 changes: 33 additions & 10 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectMemOperation(Register ResVReg, MachineInstr &I) const;

bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, unsigned NewOpcode) const;
MachineInstr &I, unsigned NewOpcode,
unsigned NegateOpcode = 0) const;

bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
Expand Down Expand Up @@ -489,6 +490,17 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_ATOMIC_CMPXCHG:
return selectAtomicCmpXchg(ResVReg, ResType, I);

case TargetOpcode::G_ATOMICRMW_FADD:
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT);
case TargetOpcode::G_ATOMICRMW_FSUB:
// Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
SPIRV::OpFNegate);
case TargetOpcode::G_ATOMICRMW_FMIN:
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
case TargetOpcode::G_ATOMICRMW_FMAX:
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMaxEXT);

case TargetOpcode::G_FENCE:
return selectFence(I);

Expand Down Expand Up @@ -686,7 +698,8 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
unsigned NewOpcode) const {
unsigned NewOpcode,
unsigned NegateOpcode) const {
assert(I.hasOneMemOperand());
const MachineMemOperand *MemOp = *I.memoperands_begin();
uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID()));
Expand All @@ -700,14 +713,24 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I);

return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Ptr)
.addUse(ScopeReg)
.addUse(MemSemReg)
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
bool Result = false;
Register ValueReg = I.getOperand(2).getReg();
if (NegateOpcode != 0) {
// Translation with negative value operand is requested
Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= selectUnOpWithSrc(TmpReg, ResType, I, ValueReg, NegateOpcode);
ValueReg = TmpReg;
}

Result |= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Ptr)
.addUse(ScopeReg)
.addUse(MemSemReg)
.addUse(ValueReg)
.constrainAllUses(TII, TRI, RBI);
return Result;
}

bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {

auto allIntScalars = {s8, s16, s32, s64};

auto allFloatScalars = {s16, s32, s64};

auto allFloatScalarsAndVectors = {
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
Expand Down Expand Up @@ -205,6 +207,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
.legalForCartesianProduct(allIntScalars, allWritablePtrs);

getActionDefinitionsBuilder(
{G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
.legalForCartesianProduct(allFloatScalars, allWritablePtrs);

getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
.legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ using namespace llvm;
void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
SPIRV::ModuleAnalysisInfo *MAI) const {
OutMI.setOpcode(MI->getOpcode());
// Propagate previously set flags
OutMI.setFlags(MI->getAsmPrinterFlags());
const MachineFunction *MF = MI->getMF();
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
const MachineOperand &MO = MI->getOperand(i);
Expand Down
Loading