Skip to content

Commit d7e48fb

Browse files
authored
[llvm][OpenMP] Add implicit cast to omp.atomic.read (#114659)
Should the operands of `omp.atomic.read` differ, emit an implicit cast. In case of `struct` arguments, extract the 0-th index, emit an implicit cast if required, and store at the destination. Fixes #112908
1 parent 30e276d commit d7e48fb

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,33 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
264264
return Result;
265265
}
266266

267+
/// Emit an implicit cast to convert \p XRead to type of variable \p V
268+
static llvm::Value *emitImplicitCast(IRBuilder<> &Builder, llvm::Value *XRead,
269+
llvm::Value *V) {
270+
// TODO: Add this functionality to the `AtomicInfo` interface
271+
llvm::Type *XReadType = XRead->getType();
272+
llvm::Type *VType = V->getType();
273+
if (llvm::AllocaInst *vAlloca = dyn_cast<llvm::AllocaInst>(V))
274+
VType = vAlloca->getAllocatedType();
275+
276+
if (XReadType->isStructTy() && VType->isStructTy())
277+
// No need to extract or convert. A direct
278+
// `store` will suffice.
279+
return XRead;
280+
281+
if (XReadType->isStructTy())
282+
XRead = Builder.CreateExtractValue(XRead, /*Idxs=*/0);
283+
if (VType->isIntegerTy() && XReadType->isFloatingPointTy())
284+
XRead = Builder.CreateFPToSI(XRead, VType);
285+
else if (VType->isFloatingPointTy() && XReadType->isIntegerTy())
286+
XRead = Builder.CreateSIToFP(XRead, VType);
287+
else if (VType->isIntegerTy() && XReadType->isIntegerTy())
288+
XRead = Builder.CreateIntCast(XRead, VType, true);
289+
else if (VType->isFloatingPointTy() && XReadType->isFloatingPointTy())
290+
XRead = Builder.CreateFPCast(XRead, VType);
291+
return XRead;
292+
}
293+
267294
/// Make \p Source branch to \p Target.
268295
///
269296
/// Handles two situations:
@@ -8501,6 +8528,8 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
85018528
}
85028529
}
85038530
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
8531+
if (XRead->getType() != V.Var->getType())
8532+
XRead = emitImplicitCast(Builder, XRead, V.Var);
85048533
Builder.CreateStore(XRead, V.Var, V.IsVolatile);
85058534
return Builder.saveIP();
85068535
}
@@ -8785,6 +8814,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
87858814
return AtomicResult.takeError();
87868815
Value *CapturedVal =
87878816
(IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
8817+
if (CapturedVal->getType() != V.Var->getType())
8818+
CapturedVal = emitImplicitCast(Builder, CapturedVal, V.Var);
87888819
Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
87898820

87908821
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,77 @@ llvm.func @omp_atomic_read(%arg0 : !llvm.ptr, %arg1 : !llvm.ptr) -> () {
13681368

13691369
// -----
13701370

1371+
// CHECK-LABEL: @omp_atomic_read_implicit_cast
1372+
llvm.func @omp_atomic_read_implicit_cast () {
1373+
//CHECK: %[[Z:.*]] = alloca float, i64 1, align 4
1374+
//CHECK: %[[Y:.*]] = alloca double, i64 1, align 8
1375+
//CHECK: %[[X:.*]] = alloca [2 x { float, float }], i64 1, align 8
1376+
//CHECK: %[[W:.*]] = alloca i32, i64 1, align 4
1377+
//CHECK: %[[X_ELEMENT:.*]] = getelementptr { float, float }, ptr %3, i64 0
1378+
%0 = llvm.mlir.constant(1 : i64) : i64
1379+
%1 = llvm.alloca %0 x f32 {bindc_name = "z"} : (i64) -> !llvm.ptr
1380+
%2 = llvm.mlir.constant(1 : i64) : i64
1381+
%3 = llvm.alloca %2 x f64 {bindc_name = "y"} : (i64) -> !llvm.ptr
1382+
%4 = llvm.mlir.constant(1 : i64) : i64
1383+
%5 = llvm.alloca %4 x !llvm.array<2 x struct<(f32, f32)>> {bindc_name = "x"} : (i64) -> !llvm.ptr
1384+
%6 = llvm.mlir.constant(1 : i64) : i64
1385+
%7 = llvm.alloca %6 x i32 {bindc_name = "w"} : (i64) -> !llvm.ptr
1386+
%8 = llvm.mlir.constant(1 : index) : i64
1387+
%9 = llvm.mlir.constant(2 : index) : i64
1388+
%10 = llvm.mlir.constant(1 : i64) : i64
1389+
%11 = llvm.mlir.constant(0 : i64) : i64
1390+
%12 = llvm.sub %8, %10 overflow<nsw> : i64
1391+
%13 = llvm.mul %12, %10 overflow<nsw> : i64
1392+
%14 = llvm.mul %13, %10 overflow<nsw> : i64
1393+
%15 = llvm.add %14, %11 overflow<nsw> : i64
1394+
%16 = llvm.mul %10, %9 overflow<nsw> : i64
1395+
%17 = llvm.getelementptr %5[%15] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f32, f32)>
1396+
1397+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = alloca { float, float }, align 8
1398+
//CHECK: call void @__atomic_load(i64 8, ptr %[[X_ELEMENT]], ptr %[[ATOMIC_LOAD_TEMP]], i32 0)
1399+
//CHECK: %[[LOAD:.*]] = load { float, float }, ptr %[[ATOMIC_LOAD_TEMP]], align 8
1400+
//CHECK: %[[EXT:.*]] = extractvalue { float, float } %[[LOAD]], 0
1401+
//CHECK: store float %[[EXT]], ptr %[[Y]], align 4
1402+
omp.atomic.read %3 = %17 : !llvm.ptr, !llvm.ptr, !llvm.struct<(f32, f32)>
1403+
1404+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
1405+
//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
1406+
//CHECK: %[[LOAD:.*]] = fpext float %[[CAST]] to double
1407+
//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
1408+
omp.atomic.read %3 = %1 : !llvm.ptr, !llvm.ptr, f32
1409+
1410+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
1411+
//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to double
1412+
//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
1413+
omp.atomic.read %3 = %7 : !llvm.ptr, !llvm.ptr, i32
1414+
1415+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
1416+
//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
1417+
//CHECK: %[[LOAD:.*]] = fptrunc double %[[CAST]] to float
1418+
//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
1419+
omp.atomic.read %1 = %3 : !llvm.ptr, !llvm.ptr, f64
1420+
1421+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
1422+
//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to float
1423+
//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
1424+
omp.atomic.read %1 = %7 : !llvm.ptr, !llvm.ptr, i32
1425+
1426+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
1427+
//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
1428+
//CHECK: %[[LOAD:.*]] = fptosi double %[[CAST]] to i32
1429+
//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
1430+
omp.atomic.read %7 = %3 : !llvm.ptr, !llvm.ptr, f64
1431+
1432+
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
1433+
//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
1434+
//CHECK: %[[LOAD:.*]] = fptosi float %[[CAST]] to i32
1435+
//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
1436+
omp.atomic.read %7 = %1 : !llvm.ptr, !llvm.ptr, f32
1437+
llvm.return
1438+
}
1439+
1440+
// -----
1441+
13711442
// CHECK-LABEL: @omp_atomic_write
13721443
// CHECK-SAME: (ptr %[[x:.*]], i32 %[[expr:.*]])
13731444
llvm.func @omp_atomic_write(%x: !llvm.ptr, %expr: i32) -> () {

0 commit comments

Comments
 (0)