Skip to content

Commit cad2c89

Browse files
committed
[RISCV] Handle RVV return type in calling convention correctly
Return values are handled in a same way as function arguments. One thing to mention is that if a type can be broken down into homogeneous vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}}, it is considered as a vector tuple type and need to be handled by tuple type rule.
1 parent 5258b46 commit cad2c89

File tree

5 files changed

+236
-18
lines changed

5 files changed

+236
-18
lines changed

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
18091809
else if (attr.hasRetAttr(Attribute::ZExt))
18101810
Flags.setZExt();
18111811

1812-
for (unsigned i = 0; i < NumParts; ++i)
1813-
Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0));
1812+
for (unsigned i = 0; i < NumParts; ++i) {
1813+
ISD::ArgFlagsTy OutFlags = Flags;
1814+
if (NumParts > 1 && i == 0)
1815+
OutFlags.setSplit();
1816+
else if (i == NumParts - 1 && i != 0)
1817+
OutFlags.setSplitEnd();
1818+
1819+
Outs.push_back(
1820+
ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0));
1821+
}
18141822
}
18151823
}
18161824

llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
409409
splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
410410

411411
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
412-
F.getReturnType()};
412+
ArrayRef(F.getReturnType())};
413413
RISCVOutgoingValueAssigner Assigner(
414414
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
415415
/*IsRet=*/true, Dispatcher);
@@ -538,7 +538,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
538538
++Index;
539539
}
540540

541-
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
541+
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
542+
ArrayRef(TypeList)};
542543
RISCVIncomingValueAssigner Assigner(
543544
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
544545
/*IsRet=*/false, Dispatcher);
@@ -603,7 +604,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
603604
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
604605
Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
605606

606-
RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
607+
RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
608+
ArrayRef(TypeList)};
607609
RISCVOutgoingValueAssigner ArgAssigner(
608610
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
609611
/*IsRet=*/false, ArgDispatcher);
@@ -635,7 +637,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
635637
splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
636638

637639
RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
638-
F.getReturnType()};
640+
ArrayRef(F.getReturnType())};
639641
RISCVIncomingValueAssigner RetAssigner(
640642
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
641643
/*IsRet=*/true, RetDispatcher);

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18339,13 +18339,15 @@ void RISCVTargetLowering::analyzeInputArgs(
1833918339
unsigned NumArgs = Ins.size();
1834018340
FunctionType *FType = MF.getFunction().getFunctionType();
1834118341

18342-
SmallVector<Type *, 4> TypeList;
18343-
if (IsRet)
18344-
TypeList.push_back(MF.getFunction().getReturnType());
18345-
else
18342+
RVVArgDispatcher Dispatcher;
18343+
if (IsRet) {
18344+
Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(Ins)};
18345+
} else {
18346+
SmallVector<Type *, 4> TypeList;
1834618347
for (const Argument &Arg : MF.getFunction().args())
1834718348
TypeList.push_back(Arg.getType());
18348-
RVVArgDispatcher Dispatcher{&MF, this, TypeList};
18349+
Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(TypeList)};
18350+
}
1834918351

1835018352
for (unsigned i = 0; i != NumArgs; ++i) {
1835118353
MVT ArgVT = Ins[i].VT;
@@ -18380,7 +18382,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
1838018382
else if (CLI)
1838118383
for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
1838218384
TypeList.push_back(Arg.Ty);
18383-
RVVArgDispatcher Dispatcher{&MF, this, TypeList};
18385+
RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(TypeList)};
1838418386

1838518387
for (unsigned i = 0; i != NumArgs; i++) {
1838618388
MVT ArgVT = Outs[i].VT;
@@ -19284,7 +19286,7 @@ bool RISCVTargetLowering::CanLowerReturn(
1928419286
SmallVector<CCValAssign, 16> RVLocs;
1928519287
CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
1928619288

19287-
RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
19289+
RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(Outs)};
1928819290

1928919291
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1929019292
MVT VT = Outs[i].VT;
@@ -21089,7 +21091,91 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
2108921091
return Subtarget.getMinimumJumpTableEntries();
2109021092
}
2109121093

