-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR] Add bufferization state class to OneShotBufferization pass #138143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
adf34f0
to
45e0383
Compare
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-bufferization Author: Michele Scuttari (mscuttari) ChangesThis PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the Patch is 55.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138143.diff 27 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cb6ef8bc17220..d644f49573a35 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,6 +578,81 @@ class AnalysisState {
insideMutuallyExclusiveRegionsCache;
};
+/// BufferizationState provides information about the state of the IR during the
+/// bufferization process.
+class BufferizationState {
+public:
+ /// Base class for BufferizationState extensions that allow BufferizationState
+ /// to contain user-specified information in the state object. The extension
+ /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState.
+ class Extension {
+ public:
+ /// Base virtual destructor.
+ // Out-of-line definition ensures symbols are emitted in a single object
+ // file.
+ virtual ~Extension();
+
+ protected:
+ /// Constructs an extension of the given state object.
+ Extension(BufferizationState &state) : state(state) {}
+
+ /// Provides read-only access to the parent OneShotAnalysisState object.
+ const BufferizationState &getBufferizationState() const { return state; }
+
+ private:
+ /// Back-reference to the state that is being extended.
+ BufferizationState &state;
+ };
+
+ /// Adds a new Extension of the type specified as template parameter,
+ /// constructing it with the arguments provided. The extension is owned by the
+ /// BufferizationState. It is expected that the state does not already have an
+ /// extension of the same type. Extension constructors are expected to take a
+ /// reference to BufferizationState as first argument, automatically supplied
+ /// by this call.
+ template <typename Ty, typename... Args>
+ Ty &addExtension(Args &&...args) {
+ static_assert(std::is_base_of<Extension, Ty>::value,
+ "only a class derived from "
+ "BufferizationState::Extension is allowed");
+ auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+ auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+ assert(result.second && "extension already added");
+ return *static_cast<Ty *>(result.first->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ Ty *getExtension() {
+ static_assert(std::is_base_of<Extension, Ty>::value,
+ "only a class derived from "
+ "BufferizationState::Extension is allowed");
+ auto iter = extensions.find(TypeID::get<Ty>());
+ if (iter == extensions.end())
+ return nullptr;
+ return static_cast<Ty *>(iter->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ const Ty *getExtension() const {
+ return const_cast<BufferizationState *>(this)->getExtension<Ty>();
+ }
+
+ /// Get a reference to the collection of cached symbol tables.
+ SymbolTableCollection &getSymbolTables();
+
+private:
+ /// Extensions attached to the state, identified by the TypeID of their type.
+ /// Only one extension of any given type is allowed.
+ DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+
+ /// The cached symbol tables.
+ /// The user is expected to update / invalidate the cached symbol tables if
+ /// the bufferized operation has the Symbol or SymbolTable traits.
+ SymbolTableCollection symbolTables;
+};
+
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..b599a9f053215 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::BufferizationOptions &":$options),
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 7a1a701bea6dc..dafa4b9b183f2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool resultBufferizesToMemoryWrite(OpResult opResult,
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool bufferizesToMemoryRead(OpOperand &opOperand,
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
}
@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
//===------------------------------------------------------------------===//
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
// to_tensor/to_buffer pairs fold away after bufferization.
return success();
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index e5f3b6d571f43..c08bd6c436133 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,6 +29,7 @@ class GlobalOp;
} // namespace memref
namespace bufferization {
+class BufferizationState;
/// A simple analysis that detects allocation operations.
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
uint64_t alignment,
Attribute memorySpace = {});
+void removeSymbol(Operation *op, BufferizationState &state);
+
+void insertSymbol(Operation *op, BufferizationState &state);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index d5cb8d8eb673c..70e3defee0867 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
/// additional buffer copies or set "options.copyBeforeWrite = true". The
/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 673027f76190d..15189d2c1cb87 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 4e5f5e9c730fa..2cf801dd1d951 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,6 +20,7 @@ namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
+class BufferizationState;
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics = nullptr);
+ BufferizationState &state, BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4f90fc8831bc6..2eef0a06d0eb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
+class BufferizationState;
} // namespace bufferization
namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..f646326ffc58f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,7 +24,8 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
@@ -46,7 +47,8 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
+ getGlobalFor(constantOp, state.getSymbolTables(),
+ options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
@@ -131,7 +134,8 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1fc34051680f1..d6224b012ac95 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,6 +125,12 @@ void AnalysisState::resetCache() {
insideMutuallyExclusiveRegionsCache.clear();
}
+BufferizationState::Extension::~Extension() = default;
+
+SymbolTableCollection &BufferizationState::getSymbolTables() {
+ return symbolTables;
+}
+
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ecd2ef15546a4..91eccb0ab7430 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
if (failed(buffer))
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToBufferToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a1d7bb995fc73..db1eb20512033 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,6 +83,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
}
auto payloadOps = state.getPayloadOps(getTarget());
+ BufferizationState bufferizationState;
+
for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
@@ -90,10 +92,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
+ if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
- if (failed(bufferization::runOneShotBufferize(target, options)))
+ if (failed(bufferization::runOneShotBufferize(target, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
}
}
@@ -162,6 +166,7 @@ class BufferizationTransformDialectExtension
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
+
>();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index c2e90764b1335..ff2c83d228dbb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
//===----------------------------------------------------------------------===//
FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
- Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
+ uint64_t alignment, Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(moduleOp.getContext());
- Symbo...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Michele Scuttari (mscuttari) ChangesThis PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the Patch is 55.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138143.diff 27 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cb6ef8bc17220..d644f49573a35 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,6 +578,81 @@ class AnalysisState {
insideMutuallyExclusiveRegionsCache;
};
+/// BufferizationState provides information about the state of the IR during the
+/// bufferization process.
+class BufferizationState {
+public:
+ /// Base class for BufferizationState extensions that allow BufferizationState
+ /// to contain user-specified information in the state object. The extension
+ /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState.
+ class Extension {
+ public:
+ /// Base virtual destructor.
+ // Out-of-line definition ensures symbols are emitted in a single object
+ // file.
+ virtual ~Extension();
+
+ protected:
+ /// Constructs an extension of the given state object.
+ Extension(BufferizationState &state) : state(state) {}
+
+ /// Provides read-only access to the parent OneShotAnalysisState object.
+ const BufferizationState &getBufferizationState() const { return state; }
+
+ private:
+ /// Back-reference to the state that is being extended.
+ BufferizationState &state;
+ };
+
+ /// Adds a new Extension of the type specified as template parameter,
+ /// constructing it with the arguments provided. The extension is owned by the
+ /// BufferizationState. It is expected that the state does not already have an
+ /// extension of the same type. Extension constructors are expected to take a
+ /// reference to BufferizationState as first argument, automatically supplied
+ /// by this call.
+ template <typename Ty, typename... Args>
+ Ty &addExtension(Args &&...args) {
+ static_assert(std::is_base_of<Extension, Ty>::value,
+ "only a class derived from "
+ "BufferizationState::Extension is allowed");
+ auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+ auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+ assert(result.second && "extension already added");
+ return *static_cast<Ty *>(result.first->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ Ty *getExtension() {
+ static_assert(std::is_base_of<Extension, Ty>::value,
+ "only a class derived from "
+ "BufferizationState::Extension is allowed");
+ auto iter = extensions.find(TypeID::get<Ty>());
+ if (iter == extensions.end())
+ return nullptr;
+ return static_cast<Ty *>(iter->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ const Ty *getExtension() const {
+ return const_cast<BufferizationState *>(this)->getExtension<Ty>();
+ }
+
+ /// Get a reference to the collection of cached symbol tables.
+ SymbolTableCollection &getSymbolTables();
+
+private:
+ /// Extensions attached to the state, identified by the TypeID of their type.
+ /// Only one extension of any given type is allowed.
+ DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+
+ /// The cached symbol tables.
+ /// The user is expected to update / invalidate the cached symbol tables if
+ /// the bufferized operation has the Symbol or SymbolTable traits.
+ SymbolTableCollection symbolTables;
+};
+
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..b599a9f053215 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::BufferizationOptions &":$options),
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 7a1a701bea6dc..dafa4b9b183f2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool resultBufferizesToMemoryWrite(OpResult opResult,
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool bufferizesToMemoryRead(OpOperand &opOperand,
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
}
@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
//===------------------------------------------------------------------===//
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
// to_tensor/to_buffer pairs fold away after bufferization.
return success();
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index e5f3b6d571f43..c08bd6c436133 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,6 +29,7 @@ class GlobalOp;
} // namespace memref
namespace bufferization {
+class BufferizationState;
/// A simple analysis that detects allocation operations.
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
uint64_t alignment,
Attribute memorySpace = {});
+void removeSymbol(Operation *op, BufferizationState &state);
+
+void insertSymbol(Operation *op, BufferizationState &state);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index d5cb8d8eb673c..70e3defee0867 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
/// additional buffer copies or set "options.copyBeforeWrite = true". The
/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 673027f76190d..15189d2c1cb87 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 4e5f5e9c730fa..2cf801dd1d951 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,6 +20,7 @@ namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
+class BufferizationState;
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics = nullptr);
+ BufferizationState &state, BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4f90fc8831bc6..2eef0a06d0eb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
+class BufferizationState;
} // namespace bufferization
namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..f646326ffc58f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,7 +24,8 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
@@ -46,7 +47,8 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
+ getGlobalFor(constantOp, state.getSymbolTables(),
+ options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
@@ -131,7 +134,8 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1fc34051680f1..d6224b012ac95 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,6 +125,12 @@ void AnalysisState::resetCache() {
insideMutuallyExclusiveRegionsCache.clear();
}
+BufferizationState::Extension::~Extension() = default;
+
+SymbolTableCollection &BufferizationState::getSymbolTables() {
+ return symbolTables;
+}
+
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ecd2ef15546a4..91eccb0ab7430 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
if (failed(buffer))
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToBufferToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a1d7bb995fc73..db1eb20512033 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,6 +83,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
}
auto payloadOps = state.getPayloadOps(getTarget());
+ BufferizationState bufferizationState;
+
for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
@@ -90,10 +92,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
+ if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
- if (failed(bufferization::runOneShotBufferize(target, options)))
+ if (failed(bufferization::runOneShotBufferize(target, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
}
}
@@ -162,6 +166,7 @@ class BufferizationTransformDialectExtension
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
+
>();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index c2e90764b1335..ff2c83d228dbb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
//===----------------------------------------------------------------------===//
FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
- Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
+ uint64_t alignment, Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(moduleOp.getContext());
- Symbo...
[truncated]
|
@llvm/pr-subscribers-mlir-arith Author: Michele Scuttari (mscuttari) ChangesThis PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the Patch is 55.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138143.diff 27 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cb6ef8bc17220..d644f49573a35 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,6 +578,81 @@ class AnalysisState {
insideMutuallyExclusiveRegionsCache;
};
+/// BufferizationState provides information about the state of the IR during the
+/// bufferization process.
+class BufferizationState {
+public:
+ /// Base class for BufferizationState extensions that allow BufferizationState
+ /// to contain user-specified information in the state object. The extension
+ /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState.
+ class Extension {
+ public:
+ /// Base virtual destructor.
+ // Out-of-line definition ensures symbols are emitted in a single object
+ // file.
+ virtual ~Extension();
+
+ protected:
+ /// Constructs an extension of the given state object.
+ Extension(BufferizationState &state) : state(state) {}
+
+ /// Provides read-only access to the parent OneShotAnalysisState object.
+ const BufferizationState &getBufferizationState() const { return state; }
+
+ private:
+ /// Back-reference to the state that is being extended.
+ BufferizationState &state;
+ };
+
+ /// Adds a new Extension of the type specified as template parameter,
+ /// constructing it with the arguments provided. The extension is owned by the
+ /// BufferizationState. It is expected that the state does not already have an
+ /// extension of the same type. Extension constructors are expected to take a
+ /// reference to BufferizationState as first argument, automatically supplied
+ /// by this call.
+ template <typename Ty, typename... Args>
+ Ty &addExtension(Args &&...args) {
+ static_assert(std::is_base_of<Extension, Ty>::value,
+ "only a class derived from "
+ "BufferizationState::Extension is allowed");
+ auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+ auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+ assert(result.second && "extension already added");
+ return *static_cast<Ty *>(result.first->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ Ty *getExtension() {
+ static_assert(std::is_base_of<Extension, Ty>::value,
+ "only a class derived from "
+ "BufferizationState::Extension is allowed");
+ auto iter = extensions.find(TypeID::get<Ty>());
+ if (iter == extensions.end())
+ return nullptr;
+ return static_cast<Ty *>(iter->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ const Ty *getExtension() const {
+ return const_cast<BufferizationState *>(this)->getExtension<Ty>();
+ }
+
+ /// Get a reference to the collection of cached symbol tables.
+ SymbolTableCollection &getSymbolTables();
+
+private:
+ /// Extensions attached to the state, identified by the TypeID of their type.
+ /// Only one extension of any given type is allowed.
+ DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+
+ /// The cached symbol tables.
+ /// The user is expected to update / invalidate the cached symbol tables if
+ /// the bufferized operation has the Symbol or SymbolTable traits.
+ SymbolTableCollection symbolTables;
+};
+
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..b599a9f053215 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::BufferizationOptions &":$options),
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 7a1a701bea6dc..dafa4b9b183f2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool resultBufferizesToMemoryWrite(OpResult opResult,
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool bufferizesToMemoryRead(OpOperand &opOperand,
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
}
@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
//===------------------------------------------------------------------===//
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
// to_tensor/to_buffer pairs fold away after bufferization.
return success();
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index e5f3b6d571f43..c08bd6c436133 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,6 +29,7 @@ class GlobalOp;
} // namespace memref
namespace bufferization {
+class BufferizationState;
/// A simple analysis that detects allocation operations.
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
uint64_t alignment,
Attribute memorySpace = {});
+void removeSymbol(Operation *op, BufferizationState &state);
+
+void insertSymbol(Operation *op, BufferizationState &state);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index d5cb8d8eb673c..70e3defee0867 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
/// additional buffer copies or set "options.copyBeforeWrite = true". The
/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 673027f76190d..15189d2c1cb87 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 4e5f5e9c730fa..2cf801dd1d951 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,6 +20,7 @@ namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
+class BufferizationState;
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics = nullptr);
+ BufferizationState &state, BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4f90fc8831bc6..2eef0a06d0eb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
+class BufferizationState;
} // namespace bufferization
namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..f646326ffc58f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,7 +24,8 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
@@ -46,7 +47,8 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
+ getGlobalFor(constantOp, state.getSymbolTables(),
+ options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
@@ -131,7 +134,8 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1fc34051680f1..d6224b012ac95 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,6 +125,12 @@ void AnalysisState::resetCache() {
insideMutuallyExclusiveRegionsCache.clear();
}
+BufferizationState::Extension::~Extension() = default;
+
+SymbolTableCollection &BufferizationState::getSymbolTables() {
+ return symbolTables;
+}
+
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ecd2ef15546a4..91eccb0ab7430 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
if (failed(buffer))
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToBufferToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a1d7bb995fc73..db1eb20512033 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,6 +83,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
}
auto payloadOps = state.getPayloadOps(getTarget());
+ BufferizationState bufferizationState;
+
for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
@@ -90,10 +92,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
+ if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
- if (failed(bufferization::runOneShotBufferize(target, options)))
+ if (failed(bufferization::runOneShotBufferize(target, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
}
}
@@ -162,6 +166,7 @@ class BufferizationTransformDialectExtension
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
+
>();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index c2e90764b1335..ff2c83d228dbb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
//===----------------------------------------------------------------------===//
FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
- Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
+ uint64_t alignment, Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(moduleOp.getContext());
- Symbo...
[truncated]
|
@matthias-springer the current state of the PR should reflect the decisions we came to in the Discourse thread. Please let me know if something is missing or improvable. P.S.: after the rebase, buildkite seems not to like the PR anymore. Any idea on how to fix it? Tests pass locally, but I'd like also the CI to succeed before merging. |
The buildkite error looks unrelated to me. |
/// Base class for BufferizationState extensions that allow BufferizationState | ||
/// to contain user-specified information in the state object. The extension | ||
/// mechanism of BufferizationState mirrors the one of OneShotAnalysisState. | ||
class Extension { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this entire extension mechanism is not needed anymore. Let's remove it for now, and bring it back if we find a use case for it in the future.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/10078 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/10056 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/11265 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/32523 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/21223 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/28717 Here is the relevant piece of the build log for the reference
|
…ferization pass" (#141012) Reverts llvm/llvm-project#138143 The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by PR checks.
This PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR.
The information currently consists of a cached list of symbol tables, which aims to solve the quadratic scaling of the bufferization task with respect to the number of symbols.
The PR breaks API compatibility: the
bufferize
method of theBufferizableOpInterface
has been enriched with a reference to aBufferizationState
object.The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the
Symbol
trait is inserted or replaced, its parentSymbolTable
must be updated accordingly (see, for example, the bufferization ofarith::ConstantOp
, where the symbol table of the module gets the new global symbol inserted). Similarly, the invalidation of a symbol table must be performed if an operation with theSymbolTable
trait is removed (this can be performed using theinvalidateSymbolTable
method, introduced in #138014).