Skip to content

Commit 1a5aa77

Browse files
[mlir][linalg] BufferizeToAllocationOp: Add option to specify custom alloc op
Supported ops are "memref.alloc" and "memref.alloca". Differential Revision: https://reviews.llvm.org/D155282
1 parent 88f4292 commit 1a5aa77

File tree

5 files changed

+49
-14
lines changed

5 files changed

+49
-14
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
128128
a fully dynamic layout is assumed for best compatibility. Users should use
129129
"memref.tensor_store" when possible.
130130

131+
"memref.alloc" is used for new buffer allocations. The buffer is deallocated
132+
at the end of the block. Custom allocation ops can be specified via
133+
`alloc_op`. Currently supported are "memref.alloc" and "memref.alloca". In
134+
case of a "memref.alloca", the buffer is not deallocated.
135+
131136
#### Return modes
132137

133138
This operation consumes the `target` handle and produces the
@@ -137,7 +142,9 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
137142
let arguments = (ins TransformHandleTypeInterface:$target,
138143
OptionalAttr<AnyAttr>:$memory_space,
139144
DefaultValuedAttr<StrAttr, "\"memref.tensor_store\"">:
140-
$memcpy_op);
145+
$memcpy_op,
146+
DefaultValuedAttr<StrAttr, "\"memref.alloc\"">:
147+
$alloc_op);
141148
let hasVerifier = 1;
142149
let results = (outs Transform_AnyValue:$allocated_buffer,
143150
Transform_AnyOpType:$new_ops);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
4747
//===----------------------------------------------------------------------===//
4848

4949
struct BufferizeToAllocationOptions {
50-
enum class MemcpyOp { MemrefTensorStore = 0, MemrefCopy = 1, LinalgCopy = 2 };
50+
enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 };
51+
AllocOp allocOp = AllocOp::MemrefAlloc;
5152

53+
enum class MemcpyOp { MemrefTensorStore = 0, MemrefCopy = 1, LinalgCopy = 2 };
5254
MemcpyOp memcpyOp = MemcpyOp::MemrefTensorStore;
5355
};
5456

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,15 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
248248
} else {
249249
llvm_unreachable("invalid memcpy op");
250250
}
251+
if (getAllocOp() == "memref.alloc") {
252+
options.allocOp =
253+
linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc;
254+
} else if (getAllocOp() == "memref.alloca") {
255+
options.allocOp =
256+
linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca;
257+
} else {
258+
llvm_unreachable("invalid alloc op");
259+
}
251260

252261
// Bufferize ops.
253262
Attribute memorySpace =
@@ -283,6 +292,8 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
283292
if (getMemcpyOp() != "memref.tensor_store" &&
284293
getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
285294
return emitOpError() << "unsupported memcpy op";
295+
if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
296+
return emitOpError() << "unsupported alloc op";
286297
return success();
287298
}
288299

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,10 @@ static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
185185
return dynSizes;
186186
}
187187

188-
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
189-
Value value,
190-
Attribute memorySpace = {}) {
188+
static Value
189+
createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value,
190+
const linalg::BufferizeToAllocationOptions &options,
191+
Attribute memorySpace = {}) {
191192
OpBuilder::InsertionGuard g(rewriter);
192193
auto tensorType = cast<RankedTensorType>(value.getType());
193194

@@ -196,11 +197,19 @@ static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
196197
cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
197198
tensorType, memorySpace));
198199
SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
199-
Value alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
200200

201-
// Place deallocation at the end of the block.
202-
rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
203-
rewriter.create<memref::DeallocOp>(loc, alloc);
201+
Value alloc;
202+
if (options.allocOp ==
203+
linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) {
204+
alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
205+
// Place deallocation at the end of the block.
206+
rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
207+
rewriter.create<memref::DeallocOp>(loc, alloc);
208+
} else if (options.allocOp ==
209+
linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) {
210+
alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
211+
// No dealloc is needed.
212+
}
204213

205214
return alloc;
206215
}
@@ -213,8 +222,8 @@ Value linalg::bufferizeToAllocation(
213222
Location loc = padOp.getLoc();
214223

215224
// Create buffer allocation.
216-
Value alloc =
217-
createAllocationForTensor(rewriter, loc, padOp.getResult(), memorySpace);
225+
Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(),
226+
options, memorySpace);
218227
rewriter.setInsertionPoint(padOp);
219228

220229
if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
@@ -491,8 +500,8 @@ Value linalg::bufferizeToAllocation(
491500
rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
492501
SmallVector<Value> allocs;
493502
for (OpOperand *operand : outOfPlaceOperands) {
494-
Value alloc = createAllocationForTensor(rewriter, op->getLoc(),
495-
operand->get(), memorySpace);
503+
Value alloc = createAllocationForTensor(
504+
rewriter, op->getLoc(), operand->get(), options, memorySpace);
496505
allocs.push_back(alloc);
497506
if (!state.findDefinitions(operand->get()).empty()) {
498507
// Initialize buffer with a copy of the operand data. Not needed if the

mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ transform.sequence failures(propagate) {
5050
// CHECK-LABEL: func @tensor_pad_constant_with_custom_copy(
5151
// CHECK-NOT: memref.tensor_store
5252
// CHECK-NOT: memref.copy
53+
// CHECK: memref.alloca
5354
// CHECK: linalg.copy
5455
func.func @tensor_pad_constant_with_custom_copy(
5556
%t: tensor<?x10xindex>, %l2: index, %h1: index, %h2: index)
@@ -66,7 +67,7 @@ func.func @tensor_pad_constant_with_custom_copy(
6667
transform.sequence failures(propagate) {
6768
^bb1(%arg1: !transform.any_op):
6869
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
69-
%2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 3, memcpy_op = "linalg.copy"}: !transform.any_op
70+
%2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 3, alloc_op = "memref.alloca", memcpy_op = "linalg.copy"}: !transform.any_op
7071

7172
// Ensure that one linalg.fill was generated.
7273
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
@@ -78,6 +79,11 @@ transform.sequence failures(propagate) {
7879
// expected-remark @below{{1}}
7980
test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
8081

82+
// Ensure that one memref.alloca was generated.
83+
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
84+
// expected-remark @below{{1}}
85+
test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
86+
8187
// Make sure that One-Shot Bufferize can bufferize the rest.
8288
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
8389
}

0 commit comments

Comments
 (0)