Skip to content

[PatternMatching] Add generic API for matching constants using custom conditions #85676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Mar 18, 2024

  • [PatternMatching] Add generic API for matching constants using custom conditions
  • [InstCombine] Add example usage for new Checked matcher API

The new API is:
m_CheckedInt(Lambda)/m_CheckedFp(Lambda)
- Matches non-undef constants s.t Lambda(ele) is true for all
elements.
m_CheckedIntAllowUndef(Lambda)/m_CheckedFpAllowUndef(Lambda)
- Matches constants/undef s.t Lambda(ele) is true for all
elements.

The goal with these is to be able to replace the common usage of:

    match(X, m_APInt(C)) && CustomCheck(C)

with

    match(X, m_CheckedInt(C, CustomChecks);

The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.

The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.

@goldsteinn goldsteinn requested a review from nikic as a code owner March 18, 2024 18:14
@goldsteinn goldsteinn changed the title goldsteinn/pattern match api [PatternMatching] Add generic API for matching constants using custom conditions Mar 18, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 18, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: None (goldsteinn)

Changes
  • [PatternMatching] Add generic API for matching constants using custom conditions
  • [InstCombine] Add example usage for new Checked matcher API

Patch is 26.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85676.diff

4 Files Affected:

  • (modified) llvm/include/llvm/IR/PatternMatch.h (+80-11)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+30-36)
  • (modified) llvm/test/Transforms/InstCombine/signed-truncation-check.ll (+6-24)
  • (modified) llvm/unittests/IR/PatternMatch.cpp (+240)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..4333d3e6e8da2a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
 /// This helper class is used to match constant scalars, vector splats,
 /// and fixed width vectors that satisfy a specified predicate.
 /// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
 struct cstval_pred_ty : public Predicate {
   template <typename ITy> bool match(ITy *V) {
     if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
           Constant *Elt = C->getAggregateElement(i);
           if (!Elt)
             return false;
-          if (isa<UndefValue>(Elt))
+          if (isa<UndefValue>(Elt)) {
+            if (!AllowUndefs)
+              return false;
             continue;
+          }
           auto *CV = dyn_cast<ConstantVal>(Elt);
           if (!CV || !this->isValue(CV->getValue()))
             return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
 
 /// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
 
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
   const APInt *&Res;
 
   api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+        if (auto *CI =
+                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APFloat.
 /// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
   const APFloat *&Res;
 
   apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantFP>(
-                C->getSplatValue(/* AllowUndef */ true)))
+        if (auto *CI =
+                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
 //
 ///////////////////////////////////////////////////////////////////////////////
 
+template <typename APTy> struct custom_checkfn {
+  function_ref<bool(const APTy &)> CheckFn;
+  bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>, false> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+                       function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+                    function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
 struct is_any_apint {
   bool isValue(const APInt &C) { return true; }
 };
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
     case ICmpInst::ICMP_ULT: {
       if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A <u C -> A == C-1 if min(A)+1 == C
-        if (*CmpC == Op0Min + 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC - 1));
-        // X <u C --> X == 0, if the number of zero bits in the bottom of X
-        // exceeds the log2 of C.
-        if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              Constant::getNullValue(Op1->getType()));
-      }
+      // A <u C -> A == C-1 if min(A)+1 == C
+      if (match(Op1, m_SpecificInt(Op0Min + 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Min));
+      // X <u C --> X == 0, if the number of zero bits in the bottom of X
+      // exceeds the log2 of C.
+      if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+                  return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+                })))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            Constant::getNullValue(Op1->getType()));
       break;
     }
     case ICmpInst::ICMP_UGT: {
       if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A >u C -> A == C+1 if max(a)-1 == C
-        if (*CmpC == Op0Max - 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-        // X >u C --> X != 0, if the number of zero bits in the bottom of X
-        // exceeds the log2 of C.
-        if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
-          return new ICmpInst(ICmpInst::ICMP_NE, Op0,
-                              Constant::getNullValue(Op1->getType()));
-      }
+      // A >u C -> A == C+1 if max(a)-1 == C
+      if (match(Op1, m_SpecificInt(Op0Max - 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Max));
+      // X >u C --> X != 0, if the number of zero bits in the bottom of X
+      // exceeds the log2 of C.
+      if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+                  return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+                })))
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+                            Constant::getNullValue(Op1->getType()));
       break;
     }
     case ICmpInst::ICMP_SLT: {
       if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC - 1));
