-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Recommit [RISCV] RISCV vector calling convention (2/2) (#79096) #87736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) ChangesReturn values are handled in a same way as function arguments. Full diff: https://github.com/llvm/llvm-project/pull/87736.diff 5 Files Affected:
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f64ded4f2cf965..6e7b67ded23c84 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
else if (attr.hasRetAttr(Attribute::ZExt))
Flags.setZExt();
- for (unsigned i = 0; i < NumParts; ++i)
- Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0));
+ for (unsigned i = 0; i < NumParts; ++i) {
+ ISD::ArgFlagsTy OutFlags = Flags;
+ if (NumParts > 1 && i == 0)
+ OutFlags.setSplit();
+ else if (i == NumParts - 1 && i != 0)
+ OutFlags.setSplitEnd();
+
+ Outs.push_back(
+ ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0));
+ }
}
}
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 8af4bc658409d4..c18892ac62f247 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -409,7 +409,7 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
- F.getReturnType()};
+ ArrayRef(F.getReturnType())};
RISCVOutgoingValueAssigner Assigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/true, Dispatcher);
@@ -538,7 +538,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
++Index;
}
- RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+ RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+ ArrayRef(TypeList)};
RISCVIncomingValueAssigner Assigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/false, Dispatcher);
@@ -603,7 +604,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
- RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+ RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+ ArrayRef(TypeList)};
RISCVOutgoingValueAssigner ArgAssigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/false, ArgDispatcher);
@@ -635,7 +637,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
- F.getReturnType()};
+ ArrayRef(F.getReturnType())};
RISCVIncomingValueAssigner RetAssigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/true, RetDispatcher);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 279d8a435a04ca..fb5027406b0808 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18327,13 +18327,15 @@ void RISCVTargetLowering::analyzeInputArgs(
unsigned NumArgs = Ins.size();
FunctionType *FType = MF.getFunction().getFunctionType();
- SmallVector<Type *, 4> TypeList;
- if (IsRet)
- TypeList.push_back(MF.getFunction().getReturnType());
- else
+ RVVArgDispatcher Dispatcher;
+ if (IsRet) {
+ Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(Ins)};
+ } else {
+ SmallVector<Type *, 4> TypeList;
for (const Argument &Arg : MF.getFunction().args())
TypeList.push_back(Arg.getType());
- RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+ Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(TypeList)};
+ }
for (unsigned i = 0; i != NumArgs; ++i) {
MVT ArgVT = Ins[i].VT;
@@ -18368,7 +18370,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
else if (CLI)
for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
TypeList.push_back(Arg.Ty);
- RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+ RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(TypeList)};
for (unsigned i = 0; i != NumArgs; i++) {
MVT ArgVT = Outs[i].VT;
@@ -19272,7 +19274,7 @@ bool RISCVTargetLowering::CanLowerReturn(
SmallVector<CCValAssign, 16> RVLocs;
CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
- RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
+ RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(Outs)};
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
MVT VT = Outs[i].VT;
@@ -21077,7 +21079,82 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
return Subtarget.getMinimumJumpTableEntries();
}
-void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
+// Handle single arg such as return value.
+template <typename Arg>
+void RVVArgDispatcher::constructArgInfos(ArrayRef<Arg> ArgList) {
+ // This lambda determines whether an array of types are constructed by
+ // homogeneous vector types.
+ auto isHomogeneousScalableVectorType = [&](ArrayRef<Arg> ArgList) {
+ // First, extract the first element in the argument type.
+ MVT FirstArgRegType;
+ unsigned FirstArgElements = 0;
+ typename SmallVectorImpl<Arg>::const_iterator It;
+ for (It = ArgList.begin(); It != ArgList.end(); ++It) {
+ FirstArgRegType = It->VT;
+ ++FirstArgElements;
+ if (!It->Flags.isSplit() || It->Flags.isSplitEnd())
+ break;
+ }
+ ++It;
+
+ // Return if this argument type contains only 1 element, or it's not a
+ // vector type.
+ if (It == ArgList.end() || !FirstArgRegType.isScalableVector())
+ return false;
+
+ // Second, check if the following elements in this argument type are all the
+ // same.
+ MVT ArgRegType;
+ unsigned ArgElements = 0;
+ bool IsPart = false;
+ for (; It != ArgList.end(); ++It) {
+ ArgRegType = It->VT;
+ ++ArgElements;
+ if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) {
+ if (ArgRegType != FirstArgRegType || ArgElements != FirstArgElements)
+ return false;
+
+ IsPart = false;
+ ArgElements = 0;
+ continue;
+ }
+
+ IsPart = true;
+ }
+
+ return true;
+ };
+
+ if (isHomogeneousScalableVectorType(ArgList)) {
+ // Handle as tuple type
+ RVVArgInfos.push_back({(unsigned)ArgList.size(), ArgList[0].VT, false});
+ } else {
+ // Handle as normal vector type
+ bool FirstVMaskAssigned = false;
+ for (const auto &OutArg : ArgList) {
+ MVT RegisterVT = OutArg.VT;
+
+ // Skip non-RVV register type
+ if (!RegisterVT.isVector())
+ continue;
+
+ if (RegisterVT.isFixedLengthVector())
+ RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+ if (!FirstVMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) {
+ RVVArgInfos.push_back({1, RegisterVT, true});
+ FirstVMaskAssigned = true;
+ continue;
+ }
+
+ RVVArgInfos.push_back({1, RegisterVT, false});
+ }
+ }
+}
+
+// Handle multiple args.
+template <>
+void RVVArgDispatcher::constructArgInfos<Type *>(ArrayRef<Type *> TypeList) {
const DataLayout &DL = MF->getDataLayout();
const Function &F = MF->getFunction();
LLVMContext &Context = F.getContext();
@@ -21090,8 +21167,11 @@ void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
EVT VT = TLI->getValueType(DL, ElemTy);
MVT RegisterVT =
TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+ unsigned NumRegs =
+ TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
- RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+ RVVArgInfos.push_back(
+ {NumRegs * STy->getNumElements(), RegisterVT, false});
} else {
SmallVector<EVT, 4> ValueVTs;
ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index c28552354bf422..a2456f2fab66b1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1041,13 +1041,16 @@ class RVVArgDispatcher {
bool FirstVMask = false;
};
+ template <typename Arg>
RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
- ArrayRef<Type *> TypeList)
+ ArrayRef<Arg> ArgList)
: MF(MF), TLI(TLI) {
- constructArgInfos(TypeList);
+ constructArgInfos(ArgList);
compute();
}
+ RVVArgDispatcher() = default;
+
MCPhysReg getNextPhysReg();
private:
@@ -1059,7 +1062,7 @@ class RVVArgDispatcher {
unsigned CurIdx = 0;
- void constructArgInfos(ArrayRef<Type *> TypeList);
+ template <typename Arg> void constructArgInfos(ArrayRef<Arg> Ret);
void compute();
void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
unsigned StartReg = 0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 90edb994ce8222..647d3158b6167f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -249,3 +249,119 @@ define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <
%add = add <vscale x 1 x i64> %0, %2
ret <vscale x 1 x i64> %add
}
+
+declare <vscale x 1 x i64> @callee1()
+declare void @callee2(<vscale x 1 x i64>)
+declare void @callee3(<vscale x 4 x i32>)
+define void @caller() {
+; RV32-LABEL: caller:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call callee1
+; RV32-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; RV32-NEXT: vadd.vv v8, v8, v8
+; RV32-NEXT: call callee2
+; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: caller:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call callee1
+; RV64-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; RV64-NEXT: vadd.vv v8, v8, v8
+; RV64-NEXT: call callee2
+; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 16
+; RV64-NEXT: ret
+ %a = call <vscale x 1 x i64> @callee1()
+ %add = add <vscale x 1 x i64> %a, %a
+ call void @callee2(<vscale x 1 x i64> %add)
+ ret void
+}
+
+declare {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+define void @caller_tuple() {
+; RV32-LABEL: caller_tuple:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call callee_tuple
+; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT: vadd.vv v8, v8, v10
+; RV32-NEXT: call callee3
+; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: caller_tuple:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call callee_tuple
+; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT: vadd.vv v8, v8, v10
+; RV64-NEXT: call callee3
+; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 16
+; RV64-NEXT: ret
+ %a = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+ %b = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 0
+ %c = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 1
+ %add = add <vscale x 4 x i32> %b, %c
+ call void @callee3(<vscale x 4 x i32> %add)
+ ret void
+}
+
+declare {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+define void @caller_nested() {
+; RV32-LABEL: caller_nested:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call callee_nested
+; RV32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT: vadd.vv v8, v8, v10
+; RV32-NEXT: vadd.vv v8, v8, v12
+; RV32-NEXT: call callee3
+; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: caller_nested:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call callee_nested
+; RV64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT: vadd.vv v8, v8, v10
+; RV64-NEXT: vadd.vv v8, v8, v12
+; RV64-NEXT: call callee3
+; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 16
+; RV64-NEXT: ret
+ %a = call {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+ %b = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 0
+ %c = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 1
+ %c0 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 0
+ %c1 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 1
+ %add0 = add <vscale x 4 x i32> %b, %c0
+ %add1 = add <vscale x 4 x i32> %add0, %c1
+ call void @callee3(<vscale x 4 x i32> %add1)
+ ret void
+}
|
This pr resolves #87323 |
Integrated the "Recommit [RISCV] RISCV vector calling convention (2/2)" and bug fix patch. |
for (; It != ArgList.end(); ++It) { | ||
ArgRegType = It->VT; | ||
++ArgElements; | ||
if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have tests for the split case? That should only happen for types larger than LMUL=8 I think? Do we even want to consider those as tuple types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all types in a split guaranteed to be the same type or do we need to check all of the types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have tests for the split case? That should only happen for types larger than LMUL=8 I think? Do we even want to consider those as tuple types?
You are right and I think the split case doesn't need to be considered as tuple type since we only have 16 vregs, and the minimum split type is also LMUL=16, so it's impossible that argument that has more than 1 split type can be fully passed by reg.
I guess it can be handled as normal vector type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Return values are handled in a same way as function arguments. One thing to mention is that if a type can be broken down into homogeneous vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}}, it is considered as a vector tuple type and need to be handled by tuple type rule.
…79096) (llvm#87736)" This reverts commit 91dd844.
…79096) (llvm#87736)" This reverts commit 91dd844.
…79096) (llvm#87736)" This reverts commit 91dd844.
…79096) (llvm#87736)" This reverts commit 91dd844.
…79096) (llvm#87736)" This reverts commit 91dd844.
…79096) (llvm#87736)" This reverts commit 91dd844.
…79096) (llvm#87736)" This reverts commit 91dd844.
Bug fix: Handle RVV return type in calling convention correctly.
Return values are handled in a same way as function arguments.
One thing to mention is that if a type can be broken down into homogeneous
vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}},
it is considered as a vector tuple type and need to be handled by tuple
type rule.