Skip to content

Commit cca3217

Browse files
[mlir][SCF] Use Affine ops for indexing math. (llvm#108450)
For index type of induction variable, the indexing math is better represented using affine ops such as `affine.delinearize_index`. This also further demonstrates that some of these `affine` ops might need to move to a different dialect. For one these ops only support `IndexType` when they should be able to work with any integer type. This change also includes some canonicalization patterns for `affine.delinearize_index` operation to 1) Drop unit `basis` values 2) Remove the `delinearize_index` op when the `linear_index` is a loop induction variable of a normalized loop and the `basis` is of size 1 and is also the upper bound of the normalized loop. --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent d33fa70 commit cca3217

File tree

11 files changed

+416
-207
lines changed

11 files changed

+416
-207
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
10961096
];
10971097

10981098
let hasVerifier = 1;
1099+
let hasCanonicalizer = 1;
10991100
}
11001101

11011102
#endif // AFFINE_OPS

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
394394
let summary = "Coalesce nested loops with independent bounds into a single "
395395
"loop";
396396
let constructor = "mlir::affine::createLoopCoalescingPass()";
397-
let dependentDialects = ["arith::ArithDialect"];
397+
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
398398
}
399399

400400
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
5656
def TestSCFParallelLoopCollapsing : Pass<"test-scf-parallel-loop-collapsing"> {
5757
let summary = "Test parallel loops collapsing transformation";
5858
let constructor = "mlir::createTestSCFParallelLoopCollapsingPass()";
59+
let dependentDialects = ["affine::AffineDialect"];
5960
let description = [{
6061
This pass is purely for testing the scf::collapseParallelLoops
6162
transformation. The transformation does not have opinions on how a

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4537,6 +4537,133 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
45374537
return success();
45384538
}
45394539

4540+
namespace {
4541+
4542+
// Drops delinearization indices that correspond to unit-extent basis
4543+
struct DropUnitExtentBasis
4544+
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4545+
using OpRewritePattern::OpRewritePattern;
4546+
4547+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4548+
PatternRewriter &rewriter) const override {
4549+
SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4550+
std::optional<Value> zero = std::nullopt;
4551+
Location loc = delinearizeOp->getLoc();
4552+
auto getZero = [&]() -> Value {
4553+
if (!zero)
4554+
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4555+
return zero.value();
4556+
};
4557+
4558+
// Replace all indices corresponding to unit-extent basis with 0.
4559+
// Remaining basis can be used to get a new `affine.delinearize_index` op.
4560+
SmallVector<Value> newOperands;
4561+
for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
4562+
if (matchPattern(basis, m_One()))
4563+
replacements[index] = getZero();
4564+
else
4565+
newOperands.push_back(basis);
4566+
}
4567+
4568+
if (newOperands.size() == delinearizeOp.getBasis().size())
4569+
return failure();
4570+
4571+
if (!newOperands.empty()) {
4572+
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4573+
loc, delinearizeOp.getLinearIndex(), newOperands);
4574+
int newIndex = 0;
4575+
// Map back the new delinearized indices to the values they replace.
4576+
for (auto &replacement : replacements) {
4577+
if (replacement)
4578+
continue;
4579+
replacement = newDelinearizeOp->getResult(newIndex++);
4580+
}
4581+
}
4582+
4583+
rewriter.replaceOp(delinearizeOp, replacements);
4584+
return success();
4585+
}
4586+
};
4587+
4588+
/// Drop delinearization pattern related to loops in the following way
4589+
///
4590+
/// ```
4591+
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4592+
/// %0 = affine.delinearize_index %iv into (%ub) : index
4593+
/// <some_use>(%0)
4594+
/// }
4595+
/// ```
4596+
///
4597+
/// can be canonicalized to
4598+
///
4599+
/// ```
4600+
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4601+
/// <some_use>(%iv)
4602+
/// }
4603+
/// ```
4604+
struct DropDelinearizeOfSingleLoop
4605+
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4606+
using OpRewritePattern::OpRewritePattern;
4607+
4608+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4609+
PatternRewriter &rewriter) const override {
4610+
auto basis = delinearizeOp.getBasis();
4611+
if (basis.size() != 1)
4612+
return failure();
4613+
4614+
// Check that the `linear_index` is an induction variable.
4615+
auto inductionVar = cast<BlockArgument>(delinearizeOp.getLinearIndex());
4616+
if (!inductionVar)
4617+
return failure();
4618+
4619+
// Check that the parent is a `LoopLikeOpInterface`.
4620+
auto loopLikeOp = cast<LoopLikeOpInterface>(
4621+
inductionVar.getParentRegion()->getParentOp());
4622+
if (!loopLikeOp)
4623+
return failure();
4624+
4625+
// Check that loop is unit-rank and that the `linear_index` is the induction
4626+
// variable.
4627+
auto inductionVars = loopLikeOp.getLoopInductionVars();
4628+
if (!inductionVars || inductionVars->size() != 1 ||
4629+
inductionVars->front() != inductionVar) {
4630+
return rewriter.notifyMatchFailure(
4631+
delinearizeOp, "`linear_index` is not loop induction variable");
4632+
}
4633+
4634+
// Check that the upper-bound is the basis.
4635+
auto upperBounds = loopLikeOp.getLoopUpperBounds();
4636+
if (!upperBounds || upperBounds->size() != 1 ||
4637+
upperBounds->front() != getAsOpFoldResult(basis.front())) {
4638+
return rewriter.notifyMatchFailure(delinearizeOp,
4639+
"`basis` is not upper bound");
4640+
}
4641+
4642+
// Check that the lower bound is zero.
4643+
auto lowerBounds = loopLikeOp.getLoopLowerBounds();
4644+
if (!lowerBounds || lowerBounds->size() != 1 ||
4645+
!isZeroIndex(lowerBounds->front())) {
4646+
return rewriter.notifyMatchFailure(delinearizeOp,
4647+
"loop lower bound is not zero");
4648+
}
4649+
4650+
// Check that the step is one.
4651+
auto steps = loopLikeOp.getLoopSteps();
4652+
if (!steps || steps->size() != 1 || !isConstantIntValue(steps->front(), 1))
4653+
return rewriter.notifyMatchFailure(delinearizeOp, "loop step is not one");
4654+
4655+
rewriter.replaceOp(delinearizeOp, inductionVar);
4656+
return success();
4657+
}
4658+
};
4659+
4660+
} // namespace
4661+
4662+
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4663+
RewritePatternSet &patterns, MLIRContext *context) {
4664+
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
4665+
}
4666+
45404667
//===----------------------------------------------------------------------===//
45414668
// TableGen'd op method definitions
45424669
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/SCF/Transforms/Passes.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/Dialect/SCF/IR/SCF.h"
1213
#include "mlir/Dialect/SCF/Utils/Utils.h"
1314
#include "mlir/Transforms/RegionUtils.h"

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/SCF/Utils/Utils.h"
1414
#include "mlir/Analysis/SliceAnalysis.h"
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Arith/Utils/Utils.h"
1718
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
671672
return success();
672673
}
673674