-      }
+      // A <s C -> A == C-1 if min(A)+1 == C
+      if (match(Op1, m_SpecificInt(Op0Min + 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Min));
       break;
     }
     case ICmpInst::ICMP_SGT: {
       if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-      }
+      // A >s C -> A == C+1 if max(A)-1 == C
+      if (match(Op1, m_SpecificInt(Op0Max - 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Max));
       break;
     }
     }
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
   EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
 }
 
+TEST_F(PatternMatchTest, CheckedInt) {
+  Type *I8Ty = IRB.getInt8Ty();
+  const APInt *Res = nullptr;
+
+  auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+  auto CheckTrue = [](const APInt &) { return true; };
+  auto CheckFalse = [](const APInt &) { return false; };
+  auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+  auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+  auto DoScalarCheck = [&](int8_t Val) {
+    APInt APVal(8, Val);
+    Constant *C = ConstantInt::get(I8Ty, Val);
+
+    Res = nullptr;
+    EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+    EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+    EXPECT_EQ(*Res, APVal);
+
+    Res = nullptr;
+    EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+    EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal),
+              m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal),
+              m_CheckedIntAllowUndef(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal),
+              m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal),
+              m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+  };
+
+  DoScalarCheck(0);
+  DoScalarCheck(1);
+  DoScalarCheck(2);
+  DoScalarCheck(3);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+                            function_ref<bool(const APInt &)> CheckFn,
+                            bool UndefAsPoison) {
+    SmallVector<Constant *> VecElems;
+    std::optional<bool> Okay;
+    bool AllSame = true;
+    bool HasUndef = false;
+    std::optional<APInt> First;
+    for (const std::optional<int8_t> &Val : Vals) {
+      if (!Val.has_value()) {
+        VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+                                         : UndefValue::get(I8Ty));
+        HasUndef = true;
+      } else {
+        if (!Okay.has_value())
+          Okay = true;
+        APInt APVal(8, *Val);
+        if (!First.has_value())
+          First = APVal;
+        else
+          AllSame &= First->eq(APVal);
+        Okay = *Okay && CheckFn(APVal);
+        VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+      }
+    }
+
+    Constant *C = ConstantVector::get(VecElems);
+    EXPECT_EQ(!HasUndef && Okay.value_or(false),
+              m_CheckedInt(CheckFn).match(C));
+    EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+    Res = nullptr;
+    bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+    if (Expec) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, *First);
+    }
+
+    Res = nullptr;
+    Expec = AllSame && Okay.value_or(f...
[truncated]

Copy link

github-actions bot commented Mar 18, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff a12622543de15df45fb9ad64e8ab723289d55169 436c5a2925171e6384e508893c96347c401b94ed -- llvm/include/llvm/IR/PatternMatch.h llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/unittests/IR/PatternMatch.cpp
View the diff from clang-format here.
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index d5a4a6a056..a436d19c0d 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -657,7 +657,6 @@ TEST_F(PatternMatchTest, CheckedInt) {
       EXPECT_NE(Res, nullptr);
       EXPECT_EQ(*Res, APVal);
     }
-
   };
 
   DoScalarCheck(0);

@goldsteinn goldsteinn force-pushed the goldsteinn/pattern-match-api branch from 6a6c35f to d4de565 Compare March 18, 2024 18:42
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 19, 2024
@goldsteinn
Copy link
Contributor Author

ping

@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 24, 2024

Please check the compile-time impact of this change.

@goldsteinn
Copy link
Contributor Author

Please check the compile-time impact of this change.

https://llvm-compile-time-tracker.com/compare.php?from=817f453aa576286aaca0a6b0244e6ab08516b80c&to=05871ee629407932bc6576224dcc2ae99db473fa&stat=instructions:u

