Skip to content

[mlir] Introduction of LocalEffectsOpInterface #130341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
RecursiveMemoryEffects,
LocalEffectsOpInterface,
DestinationStyleOpInterface,
LinalgStructuredInterface,
ReifyRankedShapedTypeOpInterface], props)> {
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ def CopyOp : MemRef_Op<"copy", [CopyOpInterface, SameOperandsElementType,
// DeallocOp
//===----------------------------------------------------------------------===//

def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
def MemRef_DeallocOp
: MemRef_Op<"dealloc", [MemRefsNormalizable, LocalEffectsOpInterface]> {
let summary = "memory deallocation operation";
let description = [{
The `dealloc` operation frees the region of memory referenced by a memref
Expand Down Expand Up @@ -1180,6 +1181,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
LocalEffectsOpInterface,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
Expand Down Expand Up @@ -1813,6 +1815,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
LocalEffectsOpInterface,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

namespace mlir {
class BranchOpInterface;
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_INTERFACES_CONTROLFLOWINTERFACES
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -115,7 +116,8 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//

def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
def RegionBranchOpInterface
: OpInterface<"RegionBranchOpInterface", [LocalEffectsOpInterface]> {
let description = [{
This interface provides information for region operations that exhibit
branching behavior between held regions. I.e., this interface allows for
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,15 @@ bool isSpeculatable(Operation *op);
/// This function is the C++ equivalent of the `Pure` trait.
bool isPure(Operation *op);

//===----------------------------------------------------------------------===//
// LocalEffects Utilities
//===----------------------------------------------------------------------===//

namespace detail {
/// Default implementation of `hasLocalEffects` method.
bool hasLocalEffectsDefaultImpl(Operation *op);
} // namespace detail

} // namespace mlir

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 25 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,29 @@ def RecursivelySpeculatable : TraitList<[
// are always legal to hoist or sink.
def Pure : TraitList<[AlwaysSpeculatable, NoMemoryEffect]>;

//===----------------------------------------------------------------------===//
// LocalEffects
//===----------------------------------------------------------------------===//

// Interface which could be implemented by imperative operators that have no
// effects on state outside of what’s directly available through their operands
// (for example, they can’t access a `memref.global`, can’t make a call to
// another function that can potentially do so, can’t perform a
// synchronization/wait on other pending memory operations, etc.), including
// through operators in their regions.
def LocalEffectsOpInterface : OpInterface<"LocalEffectsOpInterface"> {
let description = [{An interface for operators which have no effects on state
outside of what's directly available through their own
operands or operands of the operators inside their regions.
}];
let cppNamespace = "::mlir";

let methods =
[InterfaceMethod<[{ Returns true if operator has only local effects. }],
"bool", "hasLocalEffects", (ins), [{}], [{
return mlir::detail::hasLocalEffectsDefaultImpl(
$_op.getOperation());
}]>];
}

#endif // MLIR_INTERFACES_SIDEEFFECTS
29 changes: 13 additions & 16 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ bool MemRefDependenceGraph::init() {
// Create graph nodes.
DenseMap<Operation *, unsigned> forToNodeMap;
for (Operation &op : block) {
auto localEffectsOp = dyn_cast<LocalEffectsOpInterface>(op);
if (auto forOp = dyn_cast<AffineForOp>(op)) {
Node *node = addNodeToMDG(&op, *this, memrefAccesses);
if (!node)
Expand All @@ -277,27 +278,23 @@ bool MemRefDependenceGraph::init() {
Node *node = addNodeToMDG(&op, *this, memrefAccesses);
if (!node)
return false;
} else if (!isMemoryEffectFree(&op) &&
(op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
// Create graph node for top-level op unless it is known to be
// memory-effect free. This covers all unknown/unregistered ops,
// non-affine ops with memory effects, and region-holding ops with a
// well-defined control flow. During the fusion validity checks, edges
// to/from these ops get looked at.
} else if (isMemoryEffectFree(&op)) {
// Do not create nodes for memory-effect free ops w/o uses.
;
} else if (localEffectsOp && localEffectsOp.hasLocalEffects()) {
// Create graph node for top-level op which are known to have only local
// effects.
Node *node = addNodeToMDG(&op, *this, memrefAccesses);
if (!node)
return false;
} else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
// Return false if non-handled/unknown region-holding ops are found. We
// won't know what such ops do or what its regions mean; for e.g., it may
// not be an imperative op.
LLVM_DEBUG(llvm::dbgs()
<< "MDG init failed; unknown region-holding op found!\n");
} else {
// Return false if non-handled/unknown ops are found. We won't know what
// such ops do or what its regions mean; for e.g., it may not be an
// imperative op.
LLVM_DEBUG(llvm::dbgs() << "MDG init failed; unknown operator found:\n"
<< op << "\n");
return false;
}
// We aren't creating nodes for memory-effect free ops either with no
// regions (unless it has results being used) or those with branch op
// interface.
}

LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Interfaces/SideEffectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,24 @@ bool mlir::isSpeculatable(Operation *op) {
bool mlir::isPure(Operation *op) {
return isSpeculatable(op) && isMemoryEffectFree(op);
}

//===----------------------------------------------------------------------===//
// LocalEffects Utilities
//===----------------------------------------------------------------------===//

bool mlir::detail::hasLocalEffectsDefaultImpl(Operation *op) {
assert(isa<LocalEffectsOpInterface>(op) &&
"Operator does not implement LocalEffectsOpInterface");

// Recurse into the regions and ensure that all nested ops have local effects.
for (auto &region : op->getRegions()) {
for (auto &nestedOp : region.getOps()) {
auto localEffectsOp = dyn_cast<LocalEffectsOpInterface>(nestedOp);
auto hasLocalEffects = localEffectsOp && localEffectsOp.hasLocalEffects();
if (!isPure(&nestedOp) && !hasLocalEffects) {
return false;
}
}
}
return true;
}
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Affine/loop-fusion-4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,36 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,

// -----

// Check that presence of a Linalg operator in a block does not prevent
// fusion from happening in this block.

// ALL-LABEL: func @fusion_in_block_containing_linalg
func.func @fusion_in_block_containing_linalg(%arg0: memref<5xi8>, %arg1: memref<5xi8>) {
%c15_i8 = arith.constant 15 : i8
%alloc = memref.alloc() : memref<5xi8>
affine.for %arg3 = 0 to 5 {
affine.store %c15_i8, %alloc[%arg3] : memref<5xi8>
}
affine.for %arg3 = 0 to 5 {
%0 = affine.load %alloc[%arg3] : memref<5xi8>
%1 = affine.load %arg0[%arg3] : memref<5xi8>
%2 = arith.muli %0, %1 : i8
affine.store %2, %alloc[%arg3] : memref<5xi8>
}
// ALL: affine.for
// ALL-NEXT: affine.store
// ALL-NEXT: affine.load
// ALL-NEXT: affine.load
// ALL-NEXT: arith.muli
// ALL-NEXT: affine.store
// ALL-NEXT: }
linalg.elemwise_binary ins(%alloc, %alloc: memref<5xi8>, memref<5xi8>) outs(%arg1: memref<5xi8>)
// ALL-NEXT: linalg.elemwise_binary
return
}

// -----

// From https://github.com/llvm/llvm-project/issues/54541

#map = affine_map<(d0) -> (d0 mod 65536)>
Expand Down
Loading