21092-
void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
21094+
// Handle single arg such as return value.
21095+
template <typename Arg>
21096+
void RVVArgDispatcher::constructArgInfos(ArrayRef<Arg> ArgList) {
21097+
// This lambda determines whether an array of types are constructed by
21098+
// homogeneous vector types.
21099+
auto isHomogeneousScalableVectorType = [](ArrayRef<Arg> ArgList) {
21100+
// First, extract the first element in the argument type.
21101+
MVT FirstArgRegType;
21102+
unsigned FirstArgElements = 0;
21103+
auto It = ArgList.begin();
21104+
bool IsPart = false;
21105+
21106+
if (It == ArgList.end())
21107+
return false;
21108+
21109+
for (; It != ArgList.end(); ++It) {
21110+
FirstArgRegType = It->VT;
21111+
++FirstArgElements;
21112+
if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd())
21113+
break;
21114+
21115+
IsPart = true;
21116+
}
21117+
21118+
assert(It != ArgList.end() && "It shouldn't reach the end of ArgList.");
21119+
++It;
21120+
21121+
// Return if this argument type contains only 1 element, or it's not a
21122+
// vector type.
21123+
if (It == ArgList.end() || !FirstArgRegType.isScalableVector())
21124+
return false;
21125+
21126+
// Second, check if the following elements in this argument type are all the
21127+
// same.
21128+
MVT ArgRegType;
21129+
unsigned ArgElements = 0;
21130+
IsPart = false;
21131+
for (; It != ArgList.end(); ++It) {
21132+
ArgRegType = It->VT;
21133+
++ArgElements;
21134+
if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) {
21135+
if (ArgRegType != FirstArgRegType || ArgElements != FirstArgElements)
21136+
return false;
21137+
21138+
IsPart = false;
21139+
ArgElements = 0;
21140+
continue;
21141+
}
21142+
21143+
IsPart = true;
21144+
}
21145+
21146+
return true;
21147+
};
21148+
21149+
if (isHomogeneousScalableVectorType(ArgList)) {
21150+
// Handle as tuple type
21151+
RVVArgInfos.push_back({(unsigned)ArgList.size(), ArgList[0].VT, false});
21152+
} else {
21153+
// Handle as normal vector type
21154+
bool FirstVMaskAssigned = false;
21155+
for (const auto &OutArg : ArgList) {
21156+
MVT RegisterVT = OutArg.VT;
21157+
21158+
// Skip non-RVV register type
21159+
if (!RegisterVT.isVector())
21160+
continue;
21161+
21162+
if (RegisterVT.isFixedLengthVector())
21163+
RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
21164+
21165+
if (!FirstVMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) {
21166+
RVVArgInfos.push_back({1, RegisterVT, true});
21167+
FirstVMaskAssigned = true;
21168+
continue;
21169+
}
21170+
21171+
RVVArgInfos.push_back({1, RegisterVT, false});
21172+
}
21173+
}
21174+
}
21175+
21176+
// Handle multiple args.
21177+
template <>
21178+
void RVVArgDispatcher::constructArgInfos<Type *>(ArrayRef<Type *> TypeList) {
2109321179
const DataLayout &DL = MF->getDataLayout();
2109421180
const Function &F = MF->getFunction();
2109521181
LLVMContext &Context = F.getContext();
@@ -21102,8 +21188,11 @@ void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
2110221188
EVT VT = TLI->getValueType(DL, ElemTy);
2110321189
MVT RegisterVT =
2110421190
TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
21191+
unsigned NumRegs =
21192+
TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
2110521193

21106-
RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
21194+
RVVArgInfos.push_back(
21195+
{NumRegs * STy->getNumElements(), RegisterVT, false});
2110721196
} else {
2110821197
SmallVector<EVT, 4> ValueVTs;
2110921198
ComputeValueVTs(*TLI, DL, Ty, ValueVTs);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,13 +1041,16 @@ class RVVArgDispatcher {
10411041
bool FirstVMask = false;
10421042
};
10431043

1044+
template <typename Arg>
10441045
RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
1045-
ArrayRef<Type *> TypeList)
1046+
ArrayRef<Arg> ArgList)
10461047
: MF(MF), TLI(TLI) {
1047-
constructArgInfos(TypeList);
1048+
constructArgInfos(ArgList);
10481049
compute();
10491050
}
10501051

1052+
RVVArgDispatcher() = default;
1053+
10511054
MCPhysReg getNextPhysReg();
10521055

10531056
private:
@@ -1059,7 +1062,7 @@ class RVVArgDispatcher {
10591062

10601063
unsigned CurIdx = 0;
10611064

1062-
void constructArgInfos(ArrayRef<Type *> TypeList);
1065+
template <typename Arg> void constructArgInfos(ArrayRef<Arg> Ret);
10631066
void compute();
10641067
void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
10651068
unsigned StartReg = 0);

