Skip to content

Commit e1fed24

Browse files
authored
[flang][OpenMP] Fix fir.convert in omp.atomic.update region (llvm#138397)
Region generation in omp.atomic.update currently emits a direct `fir.convert`. This crashes when the RHS expression involves complex type but the LHS variable is primitive type (say `f32`), since a `fir.convert` from `complex<f32>` to `f32` is emitted, which is illegal. This PR adds a conditional check to emit an additional `ExtractValueOp` in case RHS expression has a complex type. Fixes llvm#138396
1 parent 54aa16d commit e1fed24

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,9 +2858,23 @@ static void genAtomicUpdateStatement(
28582858
lower::StatementContext atomicStmtCtx;
28592859
mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
28602860
*semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
2861-
mlir::Value convertResult =
2862-
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
2863-
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
2861+
mlir::Type exprType = fir::unwrapRefType(rhsExpr.getType());
2862+
if (fir::isa_complex(exprType) && !fir::isa_complex(varType)) {
2863+
// Emit an additional `ExtractValueOp` if the expression is of complex
2864+
// type
2865+
auto extract = firOpBuilder.create<fir::ExtractValueOp>(
2866+
currentLocation,
2867+
mlir::cast<mlir::ComplexType>(exprType).getElementType(), rhsExpr,
2868+
firOpBuilder.getArrayAttr(
2869+
firOpBuilder.getIntegerAttr(firOpBuilder.getIndexType(), 0)));
2870+
mlir::Value convertResult = firOpBuilder.create<fir::ConvertOp>(
2871+
currentLocation, varType, extract);
2872+
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
2873+
} else {
2874+
mlir::Value convertResult =
2875+
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
2876+
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
2877+
}
28642878
converter.resetExprOverrides();
28652879
}
28662880
firOpBuilder.setInsertionPointAfter(atomicUpdateOp);

flang/test/Lower/OpenMP/atomic-update.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ program OmpAtomicUpdate
2020
!CHECK: %[[VAL_C_DECLARE:.*]]:2 = hlfir.declare %[[VAL_C_ADDRESS]] {{.*}}
2121
!CHECK: %[[VAL_D_ADDRESS:.*]] = fir.address_of(@_QFEd) : !fir.ref<i32>
2222
!CHECK: %[[VAL_D_DECLARE:.*]]:2 = hlfir.declare %[[VAL_D_ADDRESS]] {{.}}
23+
!CHECK: %[[VAL_G_ADDRESS:.*]] = fir.alloca complex<f32> {bindc_name = "g", uniq_name = "_QFEg"}
24+
!CHECK: %[[VAL_G_DECLARE:.*]]:2 = hlfir.declare %[[VAL_G_ADDRESS]] {uniq_name = "_QFEg"} : (!fir.ref<complex<f32>>) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
2325
!CHECK: %[[VAL_i1_ALLOCA:.*]] = fir.alloca i8 {bindc_name = "i1", uniq_name = "_QFEi1"}
2426
!CHECK: %[[VAL_i1_DECLARE:.*]]:2 = hlfir.declare %[[VAL_i1_ALLOCA]] {{.*}}
2527
!CHECK: %[[VAL_c5:.*]] = arith.constant 5 : index
@@ -40,6 +42,7 @@ program OmpAtomicUpdate
4042
integer, target :: c, d
4143
integer(1) :: i1
4244
integer, dimension(5) :: k
45+
complex :: g
4346

4447
!CHECK: %[[EMBOX:.*]] = fir.embox %[[VAL_C_DECLARE]]#0 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
4548
!CHECK: fir.store %[[EMBOX]] to %[[VAL_A_DECLARE]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
@@ -200,4 +203,19 @@ program OmpAtomicUpdate
200203
!CHECK: }
201204
!$omp atomic update
202205
x = x + sum([ (y+2, y=1, z) ])
206+
207+
!CHECK: %[[LOAD:.*]] = fir.load %[[VAL_G_DECLARE]]#0 : !fir.ref<complex<f32>>
208+
!CHECK: omp.atomic.update %[[VAL_W_DECLARE]]#0 : !fir.ref<i32> {
209+
!CHECK: ^bb0(%[[ARG:.*]]: i32):
210+
!CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (i32) -> f32
211+
!CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
212+
!CHECK: %[[UNDEF:.*]] = fir.undefined complex<f32>
213+
!CHECK: %[[IDX1:.*]] = fir.insert_value %[[UNDEF]], %[[CVT]], [0 : index] : (complex<f32>, f32) -> complex<f32>
214+
!CHECK: %[[IDX2:.*]] = fir.insert_value %[[IDX1]], %[[CST]], [1 : index] : (complex<f32>, f32) -> complex<f32>
215+
!CHECK: %[[ADD:.*]] = fir.addc %[[IDX2]], %[[LOAD]] {fastmath = #arith.fastmath<contract>} : complex<f32>
216+
!CHECK: %[[EXT:.*]] = fir.extract_value %[[ADD]], [0 : index] : (complex<f32>) -> f32
217+
!CHECK: %[[RESULT:.*]] = fir.convert %[[EXT]] : (f32) -> i32
218+
!CHECK: omp.yield(%[[RESULT]] : i32)
219+
!$omp atomic update
220+
w = w + g
203221
end program OmpAtomicUpdate

0 commit comments

Comments
 (0)