-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[PatternMatching] Add generic API for matching constants using custom conditions #85676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-ir Author: None (goldsteinn) Changes
Patch is 26.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85676.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..4333d3e6e8da2a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
- if (isa<UndefValue>(Elt))
+ if (isa<UndefValue>(Elt)) {
+ if (!AllowUndefs)
+ return false;
continue;
+ }
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
};
/// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
/// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
const APInt *&Res;
api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
const APFloat *&Res;
apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantFP>(
- C->getSplatValue(/* AllowUndef */ true)))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
//
///////////////////////////////////////////////////////////////////////////////
+template <typename APTy> struct custom_checkfn {
+ function_ref<bool(const APTy &)> CheckFn;
+ bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+ function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+ function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
struct is_any_apint {
bool isValue(const APInt &C) { return true; }
};
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (*CmpC == Op0Min + 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- // X <u C --> X == 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- // X >u C --> X != 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
- return new ICmpInst(ICmpInst::ICMP_NE, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- }
+ // A <s C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- }
+ // A >s C -> A == C+1 if max(A)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
break;
}
}
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
}
+TEST_F(PatternMatchTest, CheckedInt) {
+ Type *I8Ty = IRB.getInt8Ty();
+ const APInt *Res = nullptr;
+
+ auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+ auto CheckTrue = [](const APInt &) { return true; };
+ auto CheckFalse = [](const APInt &) { return false; };
+ auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+ auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+ auto DoScalarCheck = [&](int8_t Val) {
+ APInt APVal(8, Val);
+ Constant *C = ConstantInt::get(I8Ty, Val);
+
+ Res = nullptr;
+ EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+ EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+ EXPECT_EQ(*Res, APVal);
+
+ Res = nullptr;
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal),
+ m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal),
+ m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+ };
+
+ DoScalarCheck(0);
+ DoScalarCheck(1);
+ DoScalarCheck(2);
+ DoScalarCheck(3);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+ function_ref<bool(const APInt &)> CheckFn,
+ bool UndefAsPoison) {
+ SmallVector<Constant *> VecElems;
+ std::optional<bool> Okay;
+ bool AllSame = true;
+ bool HasUndef = false;
+ std::optional<APInt> First;
+ for (const std::optional<int8_t> &Val : Vals) {
+ if (!Val.has_value()) {
+ VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+ : UndefValue::get(I8Ty));
+ HasUndef = true;
+ } else {
+ if (!Okay.has_value())
+ Okay = true;
+ APInt APVal(8, *Val);
+ if (!First.has_value())
+ First = APVal;
+ else
+ AllSame &= First->eq(APVal);
+ Okay = *Okay && CheckFn(APVal);
+ VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+ }
+ }
+
+ Constant *C = ConstantVector::get(VecElems);
+ EXPECT_EQ(!HasUndef && Okay.value_or(false),
+ m_CheckedInt(CheckFn).match(C));
+ EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+ Res = nullptr;
+ bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+
+ Res = nullptr;
+ Expec = AllSame && Okay.value_or(f...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff a12622543de15df45fb9ad64e8ab723289d55169 436c5a2925171e6384e508893c96347c401b94ed -- llvm/include/llvm/IR/PatternMatch.h llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/unittests/IR/PatternMatch.cpp View the diff from clang-format here.diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index d5a4a6a056..a436d19c0d 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -657,7 +657,6 @@ TEST_F(PatternMatchTest, CheckedInt) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}
-
};
DoScalarCheck(0);
|
6a6c35f
to
d4de565
Compare
d4de565
to
4d7774d
Compare
ping |
Please check the compile-time impact of this change. |
Think basically no impact, maybe a slight regression stage2-O0 and slight improvement stage2-O3, but looks to be within err range and a bit scattered. |
ping |
llvm/include/llvm/IR/PatternMatch.h
Outdated
inline api_or_undef_pred_ty<custom_checkfn<APInt>> | ||
m_CheckedIntAllowUndef(const APInt *&V, | ||
function_ref<bool(const APInt &)> CheckFn) { | ||
api_or_undef_pred_ty<custom_checkfn<APInt>> P(V); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pass CheckFn
into the constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No (other functions also do it like this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They might, but why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe there is a syntax for it, but don't know the syntax for constructing fields in inherited + base type.
4d7774d
to
0f5619a
Compare
ping |
I'd like to defer this until after #88217, which would address my primary concern with this (the ever present undef footgun). |
Actually any issue getting just the default APIs (not supporting |
0f5619a
to
ea43726
Compare
rebased |
ea43726
to
e4b9b1e
Compare
llvm/unittests/IR/PatternMatch.cpp
Outdated
EXPECT_FALSE(CustomCheckZeroF.isValue(F0)); | ||
EXPECT_TRUE(CustomCheckEqF1.isValue(F0)); | ||
EXPECT_FALSE(CustomCheckNeF1.isValue(F0)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I'd consider custom_checkfn to be a PatternMatch implementation detail -- do we really need unit tests for it if you already have tests for m_CheckedInt below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dtcxzyw requested them. I'm ambivalent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need unit tests for it if you already have tests for m_CheckedInt below?
The version I reviewed 4d7774d doesn't include tests for m_CheckedInt.
It is ok to remove tests for CustomCheckFn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, dropped these tests.
… conditions The new API is: `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)` - Matches non-undef constants s.t `Lambda(ele)` is true for all elements. `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)` - Matches constants/undef s.t `Lambda(ele)` is true for all elements. The goal with these is to be able to replace the common usage of: ``` match(X, m_APInt(C)) && CustomCheck(C) ``` with ``` match(X, m_CheckedInt(C, CustomChecks); ``` The rationale if we often ignore non-splat vectors because there are no good APIs to handle them with and its not worth increasing code complexity for such cases. The hope is the API creates a common method handling scalars/splat-vecs/non-splat-vecs to essentially make this a non-issue.
There is no real motivation for this change other than to highlight a case where the new `Checked` matcher API can handle non-splat-vecs without increasing code complexity.
e4b9b1e
to
9705962
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
%v = ashr <4 x i32> %x, <i32 1, i32 2, i32 3, i32 4> | ||
%r = icmp sge <4 x i32> %v, %x | ||
ret <4 x i1> %r | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add a negative non-splat test that does not pass the check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already had non-splat negative tests (i.e udiv_x_by_const_cmp_eq_value_neg
). In retrospect probably should have had those tests be scalars as they didn't really check the old logic, but we can use them here.
} | ||
|
||
inline api_pred_ty<custom_checkfn<APInt>> | ||
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again just now, what's the point of these methods returning a const APInt *
? In that case, isn't this limited to splats anyway and you could just as well use m_APInt
+ a check and save you the lambda?
I think this should be providing the result as a Constant *
, which would support non-splat values matching the predicate...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed this... yeah think you are right. Should be a constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting the norm case seems to be to always export these as APInt
, if you looks at something like m_Negative
/m_NonNegative
its the same deal.
Any objection to me replacing all of these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming there are not a lot of uses of them, replacing them should be fine. Something to be careful about is that the direct replacement here would be m_APIntAllowPoison, not m_APInt.
If there are many uses, it may make sense to retain them as a convenience (I don't think m_CheckedInt with APInt makes sense as a convenience though, it's strictly more complicated than just writing out the check).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See: #91377 decided not to do the other. Didn't seem like much upside and made code less compact.
Checked
matcher APIThe new API is:
m_CheckedInt(Lambda)
/m_CheckedFp(Lambda)
- Matches non-undef constants s.t
Lambda(ele)
is true for allelements.
m_CheckedIntAllowUndef(Lambda)
/m_CheckedFpAllowUndef(Lambda)
- Matches constants/undef s.t
Lambda(ele)
is true for allelements.
The goal with these is to be able to replace the common usage of:
with
The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.
The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.