@@ -13457,30 +13457,54 @@ struct NodeExtensionHelper {
13457
13457
return NarrowVT;
13458
13458
}
13459
13459
13460
- /// Return the opcode required to materialize the folding for
13461
- /// both operands for \p Opcode.
13462
- /// Put differently, get the opcode to materialize:
13463
- /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
13464
- /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
13465
- /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
13466
- static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
13460
+ /// Get the opcode to materialize:
13461
+ /// Opcode(sext(a), sext(b)) -> newOpcode(a, b)
13462
+ static unsigned getSExtOpcode(unsigned Opcode) {
13467
13463
switch (Opcode) {
13468
13464
case ISD::ADD:
13469
13465
case RISCVISD::ADD_VL:
13470
13466
case RISCVISD::VWADD_W_VL:
13471
13467
case RISCVISD::VWADDU_W_VL:
13472
- return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
13473
- : RISCVISD::VWADDU_VL;
13468
+ return RISCVISD::VWADD_VL;
13469
+ case ISD::SUB:
13470
+ case RISCVISD::SUB_VL:
13471
+ case RISCVISD::VWSUB_W_VL:
13472
+ case RISCVISD::VWSUBU_W_VL:
13473
+ return RISCVISD::VWSUB_VL;
13474
13474
case ISD::MUL:
13475
13475
case RISCVISD::MUL_VL:
13476
- return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
13477
- : RISCVISD::VWMULU_VL;
13476
+ return RISCVISD::VWMUL_VL;
13477
+ default:
13478
+ llvm_unreachable("Unexpected opcode");
13479
+ }
13480
+ }
13481
+
13482
+ /// Get the opcode to materialize:
13483
+ /// Opcode(zext(a), zext(b)) -> newOpcode(a, b)
13484
+ static unsigned getZExtOpcode(unsigned Opcode) {
13485
+ switch (Opcode) {
13486
+ case ISD::ADD:
13487
+ case RISCVISD::ADD_VL:
13488
+ case RISCVISD::VWADD_W_VL:
13489
+ case RISCVISD::VWADDU_W_VL:
13490
+ return RISCVISD::VWADDU_VL;
13478
13491
case ISD::SUB:
13479
13492
case RISCVISD::SUB_VL:
13480
13493
case RISCVISD::VWSUB_W_VL:
13481
13494
case RISCVISD::VWSUBU_W_VL:
13482
- return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
13483
- : RISCVISD::VWSUBU_VL;
13495
+ return RISCVISD::VWSUBU_VL;
13496
+ case ISD::MUL:
13497
+ case RISCVISD::MUL_VL:
13498
+ return RISCVISD::VWMULU_VL;
13499
+ default:
13500
+ llvm_unreachable("Unexpected opcode");
13501
+ }
13502
+ }
13503
+
13504
+ /// Get the opcode to materialize:
13505
+ /// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
13506
+ static unsigned getFPExtOpcode(unsigned Opcode) {
13507
+ switch (Opcode) {
13484
13508
case RISCVISD::FADD_VL:
13485
13509
case RISCVISD::VFWADD_W_VL:
13486
13510
return RISCVISD::VFWADD_VL;
@@ -13835,19 +13859,16 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
13835
13859
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
13836
13860
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
13837
13861
return std::nullopt;
13838
- if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13839
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13840
- Root->getOpcode(), ExtKind::ZExt),
13862
+ if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
13863
+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
13841
13864
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
13842
13865
/*RHSExt=*/{ExtKind::ZExt});
13843
- if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
13844
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13845
- Root->getOpcode(), ExtKind::SExt),
13866
+ if ((AllowExtMask & ExtKind::SExt) && LHS.SupportsSExt && RHS.SupportsSExt)
13867
+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
13846
13868
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
13847
13869
/*RHSExt=*/{ExtKind::SExt});
13848
- if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
13849
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13850
- Root->getOpcode(), ExtKind::FPExt),
13870
+ if ((AllowExtMask & ExtKind::FPExt) && RHS.SupportsFPExt)
13871
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
13851
13872
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
13852
13873
/*RHSExt=*/{ExtKind::FPExt});
13853
13874
return std::nullopt;
0 commit comments