@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
82
82
83
83
// / Return the FuncOp called by `callOp`.
84
84
static FuncOp getCalledFunction (CallOpInterface callOp) {
85
- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
85
+ SymbolRefAttr sym =
86
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
86
87
if (!sym)
87
88
return nullptr ;
88
89
return dyn_cast_or_null<FuncOp>(
@@ -392,36 +393,45 @@ struct FuncOpInterface
392
393
auto funcOp = cast<FuncOp>(op);
393
394
FunctionType funcType = funcOp.getFunctionType ();
394
395
395
- // Construct the bufferized function type.
396
+ // Construct the bufferized function type. Compute the argument types.
396
397
SmallVector<Type> argTypes;
397
398
for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
398
399
Type argType = it.value ();
399
- if (dyn_cast <TensorType>(argType)) {
400
+ if (isa <TensorType>(argType)) {
400
401
argTypes.push_back (
401
402
getBufferizedFunctionArgType (funcOp, it.index (), options));
402
403
continue ;
403
404
}
404
405
argTypes.push_back (argType);
405
406
}
406
407
407
- // Bodiless functions are assumed opaque and we cannot know the
408
- // bufferization contract they want to enforce. As a consequence, only
409
- // support functions that don't return any tensors atm.
410
- if (funcOp.isExternal ()) {
411
- SmallVector<Type> retTypes;
412
- for (Type resultType : funcType.getResults ()) {
413
- if (isa<TensorType>(resultType))
414
- return funcOp->emitError () << " cannot bufferize bodiless function "
415
- << " that returns a tensor" ;
408
+ // Compute the result types.
409
+ SmallVector<Type> retTypes;
410
+ for (Type resultType : funcType.getResults ()) {
411
+ if (auto tensorType = dyn_cast<TensorType>(resultType)) {
412
+ BaseMemRefType resultType = options.functionArgTypeConverterFn (
413
+ tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
414
+ options);
416
415
retTypes.push_back (resultType);
416
+ continue ;
417
417
}
418
- funcOp.setType (FunctionType::get (op->getContext (), argTypes, retTypes));
418
+ retTypes.push_back (resultType);
419
+ }
420
+
421
+ // Compute the new function type.
422
+ auto newFuncType = FunctionType::get (op->getContext (), argTypes, retTypes);
423
+
424
+ // If the function has no body, set the new function type and we are done.
425
+ if (funcOp.isExternal ()) {
426
+ funcOp.setType (newFuncType);
419
427
return success ();
420
428
}
421
429
422
430
// TODO: Support functions with multiple returns.
423
431
func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
424
432
assert (returnOp && " expected func with single return op" );
433
+ assert (returnOp->getNumOperands () == retTypes.size () &&
434
+ " incorrect number of return values" );
425
435
Location loc = returnOp.getLoc ();
426
436
427
437
// 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
430
440
options)))
431
441
return failure ();
432
442
433
- // 2. For each result, keep track of which inplace argument it reuses .
443
+ // 2. Bufferize all operands of the return op .
434
444
SmallVector<Value> returnValues;
435
- for (OpOperand &returnOperand : returnOp-> getOpOperands ()) {
436
- Value returnVal = returnOperand. get ();
445
+ for (auto [returnVal, bufferizedType] :
446
+ llvm::zip_equal (returnOp-> getOperands (), retTypes)) {
437
447
auto tensorType = dyn_cast<TensorType>(returnVal.getType ());
438
448
rewriter.setInsertionPoint (returnOp);
439
449
@@ -443,23 +453,17 @@ struct FuncOpInterface
443
453
continue ;
444
454
}
445
455
446
- // Note: If `inferFunctionResultLayout = true`, cast are later folded
456
+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
447
457
// away.
448
- BaseMemRefType resultType = options.functionArgTypeConverterFn (
449
- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
450
- options);
451
458
Value toMemrefOp = rewriter.create <bufferization::ToMemrefOp>(
452
- loc, resultType , returnVal);
459
+ loc, bufferizedType , returnVal);
453
460
returnValues.push_back (toMemrefOp);
454
461
}
455
462
456
- // 3. Rewrite the terminator without the in-place bufferizable values.
457
463
returnOp.getOperandsMutable ().assign (returnValues);
458
464
459
- // 4. Rewrite the FuncOp type to buffer form.
460
- funcOp.setType (FunctionType::get (op->getContext (), argTypes,
461
- ValueRange (returnValues).getTypes ()));
462
-
465
+ // 3. Set the new function type.
466
+ funcOp.setType (newFuncType);
463
467
return success ();
464
468
}
465
469
0 commit comments