Skip to content

Commit d8428df

Browse files
committed
[PatternMatching] Add generic API for matching constants using custom conditions
The new API is: `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)` - Matches non-undef constants s.t `Lambda(ele)` is true for all elements. `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)` - Matches constants/undef s.t `Lambda(ele)` is true for all elements. The goal with these is to be able to replace the common usage of: ``` match(X, m_APInt(C)) && CustomCheck(C) ``` with ``` match(X, m_CheckedInt(C, CustomChecks); ``` The rationale if we often ignore non-splat vectors because there are no good APIs to handle them with and its not worth increasing code complexity for such cases. The hope is the API creates a common method handling scalars/splat-vecs/non-splat-vecs to essentially make this a non-issue.
1 parent 285dbed commit d8428df

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,39 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
460460
//
461461
///////////////////////////////////////////////////////////////////////////////
462462

463+
template <typename APTy> struct custom_checkfn {
464+
function_ref<bool(const APTy &)> CheckFn;
465+
bool isValue(const APTy &C) { return CheckFn(C); }
466+
};
467+
468+
/// Match an integer or vector where CheckFn(ele) for each element is true.
469+
/// For vectors, poison elements are assumed to match.
470+
inline cst_pred_ty<custom_checkfn<APInt>>
471+
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
472+
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
473+
}
474+
475+
inline api_pred_ty<custom_checkfn<APInt>>
476+
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
477+
api_pred_ty<custom_checkfn<APInt>> P(V);
478+
P.CheckFn = CheckFn;
479+
return P;
480+
}
481+
482+
/// Match a float or vector where CheckFn(ele) for each element is true.
483+
/// For vectors, poison elements are assumed to match.
484+
inline cstfp_pred_ty<custom_checkfn<APFloat>>
485+
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
486+
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
487+
}
488+
489+
inline apf_pred_ty<custom_checkfn<APFloat>>
490+
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
491+
apf_pred_ty<custom_checkfn<APFloat>> P(V);
492+
P.CheckFn = CheckFn;
493+
return P;
494+
}
495+
463496
struct is_any_apint {
464497
bool isValue(const APInt &C) { return true; }
465498
};

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,134 @@ TEST_F(PatternMatchTest, BitCast) {
611611
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
612612
}
613613

