Skip to content

Commit 62d946e

Browse files
committed
GlobalISel: Merge some AMDGPU ABI lowering code to generic code
AMDGPU currently has a lot of pre-processing code to pre-split argument types into 32-bit pieces before passing it to the generic code in handleAssignments. This is a bit sloppy and also requires some overly fancy iterator work when building the calls. It's better if all argument marshalling code is handled directly in handleAssignments. This handles more situations like decomposing large element vectors into sub-element sized pieces. This should mostly be NFC, but does change the generated code by shifting where the initial argument packing instructions are placed. I think this is nicer looking, since it now emits the packing code directly after the relevant copies, rather than after the copies for the remaining arguments. This doubles down on gfx6/gfx7 using the gfx8+ ABI for 16-bit types. This is ultimately the better option, but incompatible with the DAG. Fixing this requires more work, especially for f16.
1 parent 70e3c9a commit 62d946e

File tree

71 files changed

+1944
-2012
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1944
-2012
lines changed

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

+119-9
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,119 @@ void CallLowering::unpackRegs(ArrayRef<Register> DstRegs, Register SrcReg,
228228
MIRBuilder.buildExtract(DstRegs[i], SrcReg, Offsets[i]);
229229
}
230230

231+
/// Pack values \p SrcRegs to cover the vector type result \p DstRegs.
232+
static MachineInstrBuilder
233+
mergeVectorRegsToResultRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
234+
ArrayRef<Register> SrcRegs) {
235+
MachineRegisterInfo &MRI = *B.getMRI();
236+
LLT LLTy = MRI.getType(DstRegs[0]);
237+
LLT PartLLT = MRI.getType(SrcRegs[0]);
238+
239+
// Deal with v3s16 split into v2s16
240+
LLT LCMTy = getLCMType(LLTy, PartLLT);
241+
if (LCMTy == LLTy) {
242+
// Common case where no padding is needed.
243+
assert(DstRegs.size() == 1);
244+
return B.buildConcatVectors(DstRegs[0], SrcRegs);
245+
}
246+
247+
const int NumWide = LCMTy.getSizeInBits() / PartLLT.getSizeInBits();
248+
Register Undef = B.buildUndef(PartLLT).getReg(0);
249+
250+
// Build vector of undefs.
251+
SmallVector<Register, 8> WidenedSrcs(NumWide, Undef);
252+
253+
// Replace the first sources with the real registers.
254+
std::copy(SrcRegs.begin(), SrcRegs.end(), WidenedSrcs.begin());
255+
256+
auto Widened = B.buildConcatVectors(LCMTy, WidenedSrcs);
257+
int NumDst = LCMTy.getSizeInBits() / LLTy.getSizeInBits();
258+
259+
SmallVector<Register, 8> PadDstRegs(NumDst);
260+
std::copy(DstRegs.begin(), DstRegs.end(), PadDstRegs.begin());
261+
262+
// Create the excess dead defs for the unmerge.
263+
for (int I = DstRegs.size(); I != NumDst; ++I)
264+
PadDstRegs[I] = MRI.createGenericVirtualRegister(LLTy);
265+
266+
return B.buildUnmerge(PadDstRegs, Widened);
267+
}
268+
269+
/// Create a sequence of instructions to combine pieces split into register
270+
/// typed values to the original IR value. \p OrigRegs contains the destination
271+
/// value registers of type \p LLTy, and \p Regs contains the legalized pieces
272+
/// with type \p PartLLT.
273+
static void buildCopyToParts(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
274+
ArrayRef<Register> Regs, LLT LLTy, LLT PartLLT) {
275+
MachineRegisterInfo &MRI = *B.getMRI();
276+
277+
if (!LLTy.isVector() && !PartLLT.isVector()) {
278+
assert(OrigRegs.size() == 1);
279+
LLT OrigTy = MRI.getType(OrigRegs[0]);
280+
281+
unsigned SrcSize = PartLLT.getSizeInBits() * Regs.size();
282+
if (SrcSize == OrigTy.getSizeInBits())
283+
B.buildMerge(OrigRegs[0], Regs);
284+
else {
285+
auto Widened = B.buildMerge(LLT::scalar(SrcSize), Regs);
286+
B.buildTrunc(OrigRegs[0], Widened);
287+
}
288+
289+
return;
290+
}
291+
292+
if (LLTy.isVector() && PartLLT.isVector()) {
293+
assert(OrigRegs.size() == 1);
294+
assert(LLTy.getElementType() == PartLLT.getElementType());
295+
mergeVectorRegsToResultRegs(B, OrigRegs, Regs);
296+
return;
297+
}
298+
299+
assert(LLTy.isVector() && !PartLLT.isVector());
300+
301+
LLT DstEltTy = LLTy.getElementType();
302+
303+
// Pointer information was discarded. We'll need to coerce some register types
304+
// to avoid violating type constraints.
305+
LLT RealDstEltTy = MRI.getType(OrigRegs[0]).getElementType();
306+
307+
assert(DstEltTy.getSizeInBits() == RealDstEltTy.getSizeInBits());
308+
309+
if (DstEltTy == PartLLT) {
310+
// Vector was trivially scalarized.
311+
312+
if (RealDstEltTy.isPointer()) {
313+
for (Register Reg : Regs)
314+
MRI.setType(Reg, RealDstEltTy);
315+
}
316+
317+
B.buildBuildVector(OrigRegs[0], Regs);
318+
} else if (DstEltTy.getSizeInBits() > PartLLT.getSizeInBits()) {
319+
// Deal with vector with 64-bit elements decomposed to 32-bit
320+
// registers. Need to create intermediate 64-bit elements.
321+
SmallVector<Register, 8> EltMerges;
322+
int PartsPerElt = DstEltTy.getSizeInBits() / PartLLT.getSizeInBits();
323+
324+
assert(DstEltTy.getSizeInBits() % PartLLT.getSizeInBits() == 0);
325+
326+
for (int I = 0, NumElts = LLTy.getNumElements(); I != NumElts; ++I) {
327+
auto Merge = B.buildMerge(RealDstEltTy, Regs.take_front(PartsPerElt));
328+
// Fix the type in case this is really a vector of pointers.
329+
MRI.setType(Merge.getReg(0), RealDstEltTy);
330+
EltMerges.push_back(Merge.getReg(0));
331+
Regs = Regs.drop_front(PartsPerElt);
332+
}
333+
334+
B.buildBuildVector(OrigRegs[0], EltMerges);
335+
} else {
336+
// Vector was split, and elements promoted to a wider type.
337+
// FIXME: Should handle floating point promotions.
338+
LLT BVType = LLT::vector(LLTy.getNumElements(), PartLLT);
339+
auto BV = B.buildBuildVector(BVType, Regs);
340+
B.buildTrunc(OrigRegs[0], BV);
341+
}
342+
}
343+
231344
bool CallLowering::handleAssignments(MachineIRBuilder &MIRBuilder,
232345
SmallVectorImpl<ArgInfo> &Args,
233346
ValueHandler &Handler,
@@ -278,9 +391,6 @@ bool CallLowering::handleAssignments(CCState &CCInfo,
278391
}
279392

280393
assert(NumParts > 1);
281-
// For now only handle exact splits.
282-
if (NewVT.getSizeInBits() * NumParts != CurVT.getSizeInBits())
283-
return false;
284394

285395
// For incoming arguments (physregs to vregs), we could have values in
286396
// physregs (or memlocs) which we want to extract and copy to vregs.
@@ -379,6 +489,7 @@ bool CallLowering::handleAssignments(CCState &CCInfo,
379489
EVT OrigVT = EVT::getEVT(Args[i].Ty);
380490
EVT VAVT = VA.getValVT();
381491
const LLT OrigTy = getLLTForType(*Args[i].Ty, DL);
492+
const LLT VATy(VAVT.getSimpleVT());
382493

383494
// Expected to be multiple regs for a single incoming arg.
384495
// There should be Regs.size() ArgLocs per argument.
@@ -427,7 +538,6 @@ bool CallLowering::handleAssignments(CCState &CCInfo,
427538
}
428539

429540
// This ArgLoc covers multiple pieces, so we need to split it.
430-
const LLT VATy(VAVT.getSimpleVT());
431541
Register NewReg =
432542
MIRBuilder.getMRI()->createGenericVirtualRegister(VATy);
433543
Handler.assignValueToReg(NewReg, VA.getLocReg(), VA);
@@ -451,12 +561,12 @@ bool CallLowering::handleAssignments(CCState &CCInfo,
451561
// Now that all pieces have been handled, re-pack any arguments into any
452562
// wider, original registers.
453563
if (Handler.isIncomingArgumentHandler()) {
454-
if (VAVT.getFixedSizeInBits() < OrigVT.getFixedSizeInBits()) {
455-
assert(NumArgRegs >= 2);
564+
// Merge the split registers into the expected larger result vregs of
565+
// the original call.
456566

457-
// Merge the split registers into the expected larger result vreg
458-
// of the original call.
459-
MIRBuilder.buildMerge(Args[i].OrigRegs[0], Args[i].Regs);
567+
if (OrigTy != VATy && !Args[i].OrigRegs.empty()) {
568+
buildCopyToParts(MIRBuilder, Args[i].OrigRegs, Args[i].Regs, OrigTy,
569+
VATy);
460570
}
461571
}
462572

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

+2-138
Original file line numberDiff line numberDiff line change
@@ -735,117 +735,6 @@ bool AMDGPUCallLowering::lowerFormalArgumentsKernel(
735735
return true;
736736
}
737737

738-
/// Pack values \p SrcRegs to cover the vector type result \p DstRegs.
739-
static MachineInstrBuilder mergeVectorRegsToResultRegs(
740-
MachineIRBuilder &B, ArrayRef<Register> DstRegs, ArrayRef<Register> SrcRegs) {
741-
MachineRegisterInfo &MRI = *B.getMRI();
742-
LLT LLTy = MRI.getType(DstRegs[0]);
743-
LLT PartLLT = MRI.getType(SrcRegs[0]);
744-
745-
// Deal with v3s16 split into v2s16
746-
LLT LCMTy = getLCMType(LLTy, PartLLT);
747-
if (LCMTy == LLTy) {
748-
// Common case where no padding is needed.
749-
assert(DstRegs.size() == 1);
750-
return B.buildConcatVectors(DstRegs[0], SrcRegs);
751-
}
752-
753-
const int NumWide = LCMTy.getSizeInBits() / PartLLT.getSizeInBits();
754-
Register Undef = B.buildUndef(PartLLT).getReg(0);
755-
756-
// Build vector of undefs.
757-
SmallVector<Register, 8> WidenedSrcs(NumWide, Undef);
758-
759-
// Replace the first sources with the real registers.
760-
std::copy(SrcRegs.begin(), SrcRegs.end(), WidenedSrcs.begin());
761-
762-
auto Widened = B.buildConcatVectors(LCMTy, WidenedSrcs);
763-
int NumDst = LCMTy.getSizeInBits() / LLTy.getSizeInBits();
764-
765-
SmallVector<Register, 8> PadDstRegs(NumDst);
766-
std::copy(DstRegs.begin(), DstRegs.end(), PadDstRegs.begin());
767-
768-
// Create the excess dead defs for the unmerge.
769-
for (int I = DstRegs.size(); I != NumDst; ++I)
770-
PadDstRegs[I] = MRI.createGenericVirtualRegister(LLTy);
771-
772-
return B.buildUnmerge(PadDstRegs, Widened);
773-
}
774-
775-
// TODO: Move this to generic code
776-
static void packSplitRegsToOrigType(MachineIRBuilder &B,
777-
ArrayRef<Register> OrigRegs,
778-
ArrayRef<Register> Regs,
779-
LLT LLTy,
780-
LLT PartLLT) {
781-
MachineRegisterInfo &MRI = *B.getMRI();
782-
783-
if (!LLTy.isVector() && !PartLLT.isVector()) {
784-
assert(OrigRegs.size() == 1);
785-
LLT OrigTy = MRI.getType(OrigRegs[0]);
786-
787-
unsigned SrcSize = PartLLT.getSizeInBits() * Regs.size();
788-
if (SrcSize == OrigTy.getSizeInBits())
789-
B.buildMerge(OrigRegs[0], Regs);
790-
else {
791-
auto Widened = B.buildMerge(LLT::scalar(SrcSize), Regs);
792-
B.buildTrunc(OrigRegs[0], Widened);
793-
}
794-
795-
return;
796-
}
797-
798-
if (LLTy.isVector() && PartLLT.isVector()) {
799-
assert(OrigRegs.size() == 1);
800-
assert(LLTy.getElementType() == PartLLT.getElementType());
801-
mergeVectorRegsToResultRegs(B, OrigRegs, Regs);
802-
return;
803-
}
804-
805-
assert(LLTy.isVector() && !PartLLT.isVector());
806-
807-
LLT DstEltTy = LLTy.getElementType();
808-
809-
// Pointer information was discarded. We'll need to coerce some register types
810-
// to avoid violating type constraints.
811-
LLT RealDstEltTy = MRI.getType(OrigRegs[0]).getElementType();
812-
813-
assert(DstEltTy.getSizeInBits() == RealDstEltTy.getSizeInBits());
814-
815-
if (DstEltTy == PartLLT) {
816-
// Vector was trivially scalarized.
817-
818-
if (RealDstEltTy.isPointer()) {
819-
for (Register Reg : Regs)
820-
MRI.setType(Reg, RealDstEltTy);
821-
}
822-
823-
B.buildBuildVector(OrigRegs[0], Regs);
824-
} else if (DstEltTy.getSizeInBits() > PartLLT.getSizeInBits()) {
825-
// Deal with vector with 64-bit elements decomposed to 32-bit
826-
// registers. Need to create intermediate 64-bit elements.
827-
SmallVector<Register, 8> EltMerges;
828-
int PartsPerElt = DstEltTy.getSizeInBits() / PartLLT.getSizeInBits();
829-
830-
assert(DstEltTy.getSizeInBits() % PartLLT.getSizeInBits() == 0);
831-
832-
for (int I = 0, NumElts = LLTy.getNumElements(); I != NumElts; ++I) {
833-
auto Merge = B.buildMerge(RealDstEltTy, Regs.take_front(PartsPerElt));
834-
// Fix the type in case this is really a vector of pointers.
835-
MRI.setType(Merge.getReg(0), RealDstEltTy);
836-
EltMerges.push_back(Merge.getReg(0));
837-
Regs = Regs.drop_front(PartsPerElt);
838-
}
839-
840-
B.buildBuildVector(OrigRegs[0], EltMerges);
841-
} else {
842-
// Vector was split, and elements promoted to a wider type.
843-
LLT BVType = LLT::vector(LLTy.getNumElements(), PartLLT);
844-
auto BV = B.buildBuildVector(BVType, Regs);
845-
B.buildTrunc(OrigRegs[0], BV);
846-
}
847-
}
848-
849738
bool AMDGPUCallLowering::lowerFormalArguments(
850739
MachineIRBuilder &B, const Function &F, ArrayRef<ArrayRef<Register>> VRegs,
851740
FunctionLoweringInfo &FLI) const {
@@ -886,7 +775,6 @@ bool AMDGPUCallLowering::lowerFormalArguments(
886775
CCInfo.AllocateReg(ImplicitBufferPtrReg);
887776
}
888777

889-
SmallVector<ArgInfo, 8> SplitArg;
890778
SmallVector<ArgInfo, 32> SplitArgs;
891779
unsigned Idx = 0;
892780
unsigned PSInputNum = 0;
@@ -936,19 +824,7 @@ bool AMDGPUCallLowering::lowerFormalArguments(
936824
const unsigned OrigArgIdx = Idx + AttributeList::FirstArgIndex;
937825
setArgFlags(OrigArg, OrigArgIdx, DL, F);
938826

939-
SplitArg.clear();
940-
splitToValueTypes(B, OrigArg, SplitArg, DL, CC);
941-
942-
processSplitArgs(B, OrigArg, SplitArg, SplitArgs, DL, CC, false,
943-
// FIXME: We should probably be passing multiple registers
944-
// to handleAssignments to do this
945-
[&](ArrayRef<Register> Regs, Register DstReg, LLT LLTy,
946-
LLT PartLLT, int VTSplitIdx) {
947-
assert(DstReg == VRegs[Idx][VTSplitIdx]);
948-
packSplitRegsToOrigType(B, VRegs[Idx][VTSplitIdx], Regs,
949-
LLTy, PartLLT);
950-
});
951-
827+
splitToValueTypes(B, OrigArg, SplitArgs, DL, CC);
952828
++Idx;
953829
}
954830

@@ -1356,19 +1232,7 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
13561232
insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs,
13571233
Info.DemoteRegister, Info.DemoteStackIndex);
13581234
} else if (!Info.OrigRet.Ty->isVoidTy()) {
1359-
SmallVector<ArgInfo, 8> PreSplitRetInfos;
1360-
1361-
splitToValueTypes(
1362-
MIRBuilder, Info.OrigRet, PreSplitRetInfos/*InArgs*/, DL, Info.CallConv);
1363-
1364-
processSplitArgs(MIRBuilder, Info.OrigRet,
1365-
PreSplitRetInfos, InArgs/*SplitRetInfos*/, DL, Info.CallConv, false,
1366-
[&](ArrayRef<Register> Regs, Register DstReg,
1367-
LLT LLTy, LLT PartLLT, int VTSplitIdx) {
1368-
assert(DstReg == Info.OrigRet.Regs[VTSplitIdx]);
1369-
packSplitRegsToOrigType(MIRBuilder, Info.OrigRet.Regs[VTSplitIdx],
1370-
Regs, LLTy, PartLLT);
1371-
});
1235+
splitToValueTypes(MIRBuilder, Info.OrigRet, InArgs, DL, Info.CallConv);
13721236
}
13731237

13741238
// Make sure the raw argument copies are inserted before the marshalling to

0 commit comments

Comments
 (0)