Skip to content

Commit eacda36

Browse files
author
Rolf Morel
authored
[SCF][Transform] Add support for scf.for in LoopFuseSibling op (#81495)
Adds support for fusing two scf.for loops occurring in the same block. Uses the rudimentary checks already in place for scf.forall (like the target loop's operands being dominated by the source loop). - Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than the ops that define these operands. - Renames the LoopFuseSibling op to LoopFuseSiblingOp. - Updates LoopFuseSiblingOp's description. - Adds tests for using LoopFuseSiblingOp on scf.for loops, including one which fails without the fix for the dominance check. - Adds tests checking the different failure modes of the dominance checker. - Adds test for case whereby scf.yield is automatically generated when there are no loop-carried variables.
1 parent 91856b3 commit eacda36

File tree

5 files changed

+357
-92
lines changed

5 files changed

+357
-92
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -333,23 +333,24 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
333333
}];
334334
}
335335

336-
def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
336+
def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
337337
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
338338
DeclareOpInterfaceMethods<TransformOpInterface>]> {
339339
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
340340

341341
let description = [{
342342
Fuses the `target` loop into the `source` loop assuming they are
343-
independent of each other. It is the responsibility of the user to ensure
344-
that the given two loops are independent of each other, this operation will
345-
not performa any legality checks and will simply fuse the two given loops.
343+
independent of each other. In the fused loop, the arguments, body and
344+
results of `target` are placed _before_ those of `source`.
346345

347-
Currently, the only fusion supported is when both `target` and `source`
348-
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
349-
mapping must match, otherwise a silencable failure is produced.
346+
For fusion of two `scf.for` loops, the bounds and step size must match. For
347+
fusion of two `scf.forall` loops, the bounds and the mapping must match.
348+
Otherwise a silencable failure is produced.
350349

351-
The input handles `target` and `source` must map to exactly one operation,
352-
a definite failure is produced otherwise.
350+
The `target` and `source` handles must refer to exactly one operation,
351+
otherwise a definite failure is produced. It is the responsibility of the
352+
user to ensure that the `target` and `source` loops are independent of each
353+
other -- this op will only perform rudimentary legality checks.
353354

354355
#### Return modes
355356

@@ -362,10 +363,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
362363
let results = (outs TransformHandleTypeInterface:$fused_loop);
363364
let assemblyFormat = "$target `into` $source attr-dict "
364365
" `:` functional-type(operands, results)";
365-
366-
let builders = [
367-
OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
368-
];
369366
}
370367

371368
#endif // SCF_TRANSFORM_OPS

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
162162
scf::ForallOp source,
163163
RewriterBase &rewriter);
164164

165+
/// Given two scf.for loops, `target` and `source`, fuses `target` into
166+
/// `source`. Assumes that the given loops are siblings and are independent of
167+
/// each other.
168+
///
169+
/// This function does not perform any legality checks and simply fuses the
170+
/// loops. The caller is responsible for ensuring that the loops are legal to
171+
/// fuse.
172+
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
173+
RewriterBase &rewriter);
174+
165175
} // namespace mlir
166176

167177
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
384384
}
385385

386386
//===----------------------------------------------------------------------===//
387-
// LoopFuseSibling
387+
// LoopFuseSiblingOp
388388
//===----------------------------------------------------------------------===//
389389

