Skip to content

Commit 04ce103

Browse files
[mlir][SCF] Avoid generating unnecessary div/rem operations during coalescing (llvm#91562)
When coalescing is some of the loops are unit-trip we can avoid generating div/rem instructions during delinearization. Ideally we could use some thing like `affine.delinearize` to handle this but tthat causes dependence issues.
1 parent 8466480 commit 04ce103

File tree

2 files changed

+124
-10
lines changed

2 files changed

+124
-10
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -544,11 +544,24 @@ static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
544544
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
545545
ArrayRef<Value> values) {
546546
assert(!values.empty() && "unexpected empty list");
547-
Value productOf = values.front();
548-
for (auto v : values.drop_front()) {
549-
productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
547+
std::optional<Value> productOf;
548+
for (auto v : values) {
549+
auto vOne = getConstantIntValue(v);
550+
if (vOne && vOne.value() == 1)
551+
continue;
552+
if (productOf)
553+
productOf =
554+
rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
555+
else
556+
productOf = v;
550557
}
551-
return productOf;
558+
if (!productOf) {
559+
productOf = rewriter
560+
.create<arith::ConstantOp>(
561+
loc, rewriter.getOneAttr(values.front().getType()))
562+
.getResult();
563+
}
564+
return productOf.value();
552565
}
553566

554567
/// For each original loop, the value of the
@@ -562,19 +575,43 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
562575
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
563576
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
564577
Value linearizedIv, ArrayRef<Value> ubs) {
565-
Value previous = linearizedIv;
566578
SmallVector<Value> delinearizedIvs(ubs.size());
567579
SmallPtrSet<Operation *, 2> preservedUsers;
568-
for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
569-
unsigned idx = ubs.size() - i - 1;
570-
if (i != 0) {
580+
581+
llvm::BitVector isUbOne(ubs.size());
582+
for (auto [index, ub] : llvm::enumerate(ubs)) {
583+
auto ubCst = getConstantIntValue(ub);
584+
if (ubCst && ubCst.value() == 1)
585+
isUbOne.set(index);
586+
}
587+
588+
// Prune the lead ubs that are all ones.
589+
unsigned numLeadingOneUbs = 0;
590+
for (auto [index, ub] : llvm::enumerate(ubs)) {
591+
if (!isUbOne.test(index)) {
592+
break;
593+
}
594+
delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
595+
loc, rewriter.getZeroAttr(ub.getType()));
596+
numLeadingOneUbs++;
597+
}
598+
599+
Value previous = linearizedIv;
600+
for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
601+
unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
602+
if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
571603
previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
572604
preservedUsers.insert(previous.getDefiningOp());
573605
}
574606
Value iv = previous;
575607
if (i != e - 1) {
576-
iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
577-
preservedUsers.insert(iv.getDefiningOp());
608+
if (!isUbOne.test(idx)) {
609+
iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
610+
preservedUsers.insert(iv.getDefiningOp());
611+
} else {
612+
iv = rewriter.create<arith::ConstantOp>(
613+
loc, rewriter.getZeroAttr(ubs[idx].getType()));
614+
}
578615
}
579616
delinearizedIvs[idx] = iv;
580617
}

mlir/test/Dialect/SCF/transform-op-coalesce.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,80 @@ module attributes {transform.with_named_sequence} {
299299
// CHECK-NOT: scf.for
300300
// CHECK: transform.named_sequence
301301

302+
// -----
303+
304+
// Check avoiding generating unnecessary operations while collapsing trip-1 loops.
305+
func.func @trip_one_loops(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
306+
%c0 = arith.constant 0 : index
307+
%c1 = arith.constant 1 : index
308+
%0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
309+
%1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
310+
%2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
311+
%3 = scf.for %iv3 = %c0 to %c1 step %c1 iter_args(%iter3 = %iter2) -> tensor<?x?xf32> {
312+
%4 = scf.for %iv4 = %c0 to %arg2 step %c1 iter_args(%iter4 = %iter3) -> tensor<?x?xf32> {
313+
%5 = "some_use"(%iter4, %iv0, %iv1, %iv2, %iv3, %iv4)
314+
: (tensor<?x?xf32>, index, index, index, index, index) -> (tensor<?x?xf32>)
315+
scf.yield %5 : tensor<?x?xf32>
316+
}
317+
scf.yield %4 : tensor<?x?xf32>
318+
}
319+
scf.yield %3 : tensor<?x?xf32>
320+
}
321+
scf.yield %2 : tensor<?x?xf32>
322+
}
323+
scf.yield %1 : tensor<?x?xf32>
324+
} {coalesce}
325+
return %0 : tensor<?x?xf32>
326+
}
327+
module attributes {transform.with_named_sequence} {
328+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
329+
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
330+
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
331+
%2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
332+
transform.yield
333+
}
334+
}
335+
// CHECK-LABEL: func @trip_one_loops
336+
// CHECK-SAME: , %[[ARG1:.+]]: index,
337+
// CHECK-SAME: %[[ARG2:.+]]: index)
338+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
339+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
340+
// CHECK: %[[UB:.+]] = arith.muli %[[ARG1]], %[[ARG2]]
341+
// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[UB]] step %[[C1]]
342+
// CHECK: %[[IV1:.+]] = arith.remsi %[[IV]], %[[ARG2]]
343+
// CHECK: %[[IV2:.+]] = arith.divsi %[[IV]], %[[ARG2]]
344+
// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV2]], %[[C0]], %[[IV1]])
345+
346+
// -----
347+
348+
// Check generating no instructions when all except one loops is non unit-trip.
349+
func.func @all_outer_trip_one(%arg0 : tensor<?x?xf32>, %arg1 : index) -> tensor<?x?xf32> {
350+
%c0 = arith.constant 0 : index
351+
%c1 = arith.constant 1 : index
352+
%0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
353+
%1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
354+
%2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
355+
%3 = "some_use"(%iter2, %iv0, %iv1, %iv2)
356+
: (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>)
357+
scf.yield %3 : tensor<?x?xf32>
358+
}
359+
scf.yield %2 : tensor<?x?xf32>
360+
}
361+
scf.yield %1 : tensor<?x?xf32>
362+
} {coalesce}
363+
return %0 : tensor<?x?xf32>
364+
}
365+
module attributes {transform.with_named_sequence} {
366+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
367+
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
368+
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
369+
%2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
370+
transform.yield
371+
}
372+
}
373+
// CHECK-LABEL: func @all_outer_trip_one
374+
// CHECK-SAME: , %[[ARG1:.+]]: index)
375+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
376+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
377+
// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[ARG1]] step %[[C1]]
378+
// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV]])

0 commit comments

Comments
 (0)