Skip to content

Commit 6699807

Browse files
[mlir][linalg] Add bufferization for linalg.softmax (#97019)
Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.
1 parent 022d15c commit 6699807

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
162162
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
163163
}
164164
};
165+
166+
struct SoftmaxOpInterface
167+
: public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
168+
linalg::SoftmaxOp> {
169+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
170+
const AnalysisState &state) const {
171+
// Output operand is not read.
172+
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
173+
return &opOperand == &softmaxOp.getInputMutable();
174+
}
175+
176+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
177+
const BufferizationOptions &options) const {
178+
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
179+
FailureOr<Value> inputBuffer =
180+
getBuffer(rewriter, softmaxOp.getInput(), options);
181+
if (failed(inputBuffer))
182+
return failure();
183+
FailureOr<Value> outputBuffer =
184+
getBuffer(rewriter, softmaxOp.getOutput(), options);
185+
if (failed(outputBuffer))
186+
return failure();
187+
rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
188+
/*result=*/TypeRange(), *inputBuffer,
189+
*outputBuffer, softmaxOp.getDimension());
190+
replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
191+
return success();
192+
}
193+
};
165194
} // namespace
166195

167196
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
174203
#define GET_OP_LIST
175204
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
176205
>::registerOpInterface(ctx);
206+
207+
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
177208
});
178209
}

mlir/test/Dialect/Linalg/bufferize.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
189189
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
190190
// CHECK: return %[[OUT_TENSOR]]
191191
}
192+
193+
// -----
194+
195+
// CHECK-LABEL: func @bufferize_softmax(
196+
// CHECK-SAME: %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
197+
// CHECK: %[[m0:.*]] = bufferization.to_memref %[[arg0]]
198+
// CHECK: %[[alloc:.*]] = memref.alloc()
199+
// CHECK-NOT: memref.copy
200+
// CHECK: linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
201+
// CHECK: %[[result:.*]] = bufferization.to_tensor %[[alloc]]
202+
// CHECK: return %[[result]]
203+
func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
204+
%1 = linalg.softmax dimension(2)
205+
ins(%arg0 : tensor<2x16x32xf32>)
206+
outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
207+
return %1 : tensor<2x16x32xf32>
208+
}

0 commit comments

Comments
 (0)