614+
TEST_F(PatternMatchTest, CheckedInt) {
615+
Type *I8Ty = IRB.getInt8Ty();
616+
const APInt *Res = nullptr;
617+
618+
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
619+
auto CheckTrue = [](const APInt &) { return true; };
620+
auto CheckFalse = [](const APInt &) { return false; };
621+
auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
622+
auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
623+
624+
auto DoScalarCheck = [&](int8_t Val) {
625+
APInt APVal(8, Val);
626+
Constant *C = ConstantInt::get(I8Ty, Val);
627+
628+
Res = nullptr;
629+
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
630+
EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
631+
EXPECT_EQ(*Res, APVal);
632+
633+
Res = nullptr;
634+
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
635+
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
636+
637+
Res = nullptr;
638+
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
639+
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
640+
if (CheckUgt1(APVal)) {
641+
EXPECT_NE(Res, nullptr);
642+
EXPECT_EQ(*Res, APVal);
643+
}
644+
645+
Res = nullptr;
646+
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
647+
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
648+
if (CheckNonZero(APVal)) {
649+
EXPECT_NE(Res, nullptr);
650+
EXPECT_EQ(*Res, APVal);
651+
}
652+
653+
Res = nullptr;
654+
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
655+
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
656+
if (CheckPow2(APVal)) {
657+
EXPECT_NE(Res, nullptr);
658+
EXPECT_EQ(*Res, APVal);
659+
}
660+
661+
};
662+
663+
DoScalarCheck(0);
664+
DoScalarCheck(1);
665+
DoScalarCheck(2);
666+
DoScalarCheck(3);
667+
668+
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
669+
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
670+
EXPECT_EQ(Res, nullptr);
671+
672+
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
673+
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
674+
EXPECT_EQ(Res, nullptr);
675+
676+
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
677+
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
678+
EXPECT_EQ(Res, nullptr);
679+
680+
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
681+
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
682+
EXPECT_EQ(Res, nullptr);
683+
684+
auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
685+
function_ref<bool(const APInt &)> CheckFn,
686+
bool UndefAsPoison) {
687+
SmallVector<Constant *> VecElems;
688+
std::optional<bool> Okay;
689+
bool AllSame = true;
690+
bool HasUndef = false;
691+
std::optional<APInt> First;
692+
for (const std::optional<int8_t> &Val : Vals) {
693+
if (!Val.has_value()) {
694+
VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
695+
: UndefValue::get(I8Ty));
696+
HasUndef = true;
697+
} else {
698+
if (!Okay.has_value())
699+
Okay = true;
700+
APInt APVal(8, *Val);
701+
if (!First.has_value())
702+
First = APVal;
703+
else
704+
AllSame &= First->eq(APVal);
705+
Okay = *Okay && CheckFn(APVal);
706+
VecElems.push_back(ConstantInt::get(I8Ty, *Val));
707+
}
708+
}
709+
710+
Constant *C = ConstantVector::get(VecElems);
711+
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
712+
m_CheckedInt(CheckFn).match(C));
713+
714+
Res = nullptr;
715+
bool Expec =
716+
!(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
717+
EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
718+
if (Expec) {
719+
EXPECT_NE(Res, nullptr);
720+
EXPECT_EQ(*Res, *First);
721+
}
722+
};
723+
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
724+
DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
725+
DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
726+
DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
727+
DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
728+
DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
729+
DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
730+
DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
731+
};
732+
733+
DoVecCheck({0, 1});
734+
DoVecCheck({1, 1});
735+
DoVecCheck({1, 2});
736+
DoVecCheck({1, std::nullopt});
737+
DoVecCheck({1, std::nullopt, 1});
738+
DoVecCheck({1, std::nullopt, 2});
739+
DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
740+
}
741+
614742
TEST_F(PatternMatchTest, Power2) {
615743
Value *C128 = IRB.getInt32(128);
616744
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1397,21 +1525,58 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
13971525
EXPECT_FALSE(match(VectorInfPoison, m_Finite()));
13981526
EXPECT_FALSE(match(VectorNaNPoison, m_Finite()));
13991527

1528+
auto CheckTrue = [](const APFloat &) { return true; };
1529+
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
1530+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckTrue)));
1531+
EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
1532+
EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
1533+
EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
1534+
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
1535+
EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckTrue)));
1536+
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
1537+
EXPECT_TRUE(match(VectorNaNPoison, m_CheckedFp(CheckTrue)));
1538+
1539+
auto CheckFalse = [](const APFloat &) { return false; };
1540+
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
1541+
EXPECT_FALSE(match(VectorZeroPoison, m_CheckedFp(CheckFalse)));
1542+
EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
1543+
EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
1544+
EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
1545+
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
1546+
EXPECT_FALSE(match(VectorInfPoison, m_CheckedFp(CheckFalse)));
1547+
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
1548+
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckFalse)));
1549+
1550+
auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
1551+
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
1552+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckNonNaN)));
1553+
EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
1554+
EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
1555+
EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
1556+
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
1557+
EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckNonNaN)));
1558+
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
1559+
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
1560+
14001561
const APFloat *C;
14011562
// Regardless of whether poison is allowed,
14021563
// a fully undef/poison constant does not match.
14031564
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
14041565
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
14051566
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
1567+
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
14061568
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
14071569
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
14081570
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
1571+
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
14091572
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
14101573
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
14111574
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
1575+
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
14121576
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
14131577
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
14141578
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
1579+
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
14151580

14161581
// We can always match simple constants and simple splats.
14171582
C = nullptr;
@@ -1432,6 +1597,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
14321597
C = nullptr;
14331598
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
14341599
EXPECT_TRUE(C->isZero());
1600+
C = nullptr;
1601+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
1602+
EXPECT_TRUE(C->isZero());
1603+
C = nullptr;
1604+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
1605+
EXPECT_TRUE(C->isZero());
14351606

14361607
// Splats with undef are never allowed.
14371608
// Whether splats with poison can be matched depends on the matcher.
@@ -1456,6 +1627,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
14561627
C = nullptr;
14571628
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
14581629
EXPECT_TRUE(C->isZero());
1630+
C = nullptr;
1631+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
1632+
EXPECT_TRUE(C->isZero());
1633+
C = nullptr;
1634+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
1635+
EXPECT_TRUE(C->isZero());
14591636
}
14601637

14611638
TEST_F(PatternMatchTest, FloatingPointFNeg) {

0 commit comments

Comments
 (0)