Skip to content

Commit ba727ac

Browse files
authored
[mlir][bufferization][scf] Implement BufferDeallocationOpInterface for scf.reduce.return (#66886)
This is necessary to run the new buffer deallocation pipeline as part of the sparse compiler pipeline.
1 parent 57a5548 commit ba727ac

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,27 @@ struct InParallelOpInterface
5656
}
5757
};
5858

59+
struct ReduceReturnOpInterface
60+
: public BufferDeallocationOpInterface::ExternalModel<
61+
ReduceReturnOpInterface, scf::ReduceReturnOp> {
62+
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
63+
const DeallocationOptions &options) const {
64+
auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
65+
if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
66+
return op->emitError("only supported when operand is not a MemRef");
67+
68+
SmallVector<Value> updatedOperandOwnership;
69+
return deallocation_impl::insertDeallocOpForReturnLike(
70+
state, op, {}, updatedOperandOwnership);
71+
}
72+
};
73+
5974
} // namespace
6075

6176
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
6277
DialectRegistry &registry) {
6378
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
6479
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
80+
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
6581
});
6682
}

mlir/test/Dialect/SCF/buffer-deallocation.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,31 @@ func.func @parallel_insert_slice(%arg0: index) {
2222
// CHECK: }
2323
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
2424
// CHECK-NOT: retain
25+
26+
// -----
27+
28+
func.func @reduce(%buffer: memref<100xf32>) {
29+
%init = arith.constant 0.0 : f32
30+
%c0 = arith.constant 0 : index
31+
%c1 = arith.constant 1 : index
32+
scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
33+
%elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
34+
scf.reduce(%elem_to_reduce) : f32 {
35+
^bb0(%lhs : f32, %rhs: f32):
36+
%alloc = memref.alloc() : memref<2xf32>
37+
memref.store %lhs, %alloc [%c0] : memref<2xf32>
38+
memref.store %rhs, %alloc [%c1] : memref<2xf32>
39+
%0 = memref.load %alloc[%c0] : memref<2xf32>
40+
%1 = memref.load %alloc[%c1] : memref<2xf32>
41+
%res = arith.addf %0, %1 : f32
42+
scf.reduce.return %res : f32
43+
}
44+
}
45+
func.return
46+
}
47+
48+
// CHECK-LABEL: func @reduce
49+
// CHECK: scf.reduce
50+
// CHECK: [[ALLOC:%.+]] = memref.alloc(
51+
// CHECK: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
52+
// CHECK: scf.reduce.return

0 commit comments

Comments
 (0)