675+
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
676+
OpFoldResult lb, OpFoldResult ub,
677+
OpFoldResult step) {
678+
Range normalizedLoopBounds;
679+
normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
680+
normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
681+
AffineExpr s0, s1, s2;
682+
bindSymbols(rewriter.getContext(), s0, s1, s2);
683+
AffineExpr e = (s1 - s0).ceilDiv(s2);
684+
normalizedLoopBounds.size =
685+
affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
686+
return normalizedLoopBounds;
687+
}
688+
674689
Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
675690
OpFoldResult lb, OpFoldResult ub,
676691
OpFoldResult step) {
692+
if (getType(lb).isIndex()) {
693+
return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
694+
}
677695
// For non-index types, generate `arith` instructions
678696
// Check if the loop is already known to have a constant zero lower bound or
679697
// a constant one step.
@@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
714732
return {newLowerBound, newUpperBound, newStep};
715733
}
716734

735+
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
736+
Location loc,
737+
Value normalizedIv,
738+
OpFoldResult origLb,
739+
OpFoldResult origStep) {
740+
AffineExpr d0, s0, s1;
741+
bindSymbols(rewriter.getContext(), s0, s1);
742+
bindDims(rewriter.getContext(), d0);
743+
AffineExpr e = d0 * s1 + s0;
744+
OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
745+
rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
746+
Value denormalizedIvVal =
747+
getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
748+
SmallPtrSet<Operation *, 1> preservedUses;
749+
// If an `affine.apply` operation is generated for denormalization, the use
750+
// of `origLb` in those ops must not be replaced. These arent not generated
751+
// when `origLb == 0` and `origStep == 1`.
752+
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
753+
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
754+
preservedUses.insert(preservedUse);
755+
}
756+
}
757+
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
758+
}
759+
717760
void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
718761
Value normalizedIv, OpFoldResult origLb,
719762
OpFoldResult origStep) {
763+
if (getType(origLb).isIndex()) {
764+
return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
765+
origLb, origStep);
766+
}
720767
Value denormalizedIv;
721768
SmallPtrSet<Operation *, 2> preserve;
722769
bool isStepOne = isConstantIntValue(origStep, 1);
@@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
739786
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
740787
}
741788

