Skip to content

[mlir][Arith] Generalize and improve -int-range-optimizations #94712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
/// equivalent.
std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();

/// Add patterns for int range based optimizations.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver);

/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();

Expand Down
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
let summary = "Do optimizations based on integer range analysis";
let description = [{
This pass runs integer range analysis and apllies optimizations based on its
results. e.g. replace arith.cmpi with const if it can be inferred from
args ranges.
results. It replaces operations with known-constant results with said constants,
rewrites `(0 <= %x < D) mod D` to `%x`.
}];
// Explicitly depend on "arith" because this pass could create operations in
// `arith` out of thin air in some cases.
let dependentDialects = [
"::mlir::arith::ArithDialect"
];
}

def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
Expand Down
287 changes: 132 additions & 155 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/FoldUtils.h"

namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTRANGEOPTS
Expand All @@ -24,155 +25,145 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;

/// Returns true if 2 integer ranges have intersection.
static bool intersects(const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs) {
return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
(lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
/// Patterned after SCCP
static LogicalResult replaceWithConstant(DataFlowSolver &solver,
RewriterBase &rewriter,
OperationFolder &folder, Value value) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return failure();
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
if (!maybeConstValue.has_value())
return failure();

Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr =
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
Value constant = folder.getOrCreateConstant(
rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
// Fall back to arith.constant if the dialect materializer doesn't know what
// to do with an integer constant.
if (!constant)
constant = folder.getOrCreateConstant(
rewriter.getInsertionBlock(),
rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
value.getType());
if (!constant)
return failure();

rewriter.replaceAllUsesWith(value, constant);
return success();
}

static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return false;

return failure();
}

static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return true;

return failure();
}

static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().slt(rhs.smin()))
return true;

if (lhs.smin().sge(rhs.smax()))
return false;

return failure();
}

static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().sle(rhs.smin()))
return true;

if (lhs.smin().sgt(rhs.smax()))
return false;

return failure();
}

static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSlt(std::move(rhs), std::move(lhs));
}

static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSle(std::move(rhs), std::move(lhs));
}

static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ult(rhs.umin()))
return true;

if (lhs.umin().uge(rhs.umax()))
return false;

return failure();
}

static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ule(rhs.umin()))
return true;

if (lhs.umin().ugt(rhs.umax()))
return false;

/// Rewrite any results of `op` that were inferred to be constant integers to
/// and replace their uses with that constant. Return success() if all results
/// where thus replaced and the operation is erased.
static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
RewriterBase &rewriter,
OperationFolder &folder,
Operation &op) {
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
succeeded(replaceWithConstant(solver, rewriter, folder, res));

// If all of the results of the operation were replaced, try to erase
// the operation completely.
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
assert(op.use_empty() && "expected all uses to be replaced");
rewriter.eraseOp(&op);
return success();
}
return failure();
}

static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUlt(std::move(rhs), std::move(lhs));
/// This function hasn't come from anywhere and is relying on the overall
/// tests of the integer range inference implementation for its correctness.
static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
RewriterBase &rewriter,
Operation &op) {
if (!isa<RemSIOp, RemUIOp>(op))
return failure();
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
if (!rhsConstVal)
return failure();
int64_t modulus = rhsConstVal.value();
if (modulus <= 0)
return failure();
auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
return failure();
const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
// The minima and maxima here are given as closed ranges, we must be strictly
// less than the modulus.
if (min.isNegative() || min.uge(modulus))
return failure();
if (max.isNegative() || max.uge(modulus))
return failure();
if (!min.ule(max))
return failure();

// With all those conditions out of the way, we know thas this invocation of
// a remainder is a noop because the input is strictly within the range
// [0, modulus), so get rid of it.
rewriter.replaceOp(&op, ValueRange{lhs});
return success();
}

static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUle(std::move(rhs), std::move(lhs));
static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region &region : regions)
for (Block &block : llvm::reverse(region))
worklist.push_back(&block);
};

IRRewriter rewriter(context);
OperationFolder folder(context, rewriter.getListener());

addToWorklist(initialRegions);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();

for (Operation &op : llvm::make_early_inc_range(*block)) {
if (matchPattern(&op, m_Constant())) {
if (auto arithConstant = dyn_cast<ConstantOp>(op))
folder.insertKnownConstant(&op, arithConstant.getValue());
else
folder.insertKnownConstant(&op);
continue;
}
rewriter.setInsertionPoint(&op);

// Try rewrites. Success means that the underlying operation was erased.
if (succeeded(foldResultsToConstants(solver, rewriter, folder, op)))
continue;
if (isa<RemSIOp, RemUIOp>(op) &&
succeeded(deleteTrivialRemainder(solver, rewriter, op)))
continue;
// Add any the regions of this operation to the worklist.
addToWorklist(op.getRegions());
}

// Replace any block arguments with constants.
rewriter.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
(void)replaceWithConstant(solver, rewriter, folder, arg);
}
}

namespace {
/// This class listens on IR transformations performed during a pass relying on
/// information from a `DataflowSolver`. It erases state associated with the
/// erased operation and its results from the `DataFlowSolver` so that Patterns
/// do not accidentally query old state information for newly created Ops.
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}

protected:
void notifyOperationErased(Operation *op) override {
s.eraseState(op);
for (Value res : op->getResults())
s.eraseState(res);
}

DataFlowSolver &s;
};

struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {

ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<arith::CmpIOp>(context), solver(s) {}

LogicalResult matchAndRewrite(arith::CmpIOp op,
PatternRewriter &rewriter) const override {
auto *lhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
if (!lhsResult || lhsResult->getValue().isUninitialized())
return failure();

auto *rhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
if (!rhsResult || rhsResult->getValue().isUninitialized())
return failure();

using HandlerFunc =
FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
handlers{};
using Pred = arith::CmpIPredicate;
handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
handlers[static_cast<size_t>(Pred::uge)] = &handleUge;

HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
if (!handler)
return failure();

ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
FailureOr<bool> result = handler(lhsValue, rhsValue);

if (failed(result))
return failure();

rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
op, static_cast<int64_t>(*result), /*width*/ 1);
return success();
}

private:
DataFlowSolver &solver;
};

struct IntRangeOptimizationsPass
: public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {

Expand All @@ -185,25 +176,11 @@ struct IntRangeOptimizationsPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();

DataFlowListener listener(solver);

RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);

GreedyRewriteConfig config;
config.listener = &listener;

if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
doRewrites(solver, ctx, op->getRegions());
}
};
} // namespace

void mlir::arith::populateIntRangeOptimizationsPatterns(
RewritePatternSet &patterns, DataFlowSolver &solver) {
patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
}

std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
return std::make_unique<IntRangeOptimizationsPass>();
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Arith/int-range-interface.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s

// CHECK-LABEL: func @add_min_max
// CHECK: %[[c3:.*]] = arith.constant 3 : index
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Arith/int-range-opts.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,39 @@ func.func @test() -> i8 {
return %1: i8
}

// -----

// CHECK-LABEL: func @trivial_rem
// CHECK: [[val:%.+]] = test.with_bounds
// CHECK: return [[val]]
func.func @trivial_rem() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}

// -----

// CHECK-LABEL: func @non_const_rhs
// CHECK: [[mod:%.+]] = arith.remui
// CHECK: return [[mod]]
func.func @non_const_rhs() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
%rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
%mod = arith.remui %val, %rhs : i8
return %mod : i8
}

// -----

// CHECK-LABEL: func @wraps
// CHECK: [[mod:%.+]] = arith.remsi
// CHECK: return [[mod]]
func.func @wraps() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
Loading
Loading