Skip to content

Commit 5310be5

Browse files
committed
[mlir] make fuse_into_containing_op preserve the containing op handle
This partially undoes the intent of https://reviews.llvm.org/D151418 by cheating its way to keep the "containing op" (aka loop) handle read-only in fusion. It is crucial to do so for composability of tiling and fusion. Specfically, after the "containing op" handle started being consumed, it became impossible to perform additional tiling after fusion except tiling the last-fused op: %tiled1, %loop1 = tile %op %producer1, %loop2 = fuse %producer into %loop1 // invalid, because %tiled1 is invalidated by consuming %loop1 // that points to its parent tile %tiled1 or %tiled1, %loop1 = tile %op %tiled2, %loop2 = tile %tiled1 %p2 = fuse %producer into %loop1 // invalid, because %loop2 is invalidated by consuming %loop1 // that points to its parent fuse %p2 into %loop2 The approach here makes creative use of the state extension mechanism to update the payload operation associted with the operand handle. Further investigation is necessary to understand if is consistent with the overall execution model of the transform dialect, but it is crucial to restore composability ASAP. Reviewed By: springerm, nicolasvasilache Differential Revision: https://reviews.llvm.org/D151555
1 parent 44f6e86 commit 5310be5

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/IR/TypeUtilities.h"
3535
#include "mlir/Interfaces/TilingInterface.h"
3636
#include "mlir/Support/LLVM.h"
37+
#include "mlir/Support/TypeID.h"
3738
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3839
#include "llvm/ADT/STLExtras.h"
3940
#include "llvm/ADT/ScopeExit.h"
@@ -663,6 +664,36 @@ bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
663664
return true;
664665
}
665666

667+
namespace {
668+
/// Unsafely exposes an internal protected method of TransformState::Extension
669+
/// as public.
670+
///
671+
/// MUST NOT be used directly.
672+
class UnsafeOpReplacementStateExtension : public TransformState::Extension {
673+
public:
674+
UnsafeOpReplacementStateExtension(TransformState &state)
675+
: TransformState::Extension(state) {}
676+
677+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
678+
UnsafeOpReplacementStateExtension)
679+
680+
LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) {
681+
return replacePayloadOp(op, replacement);
682+
}
683+
};
684+
} // namespace
685+
686+
/// Replaces `payload` with `replacement` in all handles stored in the state.
687+
/// MUST NOT be used except for the case immediately below.
688+
static void forciblyReplaceReferencedPayloadOperation(TransformState &state,
689+
Operation *payload,
690+
Operation *replacement) {
691+
UnsafeOpReplacementStateExtension extension(state);
692+
// This may return failure if the payload is not associated with any handle,
693+
// ignore that.
694+
(void)extension.doReplacePayloadOp(payload, replacement);
695+
}
696+
666697
DiagnosedSilenceableFailure
667698
transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
668699
transform::TransformState &state) {
@@ -757,6 +788,14 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
757788
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
758789
}
759790

791+
// Update handles associated with the containing op so we don't need to
792+
// invalidate them. This is a hack to support better composability between
793+
// tiling and fusion while a proper mechanism is being investigated.
794+
//
795+
// DO NOT replicate this elsewhere unless you understand what you are doing.
796+
forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(),
797+
containingOp);
798+
760799
results.set(cast<OpResult>(getFusedOp()), fusedOps);
761800
results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
762801
return DiagnosedSilenceableFailure::success();
@@ -765,7 +804,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
765804
void transform::FuseIntoContainingOp::getEffects(
766805
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
767806
consumesHandle(getProducerOp(), effects);
768-
consumesHandle(getContainingOp(), effects);
807+
onlyReadsHandle(getContainingOp(), effects);
769808
producesHandle(getResults(), effects);
770809
modifiesPayload(effects);
771810
}

mlir/test/Dialect/Linalg/transform-ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,15 @@ transform.sequence failures(propagate) {
3535
// CHECK: transform.structured.scalarize
3636
%0 = transform.structured.scalarize %arg0 : (!transform.any_op) -> !transform.any_op
3737
}
38+
39+
// Check that the second argument of `fuse_into_containing_op` is not consumed
40+
// (if it had been, we would have seen a diagnostic about multiple consumers).
41+
transform.sequence failures(propagate) {
42+
^bb1(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
43+
%loop = transform.structured.match ops{["scf.forall"]} in %arg0
44+
: (!transform.any_op) -> !transform.any_op
45+
%0:2 = transform.structured.fuse_into_containing_op %arg1 into %loop
46+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
47+
%1:2 = transform.structured.fuse_into_containing_op %arg2 into %loop
48+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
49+
}

0 commit comments

Comments
 (0)