789+
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
790+
ArrayRef<OpFoldResult> values) {
791+
assert(!values.empty() && "unexecpted empty array");
792+
AffineExpr s0, s1;
793+
bindSymbols(rewriter.getContext(), s0, s1);
794+
AffineExpr mul = s0 * s1;
795+
OpFoldResult products = rewriter.getIndexAttr(1);
796+
for (auto v : values) {
797+
products = affine::makeComposedFoldedAffineApply(
798+
rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
799+
}
800+
return products;
801+
}
802+
742803
/// Helper function to multiply a sequence of values.
743804
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
744805
ArrayRef<Value> values) {
745806
assert(!values.empty() && "unexpected empty list");
807+
if (getType(values.front()).isIndex()) {
808+
SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
809+
OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
810+
return getValueOrCreateConstantIndexOp(rewriter, loc, product);
811+
}
746812
std::optional<Value> productOf;
747813
for (auto v : values) {
748814
auto vOne = getConstantIntValue(v);
@@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
757823
if (!productOf) {
758824
productOf = rewriter
759825
.create<arith::ConstantOp>(
760-
loc, rewriter.getOneAttr(values.front().getType()))
826+
loc, rewriter.getOneAttr(getType(values.front())))
761827
.getResult();
762828
}
763829
return productOf.value();
@@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
774840
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
775841
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
776842
Value linearizedIv, ArrayRef<Value> ubs) {
843+
844+
if (linearizedIv.getType().isIndex()) {
845+
Operation *delinearizedOp =
846+
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
847+
ubs);
848+
auto resultVals = llvm::map_to_vector(
849+
delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
850+
return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
851+
}
852+
777853
SmallVector<Value> delinearizedIvs(ubs.size());
778854
SmallPtrSet<Operation *, 2> preservedUsers;
779855

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,3 +1466,51 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14661466
}
14671467
return
14681468
}
1469+
1470+
// -----
1471+
1472+
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
1473+
(index, index, index, index, index, index) {
1474+
%c1 = arith.constant 1 : index
1475+
%0:6 = affine.delinearize_index %arg0 into (%c1, %arg1, %c1, %c1, %arg2, %c1)
1476+
: index, index, index, index, index, index
1477+
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
1478+
}
1479+
// CHECK-LABEL: func @drop_unit_basis_in_delinearize(
1480+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1481+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1482+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1483+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1484+
// CHECK-DAG: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[ARG2]])
1485+
// CHECK: return %[[C0]], %[[DELINEARIZE]]#0, %[[C0]], %[[C0]], %[[DELINEARIZE]]#1, %[[C0]]
1486+
1487+
// -----
1488+
1489+
func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
1490+
%c1 = arith.constant 1 : index
1491+
%0:2 = affine.delinearize_index %arg0 into (%c1, %c1) : index, index
1492+
return %0#0, %0#1 : index, index
1493+
}
1494+
// CHECK-LABEL: func @drop_all_unit_bases(
1495+
// CHECK-SAME: %[[ARG0:.+]]: index)
1496+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1497+
// CHECK-NOT: affine.delinearize_index
1498+
// CHECK: return %[[C0]], %[[C0]]
1499+
1500+
// -----
1501+
1502+
func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
1503+
%c0 = arith.constant 0 : index
1504+
%c1 = arith.constant 1 : index
1505+
%2 = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%arg2 = %c0) -> index {
1506+
%0 = affine.delinearize_index %iv into (%arg1) : index
1507+
%1 = "some_use"(%arg2, %0) : (index, index) -> (index)
1508+
scf.yield %1 : index
1509+
}
1510+
return %2 : index
1511+
}
1512+
// CHECK-LABEL: func @drop_single_loop_delinearize(
1513+
// CHECK-SAME: %[[ARG0:.+]]: index)
1514+
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
1515+
// CHECK-NOT: affine.delinearize_index
1516+
// CHECK: "some_use"(%{{.+}}, %[[IV]])

0 commit comments

Comments
 (0)