Skip to content

Commit d94f74f

Browse files
committed
split getSameExtensionOpcode
1 parent f0e6c8b commit d94f74f

File tree

1 file changed

+43
-22
lines changed

1 file changed

+43
-22
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13457,30 +13457,54 @@ struct NodeExtensionHelper {
1345713457
return NarrowVT;
1345813458
}
1345913459

13460-
/// Return the opcode required to materialize the folding for
13461-
/// both operands for \p Opcode.
13462-
/// Put differently, get the opcode to materialize:
13463-
/// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
13464-
/// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
13465-
/// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
13466-
static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
13460+
/// Get the opcode to materialize:
13461+
/// Opcode(sext(a), sext(b)) -> newOpcode(a, b)
13462+
static unsigned getSExtOpcode(unsigned Opcode) {
1346713463
switch (Opcode) {
1346813464
case ISD::ADD:
1346913465
case RISCVISD::ADD_VL:
1347013466
case RISCVISD::VWADD_W_VL:
1347113467
case RISCVISD::VWADDU_W_VL:
13472-
return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
13473-
: RISCVISD::VWADDU_VL;
13468+
return RISCVISD::VWADD_VL;
13469+
case ISD::SUB:
13470+
case RISCVISD::SUB_VL:
13471+
case RISCVISD::VWSUB_W_VL:
13472+
case RISCVISD::VWSUBU_W_VL:
13473+
return RISCVISD::VWSUB_VL;
1347413474
case ISD::MUL:
1347513475
case RISCVISD::MUL_VL:
13476-
return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
13477-
: RISCVISD::VWMULU_VL;
13476+
return RISCVISD::VWMUL_VL;
13477+
default:
13478+
llvm_unreachable("Unexpected opcode");
13479+
}
13480+
}
13481+
13482+
/// Get the opcode to materialize:
13483+
/// Opcode(zext(a), zext(b)) -> newOpcode(a, b)
13484+
static unsigned getZExtOpcode(unsigned Opcode) {
13485+
switch (Opcode) {
13486+
case ISD::ADD:
13487+
case RISCVISD::ADD_VL:
13488+
case RISCVISD::VWADD_W_VL:
13489+
case RISCVISD::VWADDU_W_VL:
13490+
return RISCVISD::VWADDU_VL;
1347813491
case ISD::SUB:
1347913492
case RISCVISD::SUB_VL:
1348013493
case RISCVISD::VWSUB_W_VL:
1348113494
case RISCVISD::VWSUBU_W_VL:
13482-
return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
13483-
: RISCVISD::VWSUBU_VL;
13495+
return RISCVISD::VWSUBU_VL;
13496+
case ISD::MUL:
13497+
case RISCVISD::MUL_VL:
13498+
return RISCVISD::VWMULU_VL;
13499+
default:
13500+
llvm_unreachable("Unexpected opcode");
13501+
}
13502+
}
13503+
13504+
/// Get the opcode to materialize:
13505+
/// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
13506+
static unsigned getFPExtOpcode(unsigned Opcode) {
13507+
switch (Opcode) {
1348413508
case RISCVISD::FADD_VL:
1348513509
case RISCVISD::VFWADD_W_VL:
1348613510
return RISCVISD::VFWADD_VL;
@@ -13835,19 +13859,16 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1383513859
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
1383613860
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
1383713861
return std::nullopt;
13838-
if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13839-
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13840-
Root->getOpcode(), ExtKind::ZExt),
13862+
if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
13863+
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
1384113864
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
1384213865
/*RHSExt=*/{ExtKind::ZExt});
13843-
if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
13844-
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13845-
Root->getOpcode(), ExtKind::SExt),
13866+
if ((AllowExtMask & ExtKind::SExt) && LHS.SupportsSExt && RHS.SupportsSExt)
13867+
return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
1384613868
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
1384713869
/*RHSExt=*/{ExtKind::SExt});
13848-
if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
13849-
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13850-
Root->getOpcode(), ExtKind::FPExt),
13870+
if ((AllowExtMask & ExtKind::FPExt) && RHS.SupportsFPExt)
13871+
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1385113872
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1385213873
/*RHSExt=*/{ExtKind::FPExt});
1385313874
return std::nullopt;

0 commit comments

Comments
 (0)