@@ -13292,8 +13292,7 @@ namespace {
13292
13292
// apply a combine.
13293
13293
struct CombineResult;
13294
13294
13295
- enum class ExtKind { ZExt, SExt, FPExt };
13296
-
13295
+ enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
13297
13296
/// Helper class for folding sign/zero extensions.
13298
13297
/// In particular, this class is used for the following combines:
13299
13298
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -13424,13 +13423,11 @@ struct NodeExtensionHelper {
13424
13423
// Determine the narrow size.
13425
13424
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13426
13425
13427
- unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
13428
-
13429
13426
MVT EltVT = SupportsExt == ExtKind::FPExt
13430
13427
? MVT::getFloatingPointVT(NarrowSize)
13431
13428
: MVT::getIntegerVT(NarrowSize);
13432
13429
13433
- assert(NarrowSize >= NarrowMinSize &&
13430
+ assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
13434
13431
"Trying to extend something we can't represent");
13435
13432
MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
13436
13433
return NarrowVT;
@@ -13799,33 +13796,32 @@ struct CombineResult {
13799
13796
/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
13800
13797
/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
13801
13798
/// are zext) and LHS and RHS can be folded into Root.
13802
- /// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
13799
+ /// AllowExtMask define which form `ext` can take in this pattern.
13803
13800
///
13804
13801
/// \note If the pattern can match with both zext and sext, the returned
13805
13802
/// CombineResult will feature the zext result.
13806
13803
///
13807
13804
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
13808
13805
/// can be used to apply the pattern.
13809
- static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
13810
- SDNode *Root, const NodeExtensionHelper &LHS,
13811
- const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
13812
- bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
13813
- assert((AllowSExt || AllowZExt || AllowFPExt) &&
13814
- "Forgot to set what you want?");
13806
+ static std::optional<CombineResult>
13807
+ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
13808
+ const NodeExtensionHelper &RHS,
13809
+ uint8_t AllowExtMask, SelectionDAG &DAG,
13810
+ const RISCVSubtarget &Subtarget) {
13815
13811
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
13816
13812
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
13817
13813
return std::nullopt;
13818
- if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13814
+ if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13819
13815
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13820
13816
Root->getOpcode(), ExtKind::ZExt),
13821
13817
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
13822
13818
/*RHSExt=*/{ExtKind::ZExt});
13823
- if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
13819
+ if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
13824
13820
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13825
13821
Root->getOpcode(), ExtKind::SExt),
13826
13822
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
13827
13823
/*RHSExt=*/{ExtKind::SExt});
13828
- if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
13824
+ if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
13829
13825
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
13830
13826
Root->getOpcode(), ExtKind::FPExt),
13831
13827
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
@@ -13843,9 +13839,9 @@ static std::optional<CombineResult>
13843
13839
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
13844
13840
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13845
13841
const RISCVSubtarget &Subtarget) {
13846
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13847
- /*AllowZExt=*/true ,
13848
- /*AllowFPExt=*/true, DAG, Subtarget);
13842
+ return canFoldToVWWithSameExtensionImpl(
13843
+ Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG ,
13844
+ Subtarget);
13849
13845
}
13850
13846
13851
13847
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13887,9 +13883,8 @@ static std::optional<CombineResult>
13887
13883
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13888
13884
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13889
13885
const RISCVSubtarget &Subtarget) {
13890
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13891
- /*AllowZExt=*/false,
13892
- /*AllowFPExt=*/false, DAG, Subtarget);
13886
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
13887
+ Subtarget);
13893
13888
}
13894
13889
13895
13890
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13900,9 +13895,8 @@ static std::optional<CombineResult>
13900
13895
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13901
13896
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13902
13897
const RISCVSubtarget &Subtarget) {
13903
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13904
- /*AllowZExt=*/true,
13905
- /*AllowFPExt=*/false, DAG, Subtarget);
13898
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
13899
+ Subtarget);
13906
13900
}
13907
13901
13908
13902
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
@@ -13913,9 +13907,8 @@ static std::optional<CombineResult>
13913
13907
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
13914
13908
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
13915
13909
const RISCVSubtarget &Subtarget) {
13916
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13917
- /*AllowZExt=*/false,
13918
- /*AllowFPExt=*/true, DAG, Subtarget);
13910
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
13911
+ Subtarget);
13919
13912
}
13920
13913
13921
13914
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
0 commit comments