Skip to content

Commit a9f607f

Browse files
[mlir][bufferization] Support bufferization of external functions
This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value. This commit is in preparation of removing the `func-bufferize` pass.
1 parent 00ca207 commit a9f607f

File tree

4 files changed

+51
-48
lines changed

4 files changed

+51
-48
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ struct AliasingValue {
6060
bool isDefinite;
6161
};
6262

63-
template <typename T> class AliasList {
63+
template <typename T>
64+
class AliasList {
6465
public:
6566
/// Create an empty list of aliases.
6667
AliasList() = default;
@@ -259,7 +260,7 @@ struct BufferizationOptions {
259260
/// Initializer function for analysis state.
260261
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
261262
/// Tensor -> MemRef type converter.
262-
/// Parameters: Value, memory space, func op, bufferization options
263+
/// Parameters: tensor type, memory space, func op, bufferization options
263264
using FunctionArgTypeConverterFn =
264265
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
265266
func::FuncOp, const BufferizationOptions &)>;
@@ -344,9 +345,9 @@ struct BufferizationOptions {
344345
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
345346

346347
/// Type converter from tensors to memrefs. This type converter is used to
347-
/// determine bufferized function argument types. By default, a type
348-
/// converter that returns a memref type with a fully dynamic layout map is
349-
/// used.
348+
/// determine bufferized function argument and result types. By default, a
349+
/// type converter that returns a memref type with a fully dynamic layout map
350+
/// is used.
350351
///
351352
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
352353
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
8282

8383
/// Return the FuncOp called by `callOp`.
8484
static FuncOp getCalledFunction(CallOpInterface callOp) {
85-
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
85+
SymbolRefAttr sym =
86+
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
8687
if (!sym)
8788
return nullptr;
8889
return dyn_cast_or_null<FuncOp>(
@@ -392,36 +393,45 @@ struct FuncOpInterface
392393
auto funcOp = cast<FuncOp>(op);
393394
FunctionType funcType = funcOp.getFunctionType();
394395

395-
// Construct the bufferized function type.
396+
// Construct the bufferized function type. Compute the argument types.
396397
SmallVector<Type> argTypes;
397398
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
398399
Type argType = it.value();
399-
if (dyn_cast<TensorType>(argType)) {
400+
if (isa<TensorType>(argType)) {
400401
argTypes.push_back(
401402
getBufferizedFunctionArgType(funcOp, it.index(), options));
402403
continue;
403404
}
404405
argTypes.push_back(argType);
405406
}
406407

407-
// Bodiless functions are assumed opaque and we cannot know the
408-
// bufferization contract they want to enforce. As a consequence, only
409-
// support functions that don't return any tensors atm.
410-
if (funcOp.isExternal()) {
411-
SmallVector<Type> retTypes;
412-
for (Type resultType : funcType.getResults()) {
413-
if (isa<TensorType>(resultType))
414-
return funcOp->emitError() << "cannot bufferize bodiless function "
415-
<< "that returns a tensor";
408+
// Compute the result types.
409+
SmallVector<Type> retTypes;
410+
for (Type resultType : funcType.getResults()) {
411+
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
412+
BaseMemRefType resultType = options.functionArgTypeConverterFn(
413+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
414+
options);
416415
retTypes.push_back(resultType);
416+
continue;
417417
}
418-
funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
418+
retTypes.push_back(resultType);
419+
}
420+
421+
// Compute the new function type.
422+
auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
423+
424+
// If the function has no body, set the new function type and we are done.
425+
if (funcOp.isExternal()) {
426+
funcOp.setType(newFuncType);
419427
return success();
420428
}
421429

422430
// TODO: Support functions with multiple returns.
423431
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
424432
assert(returnOp && "expected func with single return op");
433+
assert(returnOp->getNumOperands() == retTypes.size() &&
434+
"incorrect number of return values");
425435
Location loc = returnOp.getLoc();
426436

427437
// 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
430440
options)))
431441
return failure();
432442

433-
// 2. For each result, keep track of which inplace argument it reuses.
443+
// 2. Bufferize all operands of the return op.
434444
SmallVector<Value> returnValues;
435-
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
436-
Value returnVal = returnOperand.get();
445+
for (auto [returnVal, bufferizedType] :
446+
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
437447
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
438448
rewriter.setInsertionPoint(returnOp);
439449

@@ -443,23 +453,17 @@ struct FuncOpInterface
443453
continue;
444454
}
445455

446-
// Note: If `inferFunctionResultLayout = true`, cast are later folded
456+
// Note: If `inferFunctionResultLayout = true`, casts are later folded
447457
// away.
448-
BaseMemRefType resultType = options.functionArgTypeConverterFn(
449-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
450-
options);
451458
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
452-
loc, resultType, returnVal);
459+
loc, bufferizedType, returnVal);
453460
returnValues.push_back(toMemrefOp);
454461
}
455462

456-
// 3. Rewrite the terminator without the in-place bufferizable values.
457463
returnOp.getOperandsMutable().assign(returnValues);
458464

459-
// 4. Rewrite the FuncOp type to buffer form.
460-
funcOp.setType(FunctionType::get(op->getContext(), argTypes,
461-
ValueRange(returnValues).getTypes()));
462-
465+
// 3. Set the new function type.
466+
funcOp.setType(newFuncType);
463467
return success();
464468
}
465469

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
22

3-
// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
4-
// expected-error @+1 {{failed to bufferize op}}
5-
func.func private @foo() -> tensor<?xf32>
6-
7-
// -----
8-
93
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
104
func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
115
-> (tensor<f32>, tensor<f32>)
@@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
123117

124118
// -----
125119

126-
// expected-error @+2 {{failed to bufferize op}}
127-
// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
128-
func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
129-
130-
func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
131-
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
132-
return
133-
}
134-
135-
// -----
136-
137120
func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
138121
%idx2 : index) -> f32 {
139122
%1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
4242

4343
// -----
4444

45+
// Bufferization of bodiless function that returns a tensor.
46+
47+
// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
48+
func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
49+
50+
// CHECK: func.func @call_to_unknown_tensor_returning_func(
51+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
52+
func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
53+
// CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
54+
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
55+
return
56+
}
57+
58+
// -----
59+
4560
// A function that returns a non-equivalent tensor with layout map.
4661

4762
// CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>

0 commit comments

Comments
 (0)