Skip to content

Commit 65341b0

Browse files
authored
[mlir][bufferization][NFC] Move memref specific implementation of AllocationOpInterface to memref dialect directory (#66637)
Follow-up on #65578
1 parent e88a64f commit 65341b0

File tree

9 files changed

+93
-62
lines changed

9 files changed

+93
-62
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,6 @@ std::unique_ptr<Pass> createBufferizationBufferizePass();
211211
// Registration
212212
//===----------------------------------------------------------------------===//
213213

214-
/// Register external models for AllocationOpInterface.
215-
void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
216-
217214
/// Generate the code for registering passes.
218215
#define GEN_PASS_REGISTRATION
219216
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- AllocationOpInterfaceImpl.h - Impl. of AllocationOpInterface -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace memref {
16+
void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace memref
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "mlir/Dialect/MemRef/IR/MemRef.h"
5252
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
5353
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
54+
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
5455
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
5556
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
5657
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -149,6 +150,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
149150
linalg::registerBufferizableOpInterfaceExternalModels(registry);
150151
linalg::registerTilingInterfaceExternalModels(registry);
151152
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
153+
memref::registerAllocationOpInterfaceExternalModels(registry);
152154
memref::registerBufferizableOpInterfaceExternalModels(registry);
153155
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
154156
memref::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,4 @@ class BufferizationTransformDialectExtension
174174
void mlir::bufferization::registerTransformDialectExtension(
175175
DialectRegistry &registry) {
176176
registry.addExtensions<BufferizationTransformDialectExtension>();
177-
bufferization::registerAllocationOpInterfaceExternalModels(registry);
178177
}

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@ struct BufferDeallocationPass
634634
void getDependentDialects(DialectRegistry &registry) const override {
635635
registry.insert<bufferization::BufferizationDialect>();
636636
registry.insert<memref::MemRefDialect>();
637-
registerAllocationOpInterfaceExternalModels(registry);
638637
}
639638

640639
void runOnOperation() override {

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ struct OneShotBufferizePass
195195
void getDependentDialects(DialectRegistry &registry) const override {
196196
registry
197197
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
198-
registerAllocationOpInterfaceExternalModels(registry);
199198
}
200199

201200
void runOnOperation() override {
@@ -672,59 +671,3 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
672671
options.opFilter.allowDialect<BufferizationDialect>();
673672
return options;
674673
}
675-
676-
//===----------------------------------------------------------------------===//
677-
// Default AllocationOpInterface implementation and registration
678-
//===----------------------------------------------------------------------===//
679-
680-
namespace {
681-
struct DefaultAllocationInterface
682-
: public bufferization::AllocationOpInterface::ExternalModel<
683-
DefaultAllocationInterface, memref::AllocOp> {
684-
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
685-
Value alloc) {
686-
return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
687-
.getOperation();
688-
}
689-
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
690-
return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
691-
.getResult();
692-
}
693-
static ::mlir::HoistingKind getHoistingKind() {
694-
return HoistingKind::Loop | HoistingKind::Block;
695-
}
696-
static ::std::optional<::mlir::Operation *>
697-
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
698-
Operation *definingOp = alloc.getDefiningOp();
699-
return builder.create<memref::AllocaOp>(
700-
definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
701-
definingOp->getOperands(), definingOp->getAttrs());
702-
}
703-
};
704-
705-
struct DefaultAutomaticAllocationHoistingInterface
706-
: public bufferization::AllocationOpInterface::ExternalModel<
707-
DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
708-
static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
709-
};
710-
711-
struct DefaultReallocationInterface
712-
: public bufferization::AllocationOpInterface::ExternalModel<
713-
DefaultAllocationInterface, memref::ReallocOp> {
714-
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
715-
Value realloc) {
716-
return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
717-
.getOperation();
718-
}
719-
};
720-
} // namespace
721-
722-
void bufferization::registerAllocationOpInterfaceExternalModels(
723-
DialectRegistry &registry) {
724-
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
725-
memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
726-
memref::AllocaOp::attachInterface<
727-
DefaultAutomaticAllocationHoistingInterface>(*ctx);
728-
memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
729-
});
730-
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- AllocationOpInterfaceImpl.cpp - Impl. of AllocationOpInterface -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
12+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/IR/Dialect.h"
15+
#include "mlir/IR/Operation.h"
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
struct DefaultAllocationInterface
21+
: public bufferization::AllocationOpInterface::ExternalModel<
22+
DefaultAllocationInterface, memref::AllocOp> {
23+
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
24+
Value alloc) {
25+
return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
26+
.getOperation();
27+
}
28+
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
29+
return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
30+
.getResult();
31+
}
32+
static ::mlir::HoistingKind getHoistingKind() {
33+
return HoistingKind::Loop | HoistingKind::Block;
34+
}
35+
static ::std::optional<::mlir::Operation *>
36+
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
37+
Operation *definingOp = alloc.getDefiningOp();
38+
return builder.create<memref::AllocaOp>(
39+
definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
40+
definingOp->getOperands(), definingOp->getAttrs());
41+
}
42+
};
43+
44+
struct DefaultAutomaticAllocationHoistingInterface
45+
: public bufferization::AllocationOpInterface::ExternalModel<
46+
DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
47+
static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
48+
};
49+
50+
struct DefaultReallocationInterface
51+
: public bufferization::AllocationOpInterface::ExternalModel<
52+
DefaultAllocationInterface, memref::ReallocOp> {
53+
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
54+
Value realloc) {
55+
return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
56+
.getOperation();
57+
}
58+
};
59+
} // namespace
60+
61+
void mlir::memref::registerAllocationOpInterfaceExternalModels(
62+
DialectRegistry &registry) {
63+
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
64+
memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
65+
memref::AllocaOp::attachInterface<
66+
DefaultAutomaticAllocationHoistingInterface>(*ctx);
67+
memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
68+
});
69+
}

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRMemRefTransforms
2+
AllocationOpInterfaceImpl.cpp
23
BufferizableOpInterfaceImpl.cpp
34
ComposeSubView.cpp
45
ExpandOps.cpp

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11738,6 +11738,7 @@ cc_library(
1173811738
":AffineDialect",
1173911739
":AffineTransforms",
1174011740
":AffineUtils",
11741+
":AllocationOpInterface",
1174111742
":ArithDialect",
1174211743
":ArithTransforms",
1174311744
":ArithUtils",

0 commit comments

Comments
 (0)