Skip to content

[GISel] Add KnownFPClass Analysis to GISelValueTrackingPass #134611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 433 additions & 0 deletions llvm/include/llvm/ADT/FloatingPointModeUtils.h

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
#ifndef LLVM_CODEGEN_GLOBALISEL_GISELVALUETRACKING_H
#define LLVM_CODEGEN_GLOBALISEL_GISELVALUETRACKING_H

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/KnownFPClass.h"

namespace llvm {

Expand All @@ -35,13 +38,72 @@ class GISelValueTracking : public GISelChangeObserver {
unsigned MaxDepth;
/// Cache maintained during a computeKnownBits request.
SmallDenseMap<Register, KnownBits, 16> ComputeKnownBitsCache;
SmallDenseMap<Register, KnownFPClass, 16> ComputeKnownFPClassCache;

void computeKnownBitsMin(Register Src0, Register Src1, KnownBits &Known,
const APInt &DemandedElts, unsigned Depth = 0);

unsigned computeNumSignBitsMin(Register Src0, Register Src1,
const APInt &DemandedElts, unsigned Depth = 0);

/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
/// same result as an fcmp with the given operands.
///
/// If \p LookThroughSrc is true, consider the input value when computing the
/// mask.
///
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first
/// pair element will always be LHS.
std::pair<Register, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
const MachineFunction &MF,
Register LHS, Value *RHS,
bool LookThroughSrc = true);
std::pair<Register, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
const MachineFunction &MF,
Register LHS,
const APFloat *ConstRHS,
bool LookThroughSrc = true);

/// Compute the possible floating-point classes that \p LHS could be based on
/// fcmp \Pred \p LHS, \p RHS.
///
/// \returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
///
/// If the compare returns an exact class test, ClassesIfTrue ==
/// ~ClassesIfFalse
///
/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
/// only succeed for a test of x > 0 implies positive, but not x > 1).
///
/// If \p LookThroughSrc is true, consider the input value when computing the
/// mask. This may look through sign bit operations.
///
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first
/// pair element will always be LHS.
///
std::tuple<Register, FPClassTest, FPClassTest>
fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
Register LHS, Register RHS, bool LookThroughSrc = true);
std::tuple<Register, FPClassTest, FPClassTest>
fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
Register LHS, FPClassTest RHS, bool LookThroughSrc = true);
std::tuple<Register, FPClassTest, FPClassTest>
fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
Register LHS, const APFloat &RHS,
bool LookThroughSrc = true);

void computeKnownFPClass(Register R, KnownFPClass &Known,
FPClassTest InterestedClasses, unsigned Depth);

void computeKnownFPClassForFPTrunc(const MachineInstr &MI,
const APInt &DemandedElts,
FPClassTest InterestedClasses,
KnownFPClass &Known, unsigned Depth);

void computeKnownFPClass(Register R, const APInt &DemandedElts,
FPClassTest InterestedClasses, KnownFPClass &Known,
unsigned Depth);

