Skip to content

Commit ffdcb0d

Browse files
committed
[PatternMatching] Add generic API for matching constants using custom conditions
The new API is: `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)` - Matches non-undef constants s.t `Lambda(ele)` is true for all elements. `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)` - Matches constants/undef s.t `Lambda(ele)` is true for all elements. The goal with these is to be able to replace the common usage of: ``` match(X, m_APInt(C)) && CustomCheck(C) ``` with ``` match(X, m_CheckedInt(C, CustomChecks); ``` The rationale if we often ignore non-splat vectors because there are no good APIs to handle them with and its not worth increasing code complexity for such cases. The hope is the API creates a common method handling scalars/splat-vecs/non-splat-vecs to essentially make this a non-issue.
1 parent 3005ca2 commit ffdcb0d

File tree

2 files changed

+288
-0
lines changed

2 files changed

+288
-0
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,39 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
460460
//
461461
///////////////////////////////////////////////////////////////////////////////
462462

463+
template <typename APTy> struct custom_checkfn {
464+
function_ref<bool(const APTy &)> CheckFn;
465+
bool isValue(const APTy &C) { return CheckFn(C); }
466+
};
467+
468+
// Match and integer or vector where CheckFn(ele) for each element is true.
469+
// For vectors, poison elements are assumed to match.
470+
inline cst_pred_ty<custom_checkfn<APInt>>
471+
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
472+
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
473+
}
474+
475+
inline api_pred_ty<custom_checkfn<APInt>>
476+
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
477+
api_pred_ty<custom_checkfn<APInt>> P(V);
478+
P.CheckFn = CheckFn;
479+
return P;
480+
}
481+
482+
// Match and float or vector where CheckFn(ele) for each element is true.
483+
// For vectors, poison elements are assumed to match.
484+
inline cstfp_pred_ty<custom_checkfn<APFloat>>
485+
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
486+
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
487+
}
488+
489+
inline apf_pred_ty<custom_checkfn<APFloat>>
490+
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
491+
apf_pred_ty<custom_checkfn<APFloat>> P(V);
492+
P.CheckFn = CheckFn;
493+
return P;
494+
}
495+
463496
struct is_any_apint {
464497
bool isValue(const APInt &C) { return true; }
465498
};

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,212 @@ TEST_F(PatternMatchTest, BitCast) {
611611
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
612612
}
613613

