Skip to content

Commit 4f62a18

Browse files
authored
[flang] Allow user to define free via BIND(C) (#78428)
A user defining and using free/malloc via BIND(C) would previously cause flang to crash when generating LLVM IR with error "redefinition of symbol named 'free'". This was caused by flang codegen not expecting to find a mlir::func::FuncOp definition of these function and emitting a new mlir::LLVM::FuncOp that later conflicted when translating the mlir::func::FuncOp.
1 parent 1d286ad commit 4f62a18

File tree

2 files changed

+46
-17
lines changed

2 files changed

+46
-17
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,20 +1215,23 @@ struct EmboxCharOpConversion : public FIROpConversion<fir::EmboxCharOp> {
12151215
} // namespace
12161216

12171217
/// Return the LLVMFuncOp corresponding to the standard malloc call.
1218-
static mlir::LLVM::LLVMFuncOp
1218+
static mlir::SymbolRefAttr
12191219
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1220+
static constexpr char mallocName[] = "malloc";
12201221
auto module = op->getParentOfType<mlir::ModuleOp>();
1221-
if (mlir::LLVM::LLVMFuncOp mallocFunc =
1222-
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("malloc"))
1223-
return mallocFunc;
1222+
if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
1223+
return mlir::SymbolRefAttr::get(mallocFunc);
1224+
if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
1225+
return mlir::SymbolRefAttr::get(userMalloc);
12241226
mlir::OpBuilder moduleBuilder(
12251227
op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
12261228
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
1227-
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1228-
rewriter.getUnknownLoc(), "malloc",
1229+
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1230+
op.getLoc(), mallocName,
12291231
mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()),
12301232
indexType,
12311233
/*isVarArg=*/false));
1234+
return mlir::SymbolRefAttr::get(mallocDecl);
12321235
}
12331236

12341237
/// Helper function for generating the LLVM IR that computes the distance
@@ -1276,7 +1279,6 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
12761279
matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor,
12771280
mlir::ConversionPatternRewriter &rewriter) const override {
12781281
mlir::Type heapTy = heap.getType();
1279-
mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc(heap, rewriter);
12801282
mlir::Location loc = heap.getLoc();
12811283
auto ity = lowerTy().indexType();
12821284
mlir::Type dataTy = fir::unwrapRefType(heapTy);
@@ -1289,7 +1291,7 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
12891291
for (mlir::Value opnd : adaptor.getOperands())
12901292
size = rewriter.create<mlir::LLVM::MulOp>(
12911293
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1292-
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
1294+
heap->setAttr("callee", getMalloc(heap, rewriter));
12931295
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
12941296
heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
12951297
return mlir::success();
@@ -1307,19 +1309,25 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
13071309
} // namespace
13081310

13091311
/// Return the LLVMFuncOp corresponding to the standard free call.
1310-
static mlir::LLVM::LLVMFuncOp
1311-
getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1312+
static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
1313+
mlir::ConversionPatternRewriter &rewriter) {
1314+
static constexpr char freeName[] = "free";
13121315
auto module = op->getParentOfType<mlir::ModuleOp>();
1313-
if (mlir::LLVM::LLVMFuncOp freeFunc =
1314-
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("free"))
1315-
return freeFunc;
1316+
// Check if free already defined in the module.
1317+
if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
1318+
return mlir::SymbolRefAttr::get(freeFunc);
1319+
if (auto freeDefinedByUser =
1320+
module.lookupSymbol<mlir::func::FuncOp>(freeName))
1321+
return mlir::SymbolRefAttr::get(freeDefinedByUser);
1322+
// Create llvm declaration for free.
13161323
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
13171324
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
1318-
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1319-
rewriter.getUnknownLoc(), "free",
1325+
auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1326+
rewriter.getUnknownLoc(), freeName,
13201327
mlir::LLVM::LLVMFunctionType::get(voidType,
13211328
getLlvmPtrType(op.getContext()),
13221329
/*isVarArg=*/false));
1330+
return mlir::SymbolRefAttr::get(freeDecl);
13231331
}
13241332

13251333
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
@@ -1339,9 +1347,8 @@ struct FreeMemOpConversion : public FIROpConversion<fir::FreeMemOp> {
13391347
mlir::LogicalResult
13401348
matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor,
13411349
mlir::ConversionPatternRewriter &rewriter) const override {
1342-
mlir::LLVM::LLVMFuncOp freeFunc = getFree(freemem, rewriter);
13431350
mlir::Location loc = freemem.getLoc();
1344-
freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
1351+
freemem->setAttr("callee", getFree(freemem, rewriter));
13451352
rewriter.create<mlir::LLVM::CallOp>(loc, mlir::TypeRange{},
13461353
mlir::ValueRange{adaptor.getHeapref()},
13471354
freemem->getAttrs());
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Test that FIR codegen handles cases when free and malloc have
2+
// already been defined in FIR (either by the user in Fortran via
3+
// BIND(C) or by some FIR pass in between).
4+
// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s
5+
6+
7+
func.func @already_declared_free_malloc() {
8+
%c4 = arith.constant 4 : index
9+
%0 = fir.call @malloc(%c4) : (index) -> !fir.heap<i32>
10+
fir.call @free(%0) : (!fir.heap<i32>) -> ()
11+
%1 = fir.allocmem i32
12+
fir.freemem %1 : !fir.heap<i32>
13+
return
14+
}
15+
16+
// CHECK: llvm.call @malloc(%{{.*}})
17+
// CHECK: llvm.call @free(%{{.*}})
18+
// CHECK: llvm.call @malloc(%{{.*}})
19+
// CHECK: llvm.call @free(%{{.*}})
20+
21+
func.func private @free(!fir.heap<i32>)
22+
func.func private @malloc(index) -> !fir.heap<i32>

0 commit comments

Comments
 (0)