390390
/// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
408408
// Check if fusion will violate dominance.
409409
DominanceInfo domInfo(source);
410410
if (target->isBeforeInBlock(source)) {
411-
// Since, `target` is before `source`, all users of results of `target`
411+
// Since `target` is before `source`, all users of results of `target`
412412
// need to be dominated by `source`.
413413
for (Operation *user : target->getUsers()) {
414414
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
424424
// Check if operands of `target` are dominated by `source`.
425425
for (Value operand : target->getOperands()) {
426426
Operation *operandOp = operand.getDefiningOp();
427-
// If operand does not have a defining operation, it is a block arguement,
428-
// which will always dominate `source`, since `target` and `source` are in
429-
// the same block and the operand dominated `source` before.
427+
// Operands without defining operations are block arguments. When `target`
428+
// and `source` occur in the same block, these operands dominate `source`.
430429
if (!operandOp)
431430
continue;
432431

@@ -441,8 +440,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
441440
bool failed = false;
442441
OpOperand *failedValue = nullptr;
443442
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
444-
if (!domInfo.properlyDominates(operand->getOwner(), source,
445-
/*enclosingOpOk=*/false)) {
443+
Operation *operandOp = operand->get().getDefiningOp();
444+
if (operandOp && !domInfo.properlyDominates(operandOp, source,
445+
/*enclosingOpOk=*/false)) {
446+
// `operand` is not an argument of an enclosing block and the defining
447+
// op of `operand` is outside `target` but does not dominate `source`.
446448
failed = true;
447449
failedValue = operand;
448450
}
@@ -457,12 +459,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
457459
return DiagnosedSilenceableFailure::success();
458460
}
459461

460-
/// Check if `target` can be fused into `source`.
462+
/// Check if `target` scf.forall can be fused into `source` scf.forall.
461463
///
462-
/// This is a simple check that just checks if both loops have same
463-
/// bounds, steps and mapping. This check does not ensure that the side effects
464-
/// of `target` are independent of `source` or vice-versa. It is the
465-
/// responsibility of the caller to ensure that.
464+
/// This simply checks if both loops have the same bounds, steps and mapping.
465+
/// No attempt is made at checking that the side effects of `target` and
466+
/// `source` are independent of each other.
466467
static bool isForallWithIdenticalConfiguration(Operation *target,
467468
Operation *source) {
468469
auto targetOp = dyn_cast<scf::ForallOp>(target);
@@ -476,21 +477,27 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
476477
targetOp.getMapping() == sourceOp.getMapping();
477478
}
478479

479-
/// Fuse `target` into `source` assuming they are siblings and indepndent.
480-
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
481-
static Operation *fuseSiblings(Operation *target, Operation *source,
482-
RewriterBase &rewriter) {
483-
auto targetOp = dyn_cast<scf::ForallOp>(target);
484-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
480+
/// Check if `target` scf.for can be fused into `source` scf.for.
481+
///
482+
/// This simply checks if both loops have the same bounds and steps. No attempt
483+
/// is made at checking that the side effects of `target` and `source` are
484+
/// independent of each other.
485+
static bool isForWithIdenticalConfiguration(Operation *target,
486+
Operation *source) {
487+
auto targetOp = dyn_cast<scf::ForOp>(target);
488+
auto sourceOp = dyn_cast<scf::ForOp>(source);
485489
if (!targetOp || !sourceOp)
486-
return nullptr;
487-
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
490+
return false;
491+
492+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
493+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
494+
targetOp.getStep() == sourceOp.getStep();
488495
}
489496

490497
DiagnosedSilenceableFailure
491-
transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
492-
transform::TransformResults &results,
493-
transform::TransformState &state) {
498+
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
499+
transform::TransformResults &results,
500+
transform::TransformState &state) {
494501
auto targetOps = state.getPayloadOps(getTarget());
495502
auto sourceOps = state.getPayloadOps(getSource());
496503

@@ -510,13 +517,18 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
510517
if (!diag.succeeded())
511518
return diag;
512519

513-
// Check if the target can be fused into source.
514-
if (!isForallWithIdenticalConfiguration(target, source)) {
520+
Operation *fusedLoop;
521+
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
522+
if (isForWithIdenticalConfiguration(target, source)) {
523+
fusedLoop = fuseIndependentSiblingForLoops(
524+
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
525+
} else if (isForallWithIdenticalConfiguration(target, source)) {
526+
fusedLoop = fuseIndependentSiblingForallLoops(
527+
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
528+
} else
515529
return emitSilenceableFailure(target->getLoc())
516530
<< "operations cannot be fused";
517-
}
518531

519-
Operation *fusedLoop = fuseSiblings(target, source, rewriter);
520532
assert(fusedLoop && "failed to fuse operations");
521533

522534
results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
910910
unsigned numTargetOuts = target.getNumResults();
911911
unsigned numSourceOuts = source.getNumResults();
912912

913-
OperandRange targetOuts = target.getOutputs();
914-
OperandRange sourceOuts = source.getOutputs();
915-
916913
// Create fused shared_outs.
917914
SmallVector<Value> fusedOuts;
918-
fusedOuts.reserve(numTargetOuts + numSourceOuts);
919-
fusedOuts.append(targetOuts.begin(), targetOuts.end());
920-
fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
915+
llvm::append_range(fusedOuts, target.getOutputs());
916+
llvm::append_range(fusedOuts, source.getOutputs());
921917

922-
// Create a new scf::forall op after the source loop.
918+
// Create a new scf.forall op after the source loop.
923919
rewriter.setInsertionPointAfter(source);
924920
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
925921
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
926922
source.getMixedStep(), fusedOuts, source.getMapping());
927923

928924
// Map control operands.
929-
IRMapping fusedMapping;
930-
fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
931-
fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
925+
IRMapping mapping;
926+
mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
927+
mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
932928

933929
// Map shared outs.
934-
fusedMapping.map(target.getRegionIterArgs(),
935-
fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
936-
fusedMapping.map(
937-
source.getRegionIterArgs(),
938-
fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
930+
mapping.map(target.getRegionIterArgs(),
931+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
932+
mapping.map(source.getRegionIterArgs(),
933+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
939934

940935
// Append everything except the terminator into the fused operation.
941936
rewriter.setInsertionPointToStart(fusedLoop.getBody());
942937
for (Operation &op : target.getBody()->without_terminator())
943-
rewriter.clone(op, fusedMapping);
938+
rewriter.clone(op, mapping);
944939
for (Operation &op : source.getBody()->without_terminator())
945-
rewriter.clone(op, fusedMapping);
940+
rewriter.clone(op, mapping);
946941

947942
// Fuse the old terminator in_parallel ops into the new one.
948943
scf::InParallelOp targetTerm = target.getTerminator();
949944
scf::InParallelOp sourceTerm = source.getTerminator();
950945
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
951-
952946
rewriter.setInsertionPointToStart(fusedTerm.getBody());
953947
for (Operation &op : targetTerm.getYieldingOps())
954-
rewriter.clone(op, fusedMapping);
948+
rewriter.clone(op, mapping);
955949
for (Operation &op : sourceTerm.getYieldingOps())
956-
rewriter.clone(op, fusedMapping);
957-
958-
// Replace all uses of the old loops with the fused loop.
959-
rewriter.replaceAllUsesWith(target.getResults(),
960-
fusedLoop.getResults().slice(0, numTargetOuts));
961-
rewriter.replaceAllUsesWith(
962-
source.getResults(),
963-
fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
964-
965-
// Erase the old loops.
966-
rewriter.eraseOp(target);
967-
rewriter.eraseOp(source);
950+
rewriter.clone(op, mapping);
951+
952+
// Replace old loops by substituting their uses by results of the fused loop.
953+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
954+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
955+
956+
return fusedLoop;
957+
}
958+
959+
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
960+
scf::ForOp source,
961+
RewriterBase &rewriter) {
962+
unsigned numTargetOuts = target.getNumResults();
963+
unsigned numSourceOuts = source.getNumResults();
964+
965+
// Create fused init_args, with target's init_args before source's init_args.
966+
SmallVector<Value> fusedInitArgs;
967+
llvm::append_range(fusedInitArgs, target.getInitArgs());
968+
llvm::append_range(fusedInitArgs, source.getInitArgs());
969+
970+
// Create a new scf.for op after the source loop (with scf.yield terminator
971+
// (without arguments) only in case its init_args is empty).
972+
rewriter.setInsertionPointAfter(source);
973+
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
974+
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
975+
source.getStep(), fusedInitArgs);
976+
977+
// Map original induction variables and operands to those of the fused loop.
978+
IRMapping mapping;
979+
mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
980+
mapping.map(target.getRegionIterArgs(),
981+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
982+
mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
983+
mapping.map(source.getRegionIterArgs(),
984+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
985+
986+
// Merge target's body into the new (fused) for loop and then source's body.
987+
rewriter.setInsertionPointToStart(fusedLoop.getBody());
988+
for (Operation &op : target.getBody()->without_terminator())
989+
rewriter.clone(op, mapping);
990+
for (Operation &op : source.getBody()->without_terminator())
991+
rewriter.clone(op, mapping);
992+
993+
// Build fused yield results by appropriately mapping original yield operands.
994+
SmallVector<Value> yieldResults;
995+
for (Value operand : target.getBody()->getTerminator()->getOperands())
996+
yieldResults.push_back(mapping.lookupOrDefault(operand));
997+
for (Value operand : source.getBody()->getTerminator()->getOperands())
998+
yieldResults.push_back(mapping.lookupOrDefault(operand));
999+
if (!yieldResults.empty())
1000+
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1001+
1002+
// Replace old loops by substituting their uses by results of the fused loop.
1003+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1004+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
9681005

9691006
return fusedLoop;
9701007
}

0 commit comments

Comments
 (0)