Think basically no impact, maybe a slight regression stage2-O0 and slight improvement stage2-O3, but looks to be within err range and a bit scattered.

@goldsteinn
Copy link
Contributor Author

ping

inline api_or_undef_pred_ty<custom_checkfn<APInt>>
m_CheckedIntAllowUndef(const APInt *&V,
function_ref<bool(const APInt &)> CheckFn) {
api_or_undef_pred_ty<custom_checkfn<APInt>> P(V);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass CheckFn into the constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No (other functions also do it like this).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They might, but why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there is a syntax for it, but don't know the syntax for constructing fields in inherited + base type.

@goldsteinn goldsteinn force-pushed the goldsteinn/pattern-match-api branch from 4d7774d to 0f5619a Compare March 30, 2024 01:52
@goldsteinn
Copy link
Contributor Author

ping

@nikic
Copy link
Contributor

nikic commented Apr 10, 2024

I'd like to defer this until after #88217, which would address my primary concern with this (the ever present undef footgun).

@goldsteinn
Copy link
Contributor Author

I'd like to defer this until after #88217, which would address my primary concern with this (the ever present undef footgun).

Actually any issue getting just the default APIs (not supporting undef) in now? Don't imagine it will get in your way. If you think it will be a problem though, this can ofc be deferred.

@goldsteinn goldsteinn force-pushed the goldsteinn/pattern-match-api branch from 0f5619a to ea43726 Compare April 23, 2024 01:16
@goldsteinn
Copy link
Contributor Author

rebased

@goldsteinn goldsteinn force-pushed the goldsteinn/pattern-match-api branch from ea43726 to e4b9b1e Compare April 25, 2024 21:29
EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'd consider custom_checkfn to be a PatternMatch implementation detail -- do we really need unit tests for it if you already have tests for m_CheckedInt below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dtcxzyw requested them. I'm ambivalent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need unit tests for it if you already have tests for m_CheckedInt below?

The version I reviewed 4d7774d doesn't include tests for m_CheckedInt.

It is ok to remove tests for CustomCheckFn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, dropped these tests.

… conditions

The new API is:
    `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
        - Matches non-undef constants s.t `Lambda(ele)` is true for all
          elements.
    `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
        - Matches constants/undef s.t `Lambda(ele)` is true for all
          elements.

The goal with these is to be able to replace the common usage of:
```
    match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
    match(X, m_CheckedInt(C, CustomChecks);
```

The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.

The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
goldsteinn added 2 commits May 2, 2024 12:09
There is no real motivation for this change other than to highlight a
case where the new `Checked` matcher API can handle non-splat-vecs
without increasing code complexity.
@goldsteinn goldsteinn force-pushed the goldsteinn/pattern-match-api branch from e4b9b1e to 9705962 Compare May 2, 2024 17:15
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

%v = ashr <4 x i32> %x, <i32 1, i32 2, i32 3, i32 4>
%r = icmp sge <4 x i32> %v, %x
ret <4 x i1> %r
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a negative non-splat test that does not pass the check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already had non-splat negative tests (i.e udiv_x_by_const_cmp_eq_value_neg). In retrospect probably should have had those tests be scalars as they didn't really check the old logic, but we can use them here.

}

inline api_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again just now, what's the point of these methods returning a const APInt *? In that case, isn't this limited to splats anyway and you could just as well use m_APInt + a check and save you the lambda?

I think this should be providing the result as a Constant *, which would support non-splat values matching the predicate...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this... yeah think you are right. Should be a constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting the norm case seems to be to always export these as APInt, if you looks at something like m_Negative/m_NonNegative its the same deal.

Any objection to me replacing all of these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming there are not a lot of uses of them, replacing them should be fine. Something to be careful about is that the direct replacement here would be m_APIntAllowPoison, not m_APInt.

If there are many uses, it may make sense to retain them as a convenience (I don't think m_CheckedInt with APInt makes sense as a convenience though, it's strictly more complicated than just writing out the check).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See: #91377 decided not to do the other. Didn't seem like much upside and made code less compact.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants