Skip to content

Commit 0637778

Browse files
authored
[mlir][LLVMIR][OpenMP] fix dominance for reduction init block (#96052)
It was incorrect to set the insertion point to the init block after inlining the initialization region because the code generated in the init block depends upon the value yielded from the init region. When there were multiple reduction initialization regions each with multiple blocks, this could lead to the initilization region being inlined after the init block which depends upon it. Moving the insertion point to before inlining the initialization block turned up further issues around the handling of the terminator for the initialization block, which are also fixed here. This fixes a bug in #92430 (but the affected code couldn't compile before #92430 anyway).
1 parent 906316e commit 0637778

File tree

3 files changed

+139
-4
lines changed

3 files changed

+139
-4
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,19 @@ static LogicalResult inlineConvertOmpRegions(
388388
// be processed multiple times.
389389
moduleTranslation.forgetMapping(region);
390390

391-
if (potentialTerminator && potentialTerminator->isTerminator())
392-
potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
391+
if (potentialTerminator && potentialTerminator->isTerminator()) {
392+
llvm::BasicBlock *block = builder.GetInsertBlock();
393+
if (block->empty()) {
394+
// this can happen for really simple reduction init regions e.g.
395+
// %0 = llvm.mlir.constant(0 : i32) : i32
396+
// omp.yield(%0 : i32)
397+
// because the llvm.mlir.constant (MLIR op) isn't converted into any
398+
// llvm op
399+
potentialTerminator->insertInto(block, block->begin());
400+
} else {
401+
potentialTerminator->insertAfter(&block->back());
402+
}
403+
}
393404

394405
return success();
395406
}
@@ -1171,6 +1182,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11711182
}
11721183
}
11731184

1185+
builder.SetInsertPoint(initBlock->getFirstNonPHIOrDbgOrAlloca());
1186+
11741187
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
11751188
SmallVector<llvm::Value *> phis;
11761189

@@ -1183,7 +1196,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11831196
assert(phis.size() == 1 &&
11841197
"expected one value to be yielded from the "
11851198
"reduction neutral element declaration region");
1186-
builder.SetInsertPoint(initBlock->getTerminator());
1199+
1200+
// mapInitializationArg finishes its block with a terminator. We need to
1201+
// insert before that terminator.
1202+
builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
11871203

11881204
if (isByRef[i]) {
11891205
// Store the result of the inlined region to the allocated reduction var
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// This is basically a test that we don't crash while translating this IR
4+
5+
omp.declare_reduction @add_reduction_byref_box_heap_i32 : !llvm.ptr init {
6+
^bb0(%arg0: !llvm.ptr):
7+
%7 = llvm.mlir.constant(0 : i64) : i64
8+
%16 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
9+
%17 = llvm.icmp "eq" %16, %7 : i64
10+
llvm.cond_br %17, ^bb1, ^bb2
11+
^bb1: // pred: ^bb0
12+
llvm.br ^bb3
13+
^bb2: // pred: ^bb0
14+
llvm.br ^bb3
15+
^bb3: // 2 preds: ^bb1, ^bb2
16+
omp.yield(%arg0 : !llvm.ptr)
17+
} combiner {
18+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
19+
omp.yield(%arg0 : !llvm.ptr)
20+
} cleanup {
21+
^bb0(%arg0: !llvm.ptr):
22+
omp.yield
23+
}
24+
llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: !llvm.ptr {fir.bindc_name = "y"}) attributes {fir.internal_name = "_QPmissordered_blocks", frame_pointer = #llvm.framePointerKind<"non-leaf">, target_cpu = "generic", target_features = #llvm.target_features<["+outline-atomics", "+v8a", "+fp-armv8", "+neon"]>} {
25+
omp.parallel reduction(byref @add_reduction_byref_box_heap_i32 %arg0 -> %arg2 : !llvm.ptr, byref @add_reduction_byref_box_heap_i32 %arg1 -> %arg3 : !llvm.ptr) {
26+
omp.terminator
27+
}
28+
llvm.return
29+
}
30+
31+
// CHECK: %[[VAL_0:.*]] = alloca { ptr, ptr }, align 8
32+
// CHECK: br label %[[VAL_1:.*]]
33+
// CHECK: entry: ; preds = %[[VAL_2:.*]]
34+
// CHECK: %[[VAL_3:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
35+
// CHECK: br label %[[VAL_4:.*]]
36+
// CHECK: omp_parallel: ; preds = %[[VAL_1]]
37+
// CHECK: %[[VAL_5:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_0]], i32 0, i32 0
38+
// CHECK: store ptr %[[VAL_6:.*]], ptr %[[VAL_5]], align 8
39+
// CHECK: %[[VAL_7:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_0]], i32 0, i32 1
40+
// CHECK: store ptr %[[VAL_8:.*]], ptr %[[VAL_7]], align 8
41+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @missordered_blocks_..omp_par, ptr %[[VAL_0]])
42+
// CHECK: br label %[[VAL_9:.*]]
43+
// CHECK: omp.par.outlined.exit: ; preds = %[[VAL_4]]
44+
// CHECK: br label %[[VAL_10:.*]]
45+
// CHECK: omp.par.exit.split: ; preds = %[[VAL_9]]
46+
// CHECK: ret void
47+
// CHECK: omp.par.entry:
48+
// CHECK: %[[VAL_11:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_12:.*]], i32 0, i32 0
49+
// CHECK: %[[VAL_13:.*]] = load ptr, ptr %[[VAL_11]], align 8
50+
// CHECK: %[[VAL_14:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_12]], i32 0, i32 1
51+
// CHECK: %[[VAL_15:.*]] = load ptr, ptr %[[VAL_14]], align 8
52+
// CHECK: %[[VAL_16:.*]] = alloca i32, align 4
53+
// CHECK: %[[VAL_17:.*]] = load i32, ptr %[[VAL_18:.*]], align 4
54+
// CHECK: store i32 %[[VAL_17]], ptr %[[VAL_16]], align 4
55+
// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_16]], align 4
56+
// CHECK: %[[VAL_20:.*]] = alloca ptr, align 8
57+
// CHECK: %[[VAL_21:.*]] = alloca ptr, align 8
58+
// CHECK: %[[VAL_22:.*]] = alloca [2 x ptr], align 8
59+
// CHECK: br label %[[VAL_23:.*]]
60+
// CHECK: omp.reduction.init: ; preds = %[[VAL_24:.*]]
61+
// CHECK: br label %[[VAL_25:.*]]
62+
// CHECK: omp.reduction.neutral: ; preds = %[[VAL_23]]
63+
// CHECK: %[[VAL_26:.*]] = ptrtoint ptr %[[VAL_13]] to i64
64+
// CHECK: %[[VAL_27:.*]] = icmp eq i64 %[[VAL_26]], 0
65+
// CHECK: br i1 %[[VAL_27]], label %[[VAL_28:.*]], label %[[VAL_29:.*]]
66+
// CHECK: omp.reduction.neutral2: ; preds = %[[VAL_25]]
67+
// CHECK: br label %[[VAL_30:.*]]
68+
// CHECK: omp.reduction.neutral3: ; preds = %[[VAL_28]], %[[VAL_29]]
69+
// CHECK: br label %[[VAL_31:.*]]
70+
// CHECK: omp.region.cont: ; preds = %[[VAL_30]]
71+
// CHECK: %[[VAL_32:.*]] = phi ptr [ %[[VAL_13]], %[[VAL_30]] ]
72+
// CHECK: store ptr %[[VAL_32]], ptr %[[VAL_20]], align 8
73+
// CHECK: br label %[[VAL_33:.*]]
74+
// CHECK: omp.reduction.neutral5: ; preds = %[[VAL_31]]
75+
// CHECK: %[[VAL_34:.*]] = ptrtoint ptr %[[VAL_15]] to i64
76+
// CHECK: %[[VAL_35:.*]] = icmp eq i64 %[[VAL_34]], 0
77+
// CHECK: br i1 %[[VAL_35]], label %[[VAL_36:.*]], label %[[VAL_37:.*]]
78+
// CHECK: omp.reduction.neutral7: ; preds = %[[VAL_33]]
79+
// CHECK: br label %[[VAL_38:.*]]
80+
// CHECK: omp.reduction.neutral8: ; preds = %[[VAL_36]], %[[VAL_37]]
81+
// CHECK: br label %[[VAL_39:.*]]
82+
// CHECK: omp.region.cont4: ; preds = %[[VAL_38]]
83+
// CHECK: %[[VAL_40:.*]] = phi ptr [ %[[VAL_15]], %[[VAL_38]] ]
84+
// CHECK: store ptr %[[VAL_40]], ptr %[[VAL_21]], align 8
85+
// CHECK: br label %[[VAL_41:.*]]
86+
// CHECK: omp.par.region: ; preds = %[[VAL_39]]
87+
// CHECK: br label %[[VAL_42:.*]]
88+
// CHECK: omp.par.region10: ; preds = %[[VAL_41]]
89+
// CHECK: br label %[[VAL_43:.*]]
90+
// CHECK: omp.region.cont9: ; preds = %[[VAL_42]]
91+
// CHECK: %[[VAL_44:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_22]], i64 0, i64 0
92+
// CHECK: store ptr %[[VAL_20]], ptr %[[VAL_44]], align 8
93+
// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_22]], i64 0, i64 1
94+
// CHECK: store ptr %[[VAL_21]], ptr %[[VAL_45]], align 8
95+
// CHECK: %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
96+
// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_reduce(ptr @1, i32 %[[VAL_46]], i32 2, i64 16, ptr %[[VAL_22]], ptr @.omp.reduction.func, ptr @.gomp_critical_user_.reduction.var)
97+
// CHECK: switch i32 %[[VAL_47]], label %[[VAL_48:.*]] [
98+
// CHECK: i32 1, label %[[VAL_49:.*]]
99+
// CHECK: i32 2, label %[[VAL_50:.*]]
100+
// CHECK: ]
101+
// CHECK: reduce.switch.atomic: ; preds = %[[VAL_43]]
102+
// CHECK: unreachable
103+
// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_43]]
104+
// CHECK: %[[VAL_51:.*]] = load ptr, ptr %[[VAL_20]], align 8
105+
// CHECK: %[[VAL_52:.*]] = load ptr, ptr %[[VAL_21]], align 8
106+
// CHECK: call void @__kmpc_end_reduce(ptr @1, i32 %[[VAL_46]], ptr @.gomp_critical_user_.reduction.var)
107+
// CHECK: br label %[[VAL_48]]
108+
// CHECK: reduce.finalize: ; preds = %[[VAL_49]], %[[VAL_43]]
109+
// CHECK: br label %[[VAL_53:.*]]
110+
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_48]]
111+
// CHECK: %[[VAL_54:.*]] = load ptr, ptr %[[VAL_20]], align 8
112+
// CHECK: %[[VAL_55:.*]] = load ptr, ptr %[[VAL_21]], align 8
113+
// CHECK: br label %[[VAL_56:.*]]
114+
// CHECK: omp.reduction.neutral6: ; preds = %[[VAL_33]]
115+
// CHECK: br label %[[VAL_38]]
116+
// CHECK: omp.reduction.neutral1: ; preds = %[[VAL_25]]
117+
// CHECK: br label %[[VAL_30]]
118+
// CHECK: omp.par.outlined.exit.exitStub: ; preds = %[[VAL_53]]
119+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ module {
6161
// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_16]], align 4
6262
// CHECK: %[[VAL_21:.*]] = alloca ptr, align 8
6363
// CHECK: %[[VAL_23:.*]] = alloca ptr, align 8
64-
// CHECK: %[[VAL_20:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_13]], align 8
6564
// CHECK: %[[VAL_24:.*]] = alloca [2 x ptr], align 8
6665
// CHECK: br label %[[INIT_LABEL:.*]]
6766
// CHECK: [[INIT_LABEL]]:
67+
// CHECK: %[[VAL_20:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_13]], align 8
6868
// CHECK: store ptr %[[VAL_13]], ptr %[[VAL_21]], align 8
6969
// CHECK: %[[VAL_22:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_15]], align 8
7070
// CHECK: store ptr %[[VAL_15]], ptr %[[VAL_23]], align 8

0 commit comments

Comments
 (0)