Skip to content

Commit a40388b

Browse files
committed
lazy eval of DenormalMode
1 parent 188c97e commit a40388b

File tree

3 files changed

+45
-27
lines changed

3 files changed

+45
-27
lines changed

llvm/include/llvm/ADT/FloatingPointModeUtils.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ static std::tuple<ValueTy, FPClassTest, FPClassTest> exactClass(ValueTy V,
4646
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
4747
/// element will always be LHS.
4848
///
49-
template <typename ValueTy, typename LookThroughFnTy>
49+
template <typename ValueTy, typename LookThroughFnTy,
50+
typename DenormalModeQueryTy>
5051
std::tuple<ValueTy, FPClassTest, FPClassTest>
51-
fcmpImpliesClass(CmpInst::Predicate Pred, DenormalMode Mode, ValueTy LHS,
52-
FPClassTest RHSClass, LookThroughFnTy LookThroughFn) {
52+
fcmpImpliesClass(CmpInst::Predicate Pred, DenormalModeQueryTy ModeQuery,
53+
ValueTy LHS, FPClassTest RHSClass,
54+
LookThroughFnTy LookThroughFn) {
5355
assert(RHSClass != fcNone);
5456

5557
constexpr ValueTy Invalid = {};
@@ -91,6 +93,7 @@ fcmpImpliesClass(CmpInst::Predicate Pred, DenormalMode Mode, ValueTy LHS,
9193
// Compares with fcNone are only exactly equal to fcZero if input denormals
9294
// are not flushed.
9395
// TODO: Handle DAZ by expanding masks to cover subnormal cases.
96+
DenormalMode Mode = ModeQuery(LHS);
9497
if (Mode.Input != DenormalMode::DenormalModeKind::IEEE)
9598
return {Invalid, fcAllFlags, fcAllFlags};
9699

@@ -372,10 +375,12 @@ fcmpImpliesClass(CmpInst::Predicate Pred, DenormalMode Mode, ValueTy LHS,
372375
return {Invalid, fcAllFlags, fcAllFlags};
373376
}
374377

375-
template <typename ValueTy, typename LookThroughFnTy>
378+
template <typename ValueTy, typename LookThroughFnTy,
379+
typename DenormalModeQueryTy>
376380
std::tuple<ValueTy, FPClassTest, FPClassTest>
377-
fcmpImpliesClass(CmpInst::Predicate Pred, DenormalMode Mode, ValueTy LHS,
378-
const APFloat &ConstRHS, LookThroughFnTy LookThroughFn) {
381+
fcmpImpliesClass(CmpInst::Predicate Pred, DenormalModeQueryTy ModeQuery,
382+
ValueTy LHS, const APFloat &ConstRHS,
383+
LookThroughFnTy LookThroughFn) {
379384
// We can refine checks against smallest normal / largest denormal to an
380385
// exact class test.
381386
if (!ConstRHS.isNegative() && ConstRHS.isSmallestNormalized()) {
@@ -409,7 +414,7 @@ fcmpImpliesClass(CmpInst::Predicate Pred, DenormalMode Mode, ValueTy LHS,
409414
break;
410415
}
411416
default:
412-
return fcmpImpliesClass(Pred, Mode, LHS, ConstRHS.classify(),
417+
return fcmpImpliesClass(Pred, ModeQuery, LHS, ConstRHS.classify(),
413418
LookThroughFn);
414419
}
415420

@@ -420,7 +425,8 @@ fcmpImpliesClass(CmpInst::Predicate Pred, DenormalMode Mode, ValueTy LHS,
420425
return exactClass(Src, Mask);
421426
}
422427

423-
return fcmpImpliesClass(Pred, Mode, LHS, ConstRHS.classify(), LookThroughFn);
428+
return fcmpImpliesClass(Pred, ModeQuery, LHS, ConstRHS.classify(),
429+
LookThroughFn);
424430
}
425431

426432
} // namespace llvm

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4579,10 +4579,12 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
45794579
return LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
45804580
};
45814581

4582-
Type *Ty = LHS->getType()->getScalarType();
4583-
DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
4582+
auto ModeQuery = [&](Value *LHS) {
4583+
Type *Ty = LHS->getType()->getScalarType();
4584+
return F.getDenormalMode(Ty->getFltSemantics());
4585+
};
45844586

4585-
return fcmpImpliesClass(Pred, Mode, LHS, RHSClass, LookThrough);
4587+
return fcmpImpliesClass(Pred, ModeQuery, LHS, RHSClass, LookThrough);
45864588
}
45874589

45884590
std::tuple<Value *, FPClassTest, FPClassTest>
@@ -4593,10 +4595,12 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
45934595
return LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
45944596
};
45954597

4596-
Type *Ty = LHS->getType()->getScalarType();
4597-
DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
4598+
auto ModeQuery = [&](Value *LHS) {
4599+
Type *Ty = LHS->getType()->getScalarType();
4600+
return F.getDenormalMode(Ty->getFltSemantics());
4601+
};
45984602

4599-
return fcmpImpliesClass(Pred, Mode, LHS, ConstRHS, LookThrough);
4603+
return fcmpImpliesClass(Pred, ModeQuery, LHS, ConstRHS, LookThrough);
46004604
}
46014605

46024606
std::tuple<Value *, FPClassTest, FPClassTest>
@@ -4606,15 +4610,17 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
46064610
if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
46074611
return {nullptr, fcAllFlags, fcAllFlags};
46084612

4609-
Type *Ty = LHS->getType()->getScalarType();
4610-
DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
4611-
46124613
auto LookThrough = [=](Value *LHS, Value *&Src) {
46134614
return LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
46144615
};
46154616

4617+
auto ModeQuery = [&](Value *LHS) {
4618+
Type *Ty = LHS->getType()->getScalarType();
4619+
return F.getDenormalMode(Ty->getFltSemantics());
4620+
};
4621+
46164622
// TODO: Just call computeKnownFPClass for RHS to handle non-constants.
4617-
return fcmpImpliesClass(Pred, Mode, LHS, *ConstRHS, LookThrough);
4623+
return fcmpImpliesClass(Pred, ModeQuery, LHS, *ConstRHS, LookThrough);
46184624
}
46194625

