@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
75
75
using namespace mlir ::bufferization::func_ext;
76
76
77
77
// / A mapping of FuncOps to their callers.
78
- using FuncCallerMap = DenseMap<func::FuncOp , DenseSet<Operation *>>;
78
+ using FuncCallerMap = DenseMap<FunctionOpInterface , DenseSet<Operation *>>;
79
79
80
80
// / Get or create FuncAnalysisState.
81
81
static FuncAnalysisState &
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
88
88
89
89
// / Return the unique ReturnOp that terminates `funcOp`.
90
90
// / Return nullptr if there is no such unique ReturnOp.
91
- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92
- func::ReturnOp returnOp;
93
- for (Block &b : funcOp.getBody ()) {
94
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
91
+ static Operation *getAssumedUniqueReturnOp (FunctionOpInterface funcOp) {
92
+ Operation *returnOp = nullptr ;
93
+ for (Block &b : funcOp.getFunctionBody ()) {
94
+ auto candidateOp = b.getTerminator ();
95
+ if (candidateOp && candidateOp->hasTrait <OpTrait::ReturnLike>()) {
95
96
if (returnOp)
96
97
return nullptr ;
97
98
returnOp = candidateOp;
@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
126
127
// / Store function BlockArguments that are equivalent to/aliasing a returned
127
128
// / value in FuncAnalysisState.
128
129
static LogicalResult
129
- aliasingFuncOpBBArgsAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
130
+ aliasingFuncOpBBArgsAnalysis (FunctionOpInterface funcOp,
131
+ OneShotAnalysisState &state,
130
132
FuncAnalysisState &funcState) {
131
- if (funcOp.getBody ().empty ()) {
133
+ if (funcOp.getFunctionBody ().empty ()) {
132
134
// No function body available. Conservatively assume that every tensor
133
135
// return value may alias with any tensor bbArg.
134
- FunctionType type = funcOp.getFunctionType ();
135
- for (const auto &inputIt : llvm::enumerate (type.getInputs ())) {
136
+ for (const auto &inputIt : llvm::enumerate (funcOp.getArgumentTypes ())) {
136
137
if (!isa<TensorType>(inputIt.value ()))
137
138
continue ;
138
- for (const auto &resultIt : llvm::enumerate (type. getResults ())) {
139
+ for (const auto &resultIt : llvm::enumerate (funcOp. getResultTypes ())) {
139
140
if (!isa<TensorType>(resultIt.value ()))
140
141
continue ;
141
142
int64_t returnIdx = resultIt.index ();
@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
147
148
}
148
149
149
150
// Support only single return-terminated block in the function.
150
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151
+ Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
151
152
assert (returnOp && " expected func with single return op" );
152
153
153
154
for (OpOperand &returnVal : returnOp->getOpOperands ())
@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
168
169
return success ();
169
170
}
170
171
171
- static void annotateFuncArgAccess (func::FuncOp funcOp, int64_t idx, bool isRead ,
172
- bool isWritten) {
172
+ static void annotateFuncArgAccess (FunctionOpInterface funcOp, int64_t idx,
173
+ bool isRead, bool isWritten) {
173
174
OpBuilder b (funcOp.getContext ());
174
175
Attribute accessType;
175
176
if (isRead && isWritten) {
@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
189
190
// / function with unknown ops, we conservatively assume that such ops bufferize
190
191
// / to a read + write.
191
192
static LogicalResult
192
- funcOpBbArgReadWriteAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
193
+ funcOpBbArgReadWriteAnalysis (FunctionOpInterface funcOp,
194
+ OneShotAnalysisState &state,
193
195
FuncAnalysisState &funcState) {
194
- for (int64_t idx = 0 , e = funcOp.getFunctionType ().getNumInputs (); idx < e;
195
- ++idx) {
196
+ for (int64_t idx = 0 , e = funcOp.getNumArguments (); idx < e; ++idx) {
196
197
// Skip non-tensor arguments.
197
- if (!isa<TensorType>(funcOp.getFunctionType (). getInput ( idx) ))
198
+ if (!isa<TensorType>(funcOp.getArgumentTypes ()[ idx] ))
198
199
continue ;
199
200
bool isRead;
200
201
bool isWritten;
@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
204
205
StringRef str = accessAttr.getValue ();
205
206
isRead = str == " read" || str == " read-write" ;
206
207
isWritten = str == " write" || str == " read-write" ;
207
- } else if (funcOp.getBody ().empty ()) {
208
+ } else if (funcOp.getFunctionBody ().empty ()) {
208
209
// If the function has no body, conservatively assume that all args are
209
210
// read + written.
210
211
isRead = true ;
@@ -230,33 +231,32 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
230
231
231
232
// / Remove bufferization attributes on FuncOp arguments.
232
233
static void removeBufferizationAttributes (BlockArgument bbArg) {
233
- auto funcOp = cast<func::FuncOp >(bbArg.getOwner ()->getParentOp ());
234
+ auto funcOp = cast<FunctionOpInterface >(bbArg.getOwner ()->getParentOp ());
234
235
funcOp.removeArgAttr (bbArg.getArgNumber (),
235
236
BufferizationDialect::kBufferLayoutAttrName );
236
237
funcOp.removeArgAttr (bbArg.getArgNumber (),
237
238
BufferizationDialect::kWritableAttrName );
238
239
}
239
240
240
- // / Return the func::FuncOp called by `callOp`.
241
- static func::FuncOp getCalledFunction (func::CallOp callOp) {
241
+ static FunctionOpInterface getCalledFunction (CallOpInterface callOp) {
242
242
SymbolRefAttr sym =
243
243
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
244
244
if (!sym)
245
245
return nullptr ;
246
- return dyn_cast_or_null<func::FuncOp >(
246
+ return dyn_cast_or_null<FunctionOpInterface >(
247
247
SymbolTable::lookupNearestSymbolFrom (callOp, sym));
248
248
}
249
249
250
250
// / Gather equivalence info of CallOps.
251
251
// / Note: This only adds new equivalence info if the called function was already
252
252
// / analyzed.
253
253
// TODO: This does not handle cyclic function call graphs etc.
254
- static void equivalenceAnalysis (func::FuncOp funcOp,
254
+ static void equivalenceAnalysis (FunctionOpInterface funcOp,
255
255
OneShotAnalysisState &state,
256
256
FuncAnalysisState &funcState) {
257
- funcOp->walk ([&](func::CallOp callOp) {
258
- func::FuncOp calledFunction = getCalledFunction (callOp);
259
- assert (calledFunction && " could not retrieved called func::FuncOp " );
257
+ funcOp->walk ([&](CallOpInterface callOp) {
258
+ FunctionOpInterface calledFunction = getCalledFunction (callOp);
259
+ assert (calledFunction && " could not retrieved called FunctionOpInterface " );
260
260
261
261
// No equivalence info available for the called function.
262
262
if (!funcState.equivalentFuncArgs .count (calledFunction))
@@ -267,7 +267,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
267
267
int64_t bbargIdx = it.second ;
268
268
if (!state.isInPlace (callOp->getOpOperand (bbargIdx)))
269
269
continue ;
270
- Value returnVal = callOp. getResult (returnIdx);
270
+ Value returnVal = callOp-> getResult (returnIdx);
271
271
Value argVal = callOp->getOperand (bbargIdx);
272
272
state.unionEquivalenceClasses (returnVal, argVal);
273
273
}
@@ -277,11 +277,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277
277
}
278
278
279
279
// / Return "true" if the given function signature has tensor semantics.
280
- static bool hasTensorSignature (func::FuncOp funcOp) {
281
- return llvm::any_of (funcOp.getFunctionType ().getInputs (),
282
- llvm::IsaPred<TensorType>) ||
283
- llvm::any_of (funcOp.getFunctionType ().getResults (),
284
- llvm::IsaPred<TensorType>);
280
+ static bool hasTensorSignature (FunctionOpInterface funcOp) {
281
+ return llvm::any_of (funcOp.getArgumentTypes (), llvm::IsaPred<TensorType>) ||
282
+ llvm::any_of (funcOp.getResultTypes (), llvm::IsaPred<TensorType>);
285
283
}
286
284
287
285
// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -291,16 +289,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
291
289
// / retrieve the called FuncOp from any func::CallOp.
292
290
static LogicalResult
293
291
getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294
- SmallVectorImpl<func::FuncOp > &orderedFuncOps,
292
+ SmallVectorImpl<FunctionOpInterface > &orderedFuncOps,
295
293
FuncCallerMap &callerMap) {
296
294
// For each FuncOp, the set of functions called by it (i.e. the union of
297
295
// symbols of all nested func::CallOp).
298
- DenseMap<func::FuncOp , DenseSet<func::FuncOp >> calledBy;
296
+ DenseMap<FunctionOpInterface , DenseSet<FunctionOpInterface >> calledBy;
299
297
// For each FuncOp, the number of func::CallOp it contains.
300
- DenseMap<func::FuncOp , unsigned > numberCallOpsContainedInFuncOp;
301
- WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302
- if (!funcOp.getBody ().empty ()) {
303
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
298
+ DenseMap<FunctionOpInterface , unsigned > numberCallOpsContainedInFuncOp;
299
+ WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
300
+ if (!funcOp.getFunctionBody ().empty ()) {
301
+ Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
304
302
if (!returnOp)
305
303
return funcOp->emitError ()
306
304
<< " cannot bufferize a FuncOp with tensors and "
@@ -309,9 +307,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
309
307
310
308
// Collect function calls and populate the caller map.
311
309
numberCallOpsContainedInFuncOp[funcOp] = 0 ;
312
- return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
313
- func::FuncOp calledFunction = getCalledFunction (callOp);
314
- assert (calledFunction && " could not retrieved called func::FuncOp" );
310
+ return funcOp.walk ([&](CallOpInterface callOp) -> WalkResult {
311
+ FunctionOpInterface calledFunction = getCalledFunction (callOp);
312
+ assert (calledFunction &&
313
+ " could not retrieved called FunctionOpInterface" );
315
314
// If the called function does not have any tensors in its signature, then
316
315
// it is not necessary to bufferize the callee before the caller.
317
316
if (!hasTensorSignature (calledFunction))
@@ -349,11 +348,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
349
348
// / most generic layout map as function return types. After bufferizing the
350
349
// / entire function body, a more concise memref type can potentially be used for
351
350
// / the return type of the function.
352
- static void foldMemRefCasts (func::FuncOp funcOp) {
353
- if (funcOp.getBody ().empty ())
351
+ static void foldMemRefCasts (FunctionOpInterface funcOp) {
352
+ if (funcOp.getFunctionBody ().empty ())
354
353
return ;
355
354
356
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
355
+ Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
357
356
SmallVector<Type> resultTypes;
358
357
359
358
for (OpOperand &operand : returnOp->getOpOperands ()) {
@@ -365,8 +364,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
365
364
}
366
365
}
367
366
368
- auto newFuncType = FunctionType::get (
369
- funcOp. getContext (), funcOp.getFunctionType (). getInputs (), resultTypes);
367
+ auto newFuncType = FunctionType::get (funcOp. getContext (),
368
+ funcOp.getArgumentTypes (), resultTypes);
370
369
funcOp.setType (newFuncType);
371
370
}
372
371
@@ -379,7 +378,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379
378
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380
379
381
380
// A list of functions in the order in which they are analyzed + bufferized.
382
- SmallVector<func::FuncOp > orderedFuncOps;
381
+ SmallVector<FunctionOpInterface > orderedFuncOps;
383
382
384
383
// A mapping of FuncOps to their callers.
385
384
FuncCallerMap callerMap;
@@ -388,7 +387,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388
387
return failure ();
389
388
390
389
// Analyze ops.
391
- for (func::FuncOp funcOp : orderedFuncOps) {
390
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
392
391
if (!state.getOptions ().isOpAllowed (funcOp))
393
392
continue ;
394
393
@@ -416,7 +415,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
416
415
417
416
void mlir::bufferization::removeBufferizationAttributesInModule (
418
417
ModuleOp moduleOp) {
419
- moduleOp.walk ([&](func::FuncOp op) {
418
+ moduleOp.walk ([&](FunctionOpInterface op) {
420
419
for (BlockArgument bbArg : op.getArguments ())
421
420
removeBufferizationAttributes (bbArg);
422
421
});
@@ -430,7 +429,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430
429
IRRewriter rewriter (moduleOp.getContext ());
431
430
432
431
// A list of functions in the order in which they are analyzed + bufferized.
433
- SmallVector<func::FuncOp > orderedFuncOps;
432
+ SmallVector<FunctionOpInterface > orderedFuncOps;
434
433
435
434
// A mapping of FuncOps to their callers.
436
435
FuncCallerMap callerMap;
@@ -439,11 +438,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439
438
return failure ();
440
439
441
440
// Bufferize functions.
442
- for (func::FuncOp funcOp : orderedFuncOps) {
441
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
443
442
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
444
443
// would be invalidated.
445
444
446
- if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getSymName ())) {
445
+ if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getName ())) {
447
446
// This function was not analyzed and RaW conflicts were not resolved.
448
447
// Buffer copies must be inserted before every write.
449
448
OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +462,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
463
462
// Bufferize all other ops.
464
463
for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
465
464
// Functions were already bufferized.
466
- if (isa<func::FuncOp >(&op))
465
+ if (isa<FunctionOpInterface >(&op))
467
466
continue ;
468
467
if (failed (bufferizeOp (&op, options, statistics)))
469
468
return failure ();
@@ -490,12 +489,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
490
489
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
491
490
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
492
491
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493
- auto func = dyn_cast<func::FuncOp >(op);
492
+ auto func = dyn_cast<FunctionOpInterface >(op);
494
493
if (!func)
495
- func = op->getParentOfType <func::FuncOp >();
494
+ func = op->getParentOfType <FunctionOpInterface >();
496
495
if (func)
497
496
return llvm::is_contained (options.noAnalysisFuncFilter ,
498
- func.getSymName ());
497
+ func.getName ());
499
498
return false ;
500
499
};
501
500
OneShotBufferizationOptions updatedOptions (options);
0 commit comments