public:
GISelValueTracking(MachineFunction &MF, unsigned MaxDepth = 6);
virtual ~GISelValueTracking() = default;
Expand Down Expand Up @@ -87,6 +149,34 @@ class GISelValueTracking : public GISelChangeObserver {
/// \return The known alignment for the pointer-like value \p R.
Align computeKnownAlignment(Register R, unsigned Depth = 0);

/// Determine which floating-point classes are valid for \p V, and return them
/// in KnownFPClass bit sets.
///
/// This function is defined on values with floating-point type, values
/// vectors of floating-point type, and arrays of floating-point type.

/// \p InterestedClasses is a compile time optimization hint for which
/// floating point classes should be queried. Queries not specified in \p
/// InterestedClasses should be reliable if they are determined during the
/// query.
KnownFPClass computeKnownFPClass(Register R, const APInt &DemandedElts,
FPClassTest InterestedClasses,
unsigned Depth);

KnownFPClass computeKnownFPClass(Register R,
FPClassTest InterestedClasses = fcAllFlags,
unsigned Depth = 0);

/// Wrapper to account for known fast math flags at the use instruction.
KnownFPClass computeKnownFPClass(Register R, const APInt &DemandedElts,
uint32_t Flags,
FPClassTest InterestedClasses,
unsigned Depth);

KnownFPClass computeKnownFPClass(Register R, uint32_t Flags,
FPClassTest InterestedClasses,
unsigned Depth);

// Observer API. No-op for non-caching implementation.
void erasingInstr(MachineInstr &MI) override {}
void createdInstr(MachineInstr &MI) override {}
Expand Down
38 changes: 38 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#define LLVM_CODEGEN_GLOBALISEL_MIPATTERNMATCH_H

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/InstrTypes.h"

namespace llvm {
Expand Down Expand Up @@ -393,6 +396,7 @@ inline bind_ty<const MachineInstr *> m_MInstr(const MachineInstr *&MI) {
inline bind_ty<LLT> m_Type(LLT &Ty) { return Ty; }
inline bind_ty<CmpInst::Predicate> m_Pred(CmpInst::Predicate &P) { return P; }
inline operand_type_match m_Pred() { return operand_type_match(); }
inline bind_ty<FPClassTest> m_FPClassTest(FPClassTest &T) { return T; }

template <typename BindTy> struct deferred_helper {
static bool match(const MachineRegisterInfo &MRI, BindTy &VR, BindTy &V) {
Expand Down Expand Up @@ -762,6 +766,32 @@ struct CompareOp_match {
}
};

template <typename LHS_P, typename Test_P, unsigned Opcode>
struct ClassifyOp_match {
LHS_P L;
Test_P T;

ClassifyOp_match(const LHS_P &LHS, const Test_P &Tst) : L(LHS), T(Tst) {}

template <typename OpTy>
bool match(const MachineRegisterInfo &MRI, OpTy &&Op) {
MachineInstr *TmpMI;
if (!mi_match(Op, MRI, m_MInstr(TmpMI)) || TmpMI->getOpcode() != Opcode)
return false;

Register LHS = TmpMI->getOperand(1).getReg();
if (!L.match(MRI, LHS))
return false;

FPClassTest TmpClass =
static_cast<FPClassTest>(TmpMI->getOperand(2).getImm());
if (T.match(MRI, TmpClass))
return true;

return false;
}
};

template <typename Pred, typename LHS, typename RHS>
inline CompareOp_match<Pred, LHS, RHS, TargetOpcode::G_ICMP>
m_GICmp(const Pred &P, const LHS &L, const RHS &R) {
Expand Down Expand Up @@ -804,6 +834,14 @@ m_c_GFCmp(const Pred &P, const LHS &L, const RHS &R) {
return CompareOp_match<Pred, LHS, RHS, TargetOpcode::G_FCMP, true>(P, L, R);
}

/// Matches the register and immediate used in a fpclass test
/// G_IS_FPCLASS %val, 96
template <typename LHS, typename Test>
inline ClassifyOp_match<LHS, Test, TargetOpcode::G_IS_FPCLASS>
m_GIsFPClass(const LHS &L, const Test &T) {
return ClassifyOp_match<LHS, Test, TargetOpcode::G_IS_FPCLASS>(L, T);
}

// Helper for checking if a Reg is of specific type.
struct CheckType {
LLT Ty;
Expand Down
20 changes: 20 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@ class GIConstant {
/// }
/// provides low-level access.
class GFConstant {
using VecTy = SmallVector<APFloat>;
using const_iterator = VecTy::const_iterator;

public:
enum class GFConstantKind { Scalar, FixedVector, ScalableVector };

Expand All @@ -672,6 +675,23 @@ class GFConstant {
/// Returns the kind of of this constant, e.g, Scalar.
GFConstantKind getKind() const { return Kind; }

const_iterator begin() const {
assert(Kind != GFConstantKind::ScalableVector &&
"Expected fixed vector or scalar constant");
return Values.begin();
}

const_iterator end() const {
assert(Kind != GFConstantKind::ScalableVector &&
"Expected fixed vector or scalar constant");
return Values.end();
}

size_t size() const {
assert(Kind == GFConstantKind::FixedVector && "Expected fixed vector");
return Values.size();
}
Comment on lines +678 to +693
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unrelated?

Copy link
Member Author

@tgymnich tgymnich May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used in GISelValueTracking::computeKnownFPClass(GISelValueTracking.cpp:800) to handle constants.


/// Returns the value, if this constant is a scalar.
APFloat getScalarValue() const;

Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownFPClass.h"
#include <algorithm>
#include <cassert>
#include <climits>
Expand Down Expand Up @@ -4230,6 +4231,13 @@ class TargetLowering : public TargetLoweringBase {
const MachineRegisterInfo &MRI,
unsigned Depth = 0) const;

virtual void computeKnownFPClassForTargetInstr(GISelValueTracking &Analysis,
Register R,
KnownFPClass &Known,
const APInt &DemandedElts,
const MachineRegisterInfo &MRI,
unsigned Depth = 0) const;

/// Determine the known alignment for the pointer value \p R. This is can
/// typically be inferred from the number of low known 0 bits. However, for a
/// pointer with a non-integral address space, the alignment value may be
Expand Down
Loading