46204626
static void computeKnownFPClassFromCond(const Value *V, Value *Cond,

llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,12 @@ GISelValueTracking::fcmpImpliesClass(CmpInst::Predicate Pred,
681681
return LookThroughSrc && mi_match(LHS, MRI, m_GFabs(m_Reg(Src)));
682682
};
683683

684-
LLT Ty = MRI.getType(LHS);
685-
DenormalMode Mode = MF.getDenormalMode(getFltSemanticForLLT(Ty));
684+
auto ModeQuery = [&](Register LHS) {
685+
LLT Ty = MRI.getType(LHS).getScalarType();
686+
return MF.getDenormalMode(getFltSemanticForLLT(Ty));
687+
};
686688

687-
return llvm::fcmpImpliesClass(Pred, Mode, LHS, RHSClass, LookThrough);
689+
return llvm::fcmpImpliesClass(Pred, ModeQuery, LHS, RHSClass, LookThrough);
688690
}
689691

690692
std::tuple<Register, FPClassTest, FPClassTest>
@@ -696,10 +698,12 @@ GISelValueTracking::fcmpImpliesClass(CmpInst::Predicate Pred,
696698
return LookThroughSrc && mi_match(LHS, MRI, m_GFabs(m_Reg(Src)));
697699
};
698700

699-
LLT Ty = MRI.getType(LHS);
700-
DenormalMode Mode = MF.getDenormalMode(getFltSemanticForLLT(Ty));
701+
auto ModeQuery = [&](Register LHS) {
702+
LLT Ty = MRI.getType(LHS).getScalarType();
703+
return MF.getDenormalMode(getFltSemanticForLLT(Ty));
704+
};
701705

702-
return llvm::fcmpImpliesClass(Pred, Mode, LHS, ConstRHS, LookThrough);
706+
return llvm::fcmpImpliesClass(Pred, ModeQuery, LHS, ConstRHS, LookThrough);
703707
}
704708

705709
std::tuple<Register, FPClassTest, FPClassTest>
@@ -715,11 +719,13 @@ GISelValueTracking::fcmpImpliesClass(CmpInst::Predicate Pred,
715719
return LookThroughSrc && mi_match(LHS, MRI, m_GFabs(m_Reg(Src)));
716720
};
717721

718-
LLT Ty = MRI.getType(LHS).getScalarType();
719-
DenormalMode Mode = MF.getDenormalMode(getFltSemanticForLLT(Ty));
722+
auto ModeQuery = [&](Register LHS) {
723+
LLT Ty = MRI.getType(LHS).getScalarType();
724+
return MF.getDenormalMode(getFltSemanticForLLT(Ty));
725+
};
720726

721727
// TODO: Just call computeKnownFPClass for RHS to handle non-constants.
722-
return llvm::fcmpImpliesClass(Pred, Mode, LHS, ConstRHS->getValueAPF(),
728+
return llvm::fcmpImpliesClass(Pred, ModeQuery, LHS, ConstRHS->getValueAPF(),
723729
LookThrough);
724730
}
725731

0 commit comments

Comments
 (0)