@@ -1215,20 +1215,23 @@ struct EmboxCharOpConversion : public FIROpConversion<fir::EmboxCharOp> {
1215
1215
} // namespace
1216
1216
1217
1217
// / Return the LLVMFuncOp corresponding to the standard malloc call.
1218
- static mlir::LLVM::LLVMFuncOp
1218
+ static mlir::SymbolRefAttr
1219
1219
getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1220
+ static constexpr char mallocName[] = " malloc" ;
1220
1221
auto module = op->getParentOfType <mlir::ModuleOp>();
1221
- if (mlir::LLVM::LLVMFuncOp mallocFunc =
1222
- module.lookupSymbol <mlir::LLVM::LLVMFuncOp>(" malloc" ))
1223
- return mallocFunc;
1222
+ if (auto mallocFunc = module.lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
1223
+ return mlir::SymbolRefAttr::get (mallocFunc);
1224
+ if (auto userMalloc = module.lookupSymbol <mlir::func::FuncOp>(mallocName))
1225
+ return mlir::SymbolRefAttr::get (userMalloc);
1224
1226
mlir::OpBuilder moduleBuilder (
1225
1227
op->getParentOfType <mlir::ModuleOp>().getBodyRegion ());
1226
1228
auto indexType = mlir::IntegerType::get (op.getContext (), 64 );
1227
- return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1228
- rewriter. getUnknownLoc (), " malloc " ,
1229
+ auto mallocDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1230
+ op. getLoc (), mallocName ,
1229
1231
mlir::LLVM::LLVMFunctionType::get (getLlvmPtrType (op.getContext ()),
1230
1232
indexType,
1231
1233
/* isVarArg=*/ false ));
1234
+ return mlir::SymbolRefAttr::get (mallocDecl);
1232
1235
}
1233
1236
1234
1237
// / Helper function for generating the LLVM IR that computes the distance
@@ -1276,7 +1279,6 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
1276
1279
matchAndRewrite (fir::AllocMemOp heap, OpAdaptor adaptor,
1277
1280
mlir::ConversionPatternRewriter &rewriter) const override {
1278
1281
mlir::Type heapTy = heap.getType ();
1279
- mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc (heap, rewriter);
1280
1282
mlir::Location loc = heap.getLoc ();
1281
1283
auto ity = lowerTy ().indexType ();
1282
1284
mlir::Type dataTy = fir::unwrapRefType (heapTy);
@@ -1289,7 +1291,7 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
1289
1291
for (mlir::Value opnd : adaptor.getOperands ())
1290
1292
size = rewriter.create <mlir::LLVM::MulOp>(
1291
1293
loc, ity, size, integerCast (loc, rewriter, ity, opnd));
1292
- heap->setAttr (" callee" , mlir::SymbolRefAttr::get (mallocFunc ));
1294
+ heap->setAttr (" callee" , getMalloc (heap, rewriter ));
1293
1295
rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
1294
1296
heap, ::getLlvmPtrType (heap.getContext ()), size, heap->getAttrs ());
1295
1297
return mlir::success ();
@@ -1307,19 +1309,25 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
1307
1309
} // namespace
1308
1310
1309
1311
// / Return the LLVMFuncOp corresponding to the standard free call.
1310
- static mlir::LLVM::LLVMFuncOp
1311
- getFree (fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1312
+ static mlir::SymbolRefAttr getFree (fir::FreeMemOp op,
1313
+ mlir::ConversionPatternRewriter &rewriter) {
1314
+ static constexpr char freeName[] = " free" ;
1312
1315
auto module = op->getParentOfType <mlir::ModuleOp>();
1313
- if (mlir::LLVM::LLVMFuncOp freeFunc =
1314
- module.lookupSymbol <mlir::LLVM::LLVMFuncOp>(" free" ))
1315
- return freeFunc;
1316
+ // Check if free already defined in the module.
1317
+ if (auto freeFunc = module.lookupSymbol <mlir::LLVM::LLVMFuncOp>(freeName))
1318
+ return mlir::SymbolRefAttr::get (freeFunc);
1319
+ if (auto freeDefinedByUser =
1320
+ module.lookupSymbol <mlir::func::FuncOp>(freeName))
1321
+ return mlir::SymbolRefAttr::get (freeDefinedByUser);
1322
+ // Create llvm declaration for free.
1316
1323
mlir::OpBuilder moduleBuilder (module.getBodyRegion ());
1317
1324
auto voidType = mlir::LLVM::LLVMVoidType::get (op.getContext ());
1318
- return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1319
- rewriter.getUnknownLoc (), " free " ,
1325
+ auto freeDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1326
+ rewriter.getUnknownLoc (), freeName ,
1320
1327
mlir::LLVM::LLVMFunctionType::get (voidType,
1321
1328
getLlvmPtrType (op.getContext ()),
1322
1329
/* isVarArg=*/ false ));
1330
+ return mlir::SymbolRefAttr::get (freeDecl);
1323
1331
}
1324
1332
1325
1333
static unsigned getDimension (mlir::LLVM::LLVMArrayType ty) {
@@ -1339,9 +1347,8 @@ struct FreeMemOpConversion : public FIROpConversion<fir::FreeMemOp> {
1339
1347
mlir::LogicalResult
1340
1348
matchAndRewrite (fir::FreeMemOp freemem, OpAdaptor adaptor,
1341
1349
mlir::ConversionPatternRewriter &rewriter) const override {
1342
- mlir::LLVM::LLVMFuncOp freeFunc = getFree (freemem, rewriter);
1343
1350
mlir::Location loc = freemem.getLoc ();
1344
- freemem->setAttr (" callee" , mlir::SymbolRefAttr::get (freeFunc ));
1351
+ freemem->setAttr (" callee" , getFree (freemem, rewriter ));
1345
1352
rewriter.create <mlir::LLVM::CallOp>(loc, mlir::TypeRange{},
1346
1353
mlir::ValueRange{adaptor.getHeapref ()},
1347
1354
freemem->getAttrs ());
0 commit comments