614+
TEST_F(PatternMatchTest, CustomCheckFn) {
615+
APInt I0(64, 0);
616+
APInt I1(64, 0);
617+
618+
auto CheckIsZeroI = [](const APInt &C) { return C.isZero(); };
619+
auto CheckIsEqI1 = [&I1](const APInt &C) { return C.eq(I1); };
620+
auto CheckIsNeI1 = [&I1](const APInt &C) { return !C.eq(I1); };
621+
622+
custom_checkfn<APInt> CustomCheckZeroI;
623+
CustomCheckZeroI.CheckFn = CheckIsZeroI;
624+
custom_checkfn<APInt> CustomCheckEqI1;
625+
CustomCheckEqI1.CheckFn = CheckIsEqI1;
626+
custom_checkfn<APInt> CustomCheckNeI1;
627+
CustomCheckNeI1.CheckFn = CheckIsNeI1;
628+
629+
EXPECT_TRUE(CustomCheckZeroI.isValue(I0));
630+
EXPECT_TRUE(CustomCheckEqI1.isValue(I0));
631+
EXPECT_FALSE(CustomCheckNeI1.isValue(I0));
632+
633+
I0.setBit(0);
634+
635+
EXPECT_FALSE(CustomCheckZeroI.isValue(I0));
636+
EXPECT_FALSE(CustomCheckEqI1.isValue(I0));
637+
EXPECT_TRUE(CustomCheckNeI1.isValue(I0));
638+
639+
I1.setBit(0);
640+
641+
EXPECT_FALSE(CustomCheckZeroI.isValue(I0));
642+
EXPECT_TRUE(CustomCheckEqI1.isValue(I0));
643+
EXPECT_FALSE(CustomCheckNeI1.isValue(I0));
644+
645+
APFloat F0(0.0);
646+
APFloat F1(0.0);
647+
648+
auto CheckIsZeroF = [](const APFloat &C) { return C.isZero(); };
649+
auto CheckIsEqF1 = [&F1](const APFloat &C) {
650+
return C.bitcastToAPInt().eq(F1.bitcastToAPInt());
651+
};
652+
auto CheckIsNeF1 = [&F1](const APFloat &C) {
653+
return !C.bitcastToAPInt().eq(F1.bitcastToAPInt());
654+
};
655+
656+
custom_checkfn<APFloat> CustomCheckZeroF;
657+
CustomCheckZeroF.CheckFn = CheckIsZeroF;
658+
custom_checkfn<APFloat> CustomCheckEqF1;
659+
CustomCheckEqF1.CheckFn = CheckIsEqF1;
660+
custom_checkfn<APFloat> CustomCheckNeF1;
661+
CustomCheckNeF1.CheckFn = CheckIsNeF1;
662+
663+
EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
664+
EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
665+
EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
666+
667+
F0 = -F0;
668+
669+
EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
670+
EXPECT_FALSE(CustomCheckEqF1.isValue(F0));
671+
EXPECT_TRUE(CustomCheckNeF1.isValue(F0));
672+
673+
F0 = -F0;
674+
675+
EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
676+
EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
677+
EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
678+
679+
F0 = F0 + APFloat(1.0);
680+
681+
EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
682+
EXPECT_FALSE(CustomCheckEqF1.isValue(F0));
683+
EXPECT_TRUE(CustomCheckNeF1.isValue(F0));
684+
685+
F1 = F1 + APFloat(1.0);
686+
687+
EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
688+
EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
689+
EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
690+
}
691+
692+
TEST_F(PatternMatchTest, CheckedInt) {
693+
Type *I8Ty = IRB.getInt8Ty();
694+
const APInt *Res = nullptr;
695+
696+
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
697+
auto CheckTrue = [](const APInt &) { return true; };
698+
auto CheckFalse = [](const APInt &) { return false; };
699+
auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
700+
auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
701+
702+
auto DoScalarCheck = [&](int8_t Val) {
703+
APInt APVal(8, Val);
704+
Constant *C = ConstantInt::get(I8Ty, Val);
705+
706+
Res = nullptr;
707+
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
708+
EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
709+
EXPECT_EQ(*Res, APVal);
710+
711+
Res = nullptr;
712+
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
713+
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
714+
715+
Res = nullptr;
716+
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
717+
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
718+
if (CheckUgt1(APVal)) {
719+
EXPECT_NE(Res, nullptr);
720+
EXPECT_EQ(*Res, APVal);
721+
}
722+
723+
Res = nullptr;
724+
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
725+
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
726+
if (CheckNonZero(APVal)) {
727+
EXPECT_NE(Res, nullptr);
728+
EXPECT_EQ(*Res, APVal);
729+
}
730+
731+
Res = nullptr;
732+
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
733+
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
734+
if (CheckPow2(APVal)) {
735+
EXPECT_NE(Res, nullptr);
736+
EXPECT_EQ(*Res, APVal);
737+
}
738+
739+
};
740+
741+
DoScalarCheck(0);
742+
DoScalarCheck(1);
743+
DoScalarCheck(2);
744+
DoScalarCheck(3);
745+
746+
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
747+
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
748+
EXPECT_EQ(Res, nullptr);
749+
750+
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
751+
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
752+
EXPECT_EQ(Res, nullptr);
753+
754+
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
755+
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
756+
EXPECT_EQ(Res, nullptr);
757+
758+
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
759+
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
760+
EXPECT_EQ(Res, nullptr);
761+
762+
auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
763+
function_ref<bool(const APInt &)> CheckFn,
764+
bool UndefAsPoison) {
765+
SmallVector<Constant *> VecElems;
766+
std::optional<bool> Okay;
767+
bool AllSame = true;
768+
bool HasUndef = false;
769+
std::optional<APInt> First;
770+
for (const std::optional<int8_t> &Val : Vals) {
771+
if (!Val.has_value()) {
772+
VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
773+
: UndefValue::get(I8Ty));
774+
HasUndef = true;
775+
} else {
776+
if (!Okay.has_value())
777+
Okay = true;
778+
APInt APVal(8, *Val);
779+
if (!First.has_value())
780+
First = APVal;
781+
else
782+
AllSame &= First->eq(APVal);
783+
Okay = *Okay && CheckFn(APVal);
784+
VecElems.push_back(ConstantInt::get(I8Ty, *Val));
785+
}
786+
}
787+
788+
Constant *C = ConstantVector::get(VecElems);
789+
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
790+
m_CheckedInt(CheckFn).match(C));
791+
792+
Res = nullptr;
793+
bool Expec =
794+
!(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
795+
EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
796+
if (Expec) {
797+
EXPECT_NE(Res, nullptr);
798+
EXPECT_EQ(*Res, *First);
799+
}
800+
};
801+
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
802+
DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
803+
DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
804+
DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
805+
DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
806+
DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
807+
DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
808+
DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
809+
};
810+
811+
DoVecCheck({0, 1});
812+
DoVecCheck({1, 1});
813+
DoVecCheck({1, 2});
814+
DoVecCheck({1, std::nullopt});
815+
DoVecCheck({1, std::nullopt, 1});
816+
DoVecCheck({1, std::nullopt, 2});
817+
DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
818+
}
819+
614820
TEST_F(PatternMatchTest, Power2) {
615821
Value *C128 = IRB.getInt32(128);
616822
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1397,21 +1603,58 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
13971603
EXPECT_FALSE(match(VectorInfPoison, m_Finite()));
13981604
EXPECT_FALSE(match(VectorNaNPoison, m_Finite()));
13991605

1606+
auto CheckTrue = [](const APFloat &) { return true; };
1607+
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
1608+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckTrue)));
1609+
EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
1610+
EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
1611+
EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
1612+
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
1613+
EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckTrue)));
1614+
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
1615+
EXPECT_TRUE(match(VectorNaNPoison, m_CheckedFp(CheckTrue)));
1616+
1617+
auto CheckFalse = [](const APFloat &) { return false; };
1618+
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
1619+
EXPECT_FALSE(match(VectorZeroPoison, m_CheckedFp(CheckFalse)));
1620+
EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
1621+
EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
1622+
EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
1623+
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
1624+
EXPECT_FALSE(match(VectorInfPoison, m_CheckedFp(CheckFalse)));
1625+
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
1626+
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckFalse)));
1627+
1628+
auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
1629+
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
1630+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckNonNaN)));
1631+
EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
1632+
EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
1633+
EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
1634+
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
1635+
EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckNonNaN)));
1636+
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
1637+
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
1638+
14001639
const APFloat *C;
14011640
// Regardless of whether poison is allowed,
14021641
// a fully undef/poison constant does not match.
14031642
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
14041643
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
14051644
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
1645+
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
14061646
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
14071647
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
14081648
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
1649+
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
14091650
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
14101651
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
14111652
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
1653+
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
14121654
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
14131655
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
14141656
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
1657+
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
14151658

14161659
// We can always match simple constants and simple splats.
14171660
C = nullptr;
@@ -1432,6 +1675,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
14321675
C = nullptr;
14331676
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
14341677
EXPECT_TRUE(C->isZero());
1678+
C = nullptr;
1679+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
1680+
EXPECT_TRUE(C->isZero());
1681+
C = nullptr;
1682+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
1683+
EXPECT_TRUE(C->isZero());
14351684

14361685
// Splats with undef are never allowed.
14371686
// Whether splats with poison can be matched depends on the matcher.
@@ -1456,6 +1705,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
14561705
C = nullptr;
14571706
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
14581707
EXPECT_TRUE(C->isZero());
1708+
C = nullptr;
1709+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
1710+
EXPECT_TRUE(C->isZero());
1711+
C = nullptr;
1712+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
1713+
EXPECT_TRUE(C->isZero());
14591714
}
14601715

14611716
TEST_F(PatternMatchTest, FloatingPointFNeg) {

0 commit comments

Comments
 (0)