Skip to content

Commit 93ded43

Browse files
committed
add AllowExtMask
1 parent 22ecc6a commit 93ded43

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13292,8 +13292,7 @@ namespace {
1329213292
// apply a combine.
1329313293
struct CombineResult;
1329413294

13295-
enum class ExtKind { ZExt, SExt, FPExt };
13296-
13295+
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
1329713296
/// Helper class for folding sign/zero extensions.
1329813297
/// In particular, this class is used for the following combines:
1329913298
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -13424,13 +13423,11 @@ struct NodeExtensionHelper {
1342413423
// Determine the narrow size.
1342513424
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1342613425

13427-
unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
13428-
1342913426
MVT EltVT = SupportsExt == ExtKind::FPExt
1343013427
? MVT::getFloatingPointVT(NarrowSize)
1343113428
: MVT::getIntegerVT(NarrowSize);
1343213429

13433-
assert(NarrowSize >= NarrowMinSize &&
13430+
assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
1343413431
"Trying to extend something we can't represent");
1343513432
MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
1343613433
return NarrowVT;
@@ -13799,33 +13796,32 @@ struct CombineResult {
1379913796
/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
1380013797
/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
1380113798
/// are zext) and LHS and RHS can be folded into Root.
13802-
/// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
13799+
/// AllowExtMask define which form `ext` can take in this pattern.
1380313800
///
1380413801
/// \note If the pattern can match with both zext and sext, the returned
1380513802
/// CombineResult will feature the zext result.
1380613803
///
1380713804
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1380813805
/// can be used to apply the pattern.
13809-
static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
13810-
SDNode *Root, const NodeExtensionHelper &LHS,
13811-
const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
13812-
bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
13813-
assert((AllowSExt || AllowZExt || AllowFPExt) &&
13814-
"Forgot to set what you want?");
13806+
static std::optional<CombineResult>
13807+
canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
13808+
const NodeExtensionHelper &RHS,
13809+
uint8_t AllowExtMask, SelectionDAG &DAG,
13810+
const RISCVSubtarget &Subtarget) {
1381513811
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
1381613812
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
1381713813
return std::nullopt;
13818-
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
13814+
if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
1381913815
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
1382013816
Root->getOpcode(), ExtKind::ZExt),
1382113817
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
1382213818
/*RHSExt=*/{ExtKind::ZExt});
13823-
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
13819+
if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
1382413820
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
1382513821
Root->getOpcode(), ExtKind::SExt),
1382613822
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
1382713823
/*RHSExt=*/{ExtKind::SExt});
13828-
if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
13824+
if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
1382913825
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
1383013826
Root->getOpcode(), ExtKind::FPExt),
1383113827
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
@@ -13843,9 +13839,9 @@ static std::optional<CombineResult>
1384313839
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1384413840
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1384513841
const RISCVSubtarget &Subtarget) {
13846-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13847-
/*AllowZExt=*/true,
13848-
/*AllowFPExt=*/true, DAG, Subtarget);
13842+
return canFoldToVWWithSameExtensionImpl(
13843+
Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
13844+
Subtarget);
1384913845
}
1385013846

1385113847
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13887,9 +13883,8 @@ static std::optional<CombineResult>
1388713883
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1388813884
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1388913885
const RISCVSubtarget &Subtarget) {
13890-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
13891-
/*AllowZExt=*/false,
13892-
/*AllowFPExt=*/false, DAG, Subtarget);
13886+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
13887+
Subtarget);
1389313888
}
1389413889

1389513890
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13900,9 +13895,8 @@ static std::optional<CombineResult>
1390013895
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1390113896
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1390213897
const RISCVSubtarget &Subtarget) {
13903-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13904-
/*AllowZExt=*/true,
13905-
/*AllowFPExt=*/false, DAG, Subtarget);
13898+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
13899+
Subtarget);
1390613900
}
1390713901

1390813902
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
@@ -13913,9 +13907,8 @@ static std::optional<CombineResult>
1391313907
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1391413908
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1391513909
const RISCVSubtarget &Subtarget) {
13916-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
13917-
/*AllowZExt=*/false,
13918-
/*AllowFPExt=*/true, DAG, Subtarget);
13910+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
13911+
Subtarget);
1391913912
}
1392013913

1392113914
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))

0 commit comments

Comments
 (0)