Skip to content

AMDGPU: Fix buffer load/store of pointers #95379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 30 additions & 42 deletions llvm/lib/Target/AMDGPU/BUFInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -1419,27 +1419,21 @@ let OtherPredicates = [HasPackedD16VMem] in {
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load_format_d16, v4i16, "BUFFER_LOAD_FORMAT_D16_XYZW">;
} // End HasPackedD16VMem.

defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, f32, "BUFFER_LOAD_DWORD">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, i32, "BUFFER_LOAD_DWORD">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2i16, "BUFFER_LOAD_DWORD">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2f16, "BUFFER_LOAD_DWORD">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2bf16, "BUFFER_LOAD_DWORD">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2f32, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2i32, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v4i16, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v4f16, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, i64, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, f64, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v4bf16, "BUFFER_LOAD_DWORDX2">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v3f32, "BUFFER_LOAD_DWORDX3">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v3i32, "BUFFER_LOAD_DWORDX3">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v4f32, "BUFFER_LOAD_DWORDX4">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v4i32, "BUFFER_LOAD_DWORDX4">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2i64, "BUFFER_LOAD_DWORDX4">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v2f64, "BUFFER_LOAD_DWORDX4">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v8i16, "BUFFER_LOAD_DWORDX4">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v8f16, "BUFFER_LOAD_DWORDX4">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, v8bf16, "BUFFER_LOAD_DWORDX4">;
foreach vt = Reg32Types.types in {
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, vt, "BUFFER_LOAD_DWORD">;
}

foreach vt = Reg64Types.types in {
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, vt, "BUFFER_LOAD_DWORDX2">;
}

foreach vt = Reg96Types.types in {
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, vt, "BUFFER_LOAD_DWORDX3">;
}

foreach vt = Reg128Types.types in {
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load, vt, "BUFFER_LOAD_DWORDX4">;
}

defm : MUBUF_LoadIntrinsicPat<SIbuffer_load_byte, i32, "BUFFER_LOAD_SBYTE">;
defm : MUBUF_LoadIntrinsicPat<SIbuffer_load_short, i32, "BUFFER_LOAD_SSHORT">;
Expand Down Expand Up @@ -1530,27 +1524,21 @@ let OtherPredicates = [HasPackedD16VMem] in {
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_format_d16, v4i16, "BUFFER_STORE_FORMAT_D16_XYZW">;
} // End HasPackedD16VMem.

defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, f32, "BUFFER_STORE_DWORD">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, i32, "BUFFER_STORE_DWORD">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2i16, "BUFFER_STORE_DWORD">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2f16, "BUFFER_STORE_DWORD">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2bf16, "BUFFER_STORE_DWORD">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2f32, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2i32, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, i64, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, f64, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v4i16, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v4f16, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v4bf16, "BUFFER_STORE_DWORDX2">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v3f32, "BUFFER_STORE_DWORDX3">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v3i32, "BUFFER_STORE_DWORDX3">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v4f32, "BUFFER_STORE_DWORDX4">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v4i32, "BUFFER_STORE_DWORDX4">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2i64, "BUFFER_STORE_DWORDX4">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v2f64, "BUFFER_STORE_DWORDX4">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v8f16, "BUFFER_STORE_DWORDX4">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v8i16, "BUFFER_STORE_DWORDX4">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, v8bf16, "BUFFER_STORE_DWORDX4">;
foreach vt = Reg32Types.types in {
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, vt, "BUFFER_STORE_DWORD">;
}

foreach vt = Reg64Types.types in {
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, vt, "BUFFER_STORE_DWORDX2">;
}

foreach vt = Reg96Types.types in {
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, vt, "BUFFER_STORE_DWORDX3">;
}

foreach vt = Reg128Types.types in {
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store, vt, "BUFFER_STORE_DWORDX4">;
}

defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_byte, i32, "BUFFER_STORE_BYTE">;
defm : MUBUF_StoreIntrinsicPat<SIbuffer_store_short, i32, "BUFFER_STORE_SHORT">;
Expand Down
31 changes: 19 additions & 12 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,29 +1112,33 @@ unsigned SITargetLowering::getVectorTypeBreakdownForCallingConv(
Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
}

static EVT memVTFromLoadIntrData(Type *Ty, unsigned MaxNumLanes) {
static EVT memVTFromLoadIntrData(const SITargetLowering &TLI,
const DataLayout &DL, Type *Ty,
unsigned MaxNumLanes) {
assert(MaxNumLanes != 0);

LLVMContext &Ctx = Ty->getContext();
if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
unsigned NumElts = std::min(MaxNumLanes, VT->getNumElements());
return EVT::getVectorVT(Ty->getContext(),
EVT::getEVT(VT->getElementType()),
return EVT::getVectorVT(Ctx, TLI.getValueType(DL, VT->getElementType()),
NumElts);
}

return EVT::getEVT(Ty);
return TLI.getValueType(DL, Ty);
}

// Peek through TFE struct returns to only use the data size.
static EVT memVTFromLoadIntrReturn(Type *Ty, unsigned MaxNumLanes) {
static EVT memVTFromLoadIntrReturn(const SITargetLowering &TLI,
const DataLayout &DL, Type *Ty,
unsigned MaxNumLanes) {
auto *ST = dyn_cast<StructType>(Ty);
if (!ST)
return memVTFromLoadIntrData(Ty, MaxNumLanes);
return memVTFromLoadIntrData(TLI, DL, Ty, MaxNumLanes);

// TFE intrinsics return an aggregate type.
assert(ST->getNumContainedTypes() == 2 &&
ST->getContainedType(1)->isIntegerTy(32));
return memVTFromLoadIntrData(ST->getContainedType(0), MaxNumLanes);
return memVTFromLoadIntrData(TLI, DL, ST->getContainedType(0), MaxNumLanes);
}

/// Map address space 7 to MVT::v5i32 because that's its in-memory
Expand Down Expand Up @@ -1219,10 +1223,12 @@ bool SITargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
MaxNumLanes = DMask == 0 ? 1 : llvm::popcount(DMask);
}

Info.memVT = memVTFromLoadIntrReturn(CI.getType(), MaxNumLanes);
Info.memVT = memVTFromLoadIntrReturn(*this, MF.getDataLayout(),
CI.getType(), MaxNumLanes);
} else {
Info.memVT = memVTFromLoadIntrReturn(
CI.getType(), std::numeric_limits<unsigned>::max());
Info.memVT =
memVTFromLoadIntrReturn(*this, MF.getDataLayout(), CI.getType(),
std::numeric_limits<unsigned>::max());
}