llvm/test/CodeGen/RISCV/rvv/calling-conv.ll

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,119 @@ define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <
249249
%add = add <vscale x 1 x i64> %0, %2
250250
ret <vscale x 1 x i64> %add
251251
}
252+
253+
declare <vscale x 1 x i64> @callee1()
254+
declare void @callee2(<vscale x 1 x i64>)
255+
declare void @callee3(<vscale x 4 x i32>)
256+
define void @caller() {
257+
; RV32-LABEL: caller:
258+
; RV32: # %bb.0:
259+
; RV32-NEXT: addi sp, sp, -16
260+
; RV32-NEXT: .cfi_def_cfa_offset 16
261+
; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
262+
; RV32-NEXT: .cfi_offset ra, -4
263+
; RV32-NEXT: call callee1
264+
; RV32-NEXT: vsetvli a0, zero, e64, m1, ta, ma
265+
; RV32-NEXT: vadd.vv v8, v8, v8
266+
; RV32-NEXT: call callee2
267+
; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
268+
; RV32-NEXT: addi sp, sp, 16
269+
; RV32-NEXT: ret
270+
;
271+
; RV64-LABEL: caller:
272+
; RV64: # %bb.0:
273+
; RV64-NEXT: addi sp, sp, -16
274+
; RV64-NEXT: .cfi_def_cfa_offset 16
275+
; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
276+
; RV64-NEXT: .cfi_offset ra, -8
277+
; RV64-NEXT: call callee1
278+
; RV64-NEXT: vsetvli a0, zero, e64, m1, ta, ma
279+
; RV64-NEXT: vadd.vv v8, v8, v8
280+
; RV64-NEXT: call callee2
281+
; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
282+
; RV64-NEXT: addi sp, sp, 16
283+
; RV64-NEXT: ret
284+
%a = call <vscale x 1 x i64> @callee1()
285+
%add = add <vscale x 1 x i64> %a, %a
286+
call void @callee2(<vscale x 1 x i64> %add)
287+
ret void
288+
}
289+
290+
declare {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
291+
define void @caller_tuple() {
292+
; RV32-LABEL: caller_tuple:
293+
; RV32: # %bb.0:
294+
; RV32-NEXT: addi sp, sp, -16
295+
; RV32-NEXT: .cfi_def_cfa_offset 16
296+
; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
297+
; RV32-NEXT: .cfi_offset ra, -4
298+
; RV32-NEXT: call callee_tuple
299+
; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
300+
; RV32-NEXT: vadd.vv v8, v8, v10
301+
; RV32-NEXT: call callee3
302+
; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
303+
; RV32-NEXT: addi sp, sp, 16
304+
; RV32-NEXT: ret
305+
;
306+
; RV64-LABEL: caller_tuple:
307+
; RV64: # %bb.0:
308+
; RV64-NEXT: addi sp, sp, -16
309+
; RV64-NEXT: .cfi_def_cfa_offset 16
310+
; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
311+
; RV64-NEXT: .cfi_offset ra, -8
312+
; RV64-NEXT: call callee_tuple
313+
; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
314+
; RV64-NEXT: vadd.vv v8, v8, v10
315+
; RV64-NEXT: call callee3
316+
; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
317+
; RV64-NEXT: addi sp, sp, 16
318+
; RV64-NEXT: ret
319+
%a = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
320+
%b = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 0
321+
%c = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 1
322+
%add = add <vscale x 4 x i32> %b, %c
323+
call void @callee3(<vscale x 4 x i32> %add)
324+
ret void
325+
}
326+
327+
declare {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
328+
define void @caller_nested() {
329+
; RV32-LABEL: caller_nested:
330+
; RV32: # %bb.0:
331+
; RV32-NEXT: addi sp, sp, -16
332+
; RV32-NEXT: .cfi_def_cfa_offset 16
333+
; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
334+
; RV32-NEXT: .cfi_offset ra, -4
335+
; RV32-NEXT: call callee_nested
336+
; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
337+
; RV32-NEXT: vadd.vv v8, v8, v10
338+
; RV32-NEXT: vadd.vv v8, v8, v12
339+
; RV32-NEXT: call callee3
340+
; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
341+
; RV32-NEXT: addi sp, sp, 16
342+
; RV32-NEXT: ret
343+
;
344+
; RV64-LABEL: caller_nested:
345+
; RV64: # %bb.0:
346+
; RV64-NEXT: addi sp, sp, -16
347+
; RV64-NEXT: .cfi_def_cfa_offset 16
348+
; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
349+
; RV64-NEXT: .cfi_offset ra, -8
350+
; RV64-NEXT: call callee_nested
351+
; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
352+
; RV64-NEXT: vadd.vv v8, v8, v10
353+
; RV64-NEXT: vadd.vv v8, v8, v12
354+
; RV64-NEXT: call callee3
355+
; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
356+
; RV64-NEXT: addi sp, sp, 16
357+
; RV64-NEXT: ret
358+
%a = call {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
359+
%b = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 0
360+
%c = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 1
361+
%c0 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 0
362+
%c1 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 1
363+
%add0 = add <vscale x 4 x i32> %b, %c0
364+
%add1 = add <vscale x 4 x i32> %add0, %c1
365+
call void @callee3(<vscale x 4 x i32> %add1)
366+
ret void
367+
}

0 commit comments

Comments
 (0)