Skip to content

Commit a78b248

Browse files
committed
[PatternMatching] Add generic API for matching constants using custom conditions
The new API is: `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)` - Matches non-undef constants s.t `Lambda(ele)` is true for all elements. `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)` - Matches constants/undef s.t `Lambda(ele)` is true for all elements. The goal with these is to be able to replace the common usage of: ``` match(X, m_APInt(C)) && CustomCheck(C) ``` with ``` match(X, m_CheckedInt(C, CustomChecks); ``` The rationale if we often ignore non-splat vectors because there are no good APIs to handle them with and its not worth increasing code complexity for such cases. The hope is the API creates a common method handling scalars/splat-vecs/non-splat-vecs to essentially make this a non-issue.
1 parent a1fb514 commit a78b248

File tree

2 files changed

+320
-11
lines changed

2 files changed

+320
-11
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
346346
/// This helper class is used to match constant scalars, vector splats,
347347
/// and fixed width vectors that satisfy a specified predicate.
348348
/// For fixed width vector constants, undefined elements are ignored.
349-
template <typename Predicate, typename ConstantVal>
349+
template <typename Predicate, typename ConstantVal, bool AllowUndefs>
350350
struct cstval_pred_ty : public Predicate {
351351
template <typename ITy> bool match(ITy *V) {
352352
if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
369369
Constant *Elt = C->getAggregateElement(i);
370370
if (!Elt)
371371
return false;
372-
if (isa<UndefValue>(Elt))
372+
if (isa<UndefValue>(Elt)) {
373+
if (!AllowUndefs)
374+
return false;
373375
continue;
376+
}
374377
auto *CV = dyn_cast<ConstantVal>(Elt);
375378
if (!CV || !this->isValue(CV->getValue()))
376379
return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
384387
};
385388

386389
/// specialization of cstval_pred_ty for ConstantInt
387-
template <typename Predicate>
388-
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
390+
template <typename Predicate, bool AllowUndefs = true>
391+
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
389392

390393
/// specialization of cstval_pred_ty for ConstantFP
391-
template <typename Predicate>
392-
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
394+
template <typename Predicate, bool AllowUndefs = true>
395+
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
393396

394397
/// This helper class is used to match scalar and vector constants that
395398
/// satisfy a specified predicate, and bind them to an APInt.
396-
template <typename Predicate> struct api_pred_ty : public Predicate {
399+
template <typename Predicate, bool AllowUndefs = true>
400+
struct api_pred_ty : public Predicate {
397401
const APInt *&Res;
398402

399403
api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
406410
}
407411
if (V->getType()->isVectorTy())
408412
if (const auto *C = dyn_cast<Constant>(V))
409-
if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
413+
if (auto *CI =
414+
dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
410415
if (this->isValue(CI->getValue())) {
411416
Res = &CI->getValue();
412417
return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
419424
/// This helper class is used to match scalar and vector constants that
420425
/// satisfy a specified predicate, and bind them to an APFloat.
421426
/// Undefs are allowed in splat vector constants.
422-
template <typename Predicate> struct apf_pred_ty : public Predicate {
427+
template <typename Predicate, bool AllowUndefs = true>
428+
struct apf_pred_ty : public Predicate {
423429
const APFloat *&Res;
424430

425431
apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
432438
}
433439
if (V->getType()->isVectorTy())
434440
if (const auto *C = dyn_cast<Constant>(V))
435-
if (auto *CI = dyn_cast_or_null<ConstantFP>(
436-
C->getSplatValue(/* AllowUndef */ true)))
441+
if (auto *CI =
442+
dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
437443
if (this->isValue(CI->getValue())) {
438444
Res = &CI->getValue();
439445
return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
452458
//
453459
///////////////////////////////////////////////////////////////////////////////
454460

461+
template <typename APTy> struct custom_checkfn {
462+
function_ref<bool(const APTy &)> CheckFn;
463+
bool isValue(const APTy &C) { return CheckFn(C); }
464+
};
465+
466+
// Match and integer or vector where CheckFn(ele) for each element is true.
467+
// For vectors, undefined elements are assumed NOT to match.
468+
inline cst_pred_ty<custom_checkfn<APInt>, false>
469+
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
470+
return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
471+
}
472+
473+
inline api_pred_ty<custom_checkfn<APInt>, false>
474+
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
475+
api_pred_ty<custom_checkfn<APInt>, false> P(V);
476+
P.CheckFn = CheckFn;
477+
return P;
478+
}
479+
480+
// Match and integer or vector where CheckFn(ele) for each element is true.
481+
// For vectors, undefined elements are assumed to match.
482+
inline cst_pred_ty<custom_checkfn<APInt>>
483+
m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
484+
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
485+
}
486+
487+
inline api_pred_ty<custom_checkfn<APInt>>
488+
m_CheckedIntAllowUndef(const APInt *&V,
489+
function_ref<bool(const APInt &)> CheckFn) {
490+
api_pred_ty<custom_checkfn<APInt>> P(V);
491+
P.CheckFn = CheckFn;
492+
return P;
493+
}
494+
495+
// Match and float or vector where CheckFn(ele) for each element is true.
496+
// For vectors, undefined elements are assumed NOT to match.
497+
inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
498+
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
499+
return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
500+
}
501+
502+
inline apf_pred_ty<custom_checkfn<APFloat>, false>
503+
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
504+
apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
505+
P.CheckFn = CheckFn;
506+
return P;
507+
}
508+
509+
// Match and float or vector where CheckFn(ele) for each element is true.
510+
// For vectors, undefined elements are assumed to match.
511+
inline cstfp_pred_ty<custom_checkfn<APFloat>>
512+
m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
513+
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
514+
}
515+
516+
inline apf_pred_ty<custom_checkfn<APFloat>>
517+
m_CheckedFpAllowUndef(const APFloat *&V,
518+
function_ref<bool(const APFloat &)> CheckFn) {
519+
apf_pred_ty<custom_checkfn<APFloat>> P(V);
520+
P.CheckFn = CheckFn;
521+
return P;
522+
}
523+
455524
struct is_any_apint {
456525
bool isValue(const APInt &C) { return true; }
457526
};

0 commit comments

Comments
 (0)