// FIXME: What does alignment mean for an image?
Expand All @@ -1235,9 +1241,10 @@ bool SITargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
if (RsrcIntr->IsImage) {
unsigned DMask = cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
unsigned DMaskLanes = DMask == 0 ? 1 : llvm::popcount(DMask);
Info.memVT = memVTFromLoadIntrData(DataTy, DMaskLanes);
Info.memVT = memVTFromLoadIntrData(*this, MF.getDataLayout(), DataTy,
DMaskLanes);
} else
Info.memVT = EVT::getEVT(DataTy);
Info.memVT = getValueType(MF.getDataLayout(), DataTy);

Info.flags |= MachineMemOperand::MOStore;
} else {
Expand Down
16 changes: 9 additions & 7 deletions llvm/lib/Target/AMDGPU/SIRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ class RegisterTypes<list<ValueType> reg_types> {

def Reg16Types : RegisterTypes<[i16, f16, bf16]>;
def Reg32Types : RegisterTypes<[i32, f32, v2i16, v2f16, v2bf16, p2, p3, p5, p6]>;
def Reg64Types : RegisterTypes<[i64, f64, v2i32, v2f32, v4i16, v4f16, v4bf16, p0]>;
def Reg64Types : RegisterTypes<[i64, f64, v2i32, v2f32, p0, v4i16, v4f16, v4bf16]>;
def Reg96Types : RegisterTypes<[v3i32, v3f32]>;
def Reg128Types : RegisterTypes<[v4i32, v4f32, v2i64, v2f64, v8i16, v8f16, v8bf16]>;

let HasVGPR = 1 in {
// VOP3 and VINTERP can access 256 lo and 256 hi registers.
Expand Down Expand Up @@ -744,7 +746,7 @@ def Pseudo_SReg_32 : SIRegisterClass<"AMDGPU", [i32, f32, i16, f16, bf16, v2i16,
let BaseClassOrder = 10000;
}

def Pseudo_SReg_128 : SIRegisterClass<"AMDGPU", [v4i32, v2i64, v2f64, v8i16, v8f16, v8bf16], 32,
def Pseudo_SReg_128 : SIRegisterClass<"AMDGPU", Reg128Types.types, 32,
(add PRIVATE_RSRC_REG)> {
let isAllocatable = 0;
let CopyCost = -1;
Expand Down Expand Up @@ -815,7 +817,7 @@ def SRegOrLds_32 : SIRegisterClass<"AMDGPU", [i32, f32, i16, f16, bf16, v2i16, v
let HasSGPR = 1;
}

def SGPR_64 : SIRegisterClass<"AMDGPU", [v2i32, i64, v2f32, f64, v4i16, v4f16, v4bf16], 32,
def SGPR_64 : SIRegisterClass<"AMDGPU", Reg64Types.types, 32,
(add SGPR_64Regs)> {
let CopyCost = 1;
let AllocationPriority = 1;
Expand Down Expand Up @@ -905,8 +907,8 @@ multiclass SRegClass<int numRegs,
}
}

defm "" : SRegClass<3, [v3i32, v3f32], SGPR_96Regs, TTMP_96Regs>;
defm "" : SRegClass<4, [v4i32, v4f32, v2i64, v2f64, v8i16, v8f16, v8bf16], SGPR_128Regs, TTMP_128Regs>;
defm "" : SRegClass<3, Reg96Types.types, SGPR_96Regs, TTMP_96Regs>;
defm "" : SRegClass<4, Reg128Types.types, SGPR_128Regs, TTMP_128Regs>;
defm "" : SRegClass<5, [v5i32, v5f32], SGPR_160Regs, TTMP_160Regs>;
defm "" : SRegClass<6, [v6i32, v6f32, v3i64, v3f64], SGPR_192Regs, TTMP_192Regs>;
defm "" : SRegClass<7, [v7i32, v7f32], SGPR_224Regs, TTMP_224Regs>;
Expand Down Expand Up @@ -958,8 +960,8 @@ multiclass VRegClass<int numRegs, list<ValueType> regTypes, dag regList> {

defm VReg_64 : VRegClass<2, [i64, f64, v2i32, v2f32, v4f16, v4bf16, v4i16, p0, p1, p4],
(add VGPR_64)>;
defm VReg_96 : VRegClass<3, [v3i32, v3f32], (add VGPR_96)>;
defm VReg_128 : VRegClass<4, [v4i32, v4f32, v2i64, v2f64, v8i16, v8f16, v8bf16], (add VGPR_128)>;
defm VReg_96 : VRegClass<3, Reg96Types.types, (add VGPR_96)>;
defm VReg_128 : VRegClass<4, Reg128Types.types, (add VGPR_128)>;
defm VReg_160 : VRegClass<5, [v5i32, v5f32], (add VGPR_160)>;

defm VReg_192 : VRegClass<6, [v6i32, v6f32, v3i64, v3f64], (add VGPR_192)>;
Expand Down
Loading
Loading