Skip to content

Commit 3668a3a

Browse files
committed
[OpenACC][CIR] 'if'/'self' combined construct lowering
These two require that we correctly set up the 'insertion points' for the compute construct when doing a combined construct. This patch adds that and verifies that we're doing it correctly.
1 parent 25a03c1 commit 3668a3a

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

clang/lib/CIR/CodeGen/CIRGenOpenACCClause.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ class OpenACCClauseCIREmitter final
148148
template <typename U = void,
149149
typename = std::enable_if_t<isCombinedType<OpTy>, U>>
150150
void applyToComputeOp(const OpenACCClause &c) {
151-
// TODO OpenACC: we have to set the insertion scope here correctly still.
151+
mlir::OpBuilder::InsertionGuard guardCase(builder);
152+
builder.setInsertionPoint(operation.computeOp);
152153
OpenACCClauseCIREmitter<typename OpTy::ComputeOpTy> computeEmitter{
153154
operation.computeOp, cgf, builder, dirKind, dirLoc};
154155
computeEmitter.lastDeviceTypeValues = lastDeviceTypeValues;
@@ -288,9 +289,11 @@ class OpenACCClauseCIREmitter final
288289
} else {
289290
llvm_unreachable("var-list version of self shouldn't get here");
290291
}
292+
} else if constexpr (isCombinedType<OpTy>) {
293+
applyToComputeOp(clause);
291294
} else {
292295
// TODO: When we've implemented this for everything, switch this to an
293-
// unreachable. If, combined constructs remain.
296+
// unreachable. update construct remains.
294297
return clauseNotImplemented(clause);
295298
}
296299
}
@@ -302,13 +305,15 @@ class OpenACCClauseCIREmitter final
302305
mlir::acc::DataOp, mlir::acc::WaitOp>) {
303306
operation.getIfCondMutable().append(
304307
createCondition(clause.getConditionExpr()));
308+
} else if constexpr (isCombinedType<OpTy>) {
309+
applyToComputeOp(clause);
305310
} else {
306311
// 'if' applies to most of the constructs, but hold off on lowering them
307312
// until we can write tests/know what we're doing with codegen to make
308313
// sure we get it right.
309314
// TODO: When we've implemented this for everything, switch this to an
310-
// unreachable. Enter data, exit data, host_data, update, combined
311-
// constructs remain.
315+
// unreachable. Enter data, exit data, host_data, update constructs
316+
// remain.
312317
return clauseNotImplemented(clause);
313318
}
314319
}

clang/test/CIR/CodeGenOpenACC/combined.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,80 @@ extern "C" void acc_combined(int N) {
176176
// CHECK-NEXT: } attributes {collapse = [1, 2, 2, 3], collapseDeviceType = [#acc.device_type<none>, #acc.device_type<radeon>, #acc.device_type<nvidia>, #acc.device_type<host>]}
177177
// CHECK: acc.yield
178178
// CHECK-NEXT: } loc
179+
180+
#pragma acc kernels loop self
181+
for(unsigned I = 0; I < N; ++I);
182+
// CHECK-NEXT: acc.kernels combined(loop) {
183+
// CHECK-NEXT: acc.loop combined(kernels) {
184+
// CHECK: acc.yield
185+
// CHECK-NEXT: } loc
186+
// CHECK-NEXT: acc.terminator
187+
// CHECK-NEXT: } attributes {selfAttr}
188+
189+
#pragma acc serial loop self(N)
190+
for(unsigned I = 0; I < N; ++I);
191+
// CHECK-NEXT: %[[N_LOAD:.*]] = cir.load %[[ALLOCA_N]] : !cir.ptr<!s32i>, !s32i
192+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[N_LOAD]] : !s32i), !cir.bool
193+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
194+
// CHECK-NEXT: acc.serial combined(loop) self(%[[CONV_CAST]]) {
195+
// CHECK-NEXT: acc.loop combined(serial) {
196+
// CHECK: acc.yield
197+
// CHECK-NEXT: } loc
198+
// CHECK-NEXT: acc.yield
199+
// CHECK-NEXT: } loc
200+
201+
#pragma acc parallel loop if(N)
202+
for(unsigned I = 0; I < N; ++I);
203+
// CHECK-NEXT: %[[N_LOAD:.*]] = cir.load %[[ALLOCA_N]] : !cir.ptr<!s32i>, !s32i
204+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[N_LOAD]] : !s32i), !cir.bool
205+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
206+
// CHECK-NEXT: acc.parallel combined(loop) if(%[[CONV_CAST]]) {
207+
// CHECK-NEXT: acc.loop combined(parallel) {
208+
// CHECK: acc.yield
209+
// CHECK-NEXT: } loc
210+
// CHECK-NEXT: acc.yield
211+
// CHECK-NEXT: } loc
212+
213+
#pragma acc serial loop if(1)
214+
for(unsigned I = 0; I < N; ++I);
215+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
216+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[ONE_LITERAL]] : !s32i), !cir.bool
217+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
218+
// CHECK-NEXT: acc.serial combined(loop) if(%[[CONV_CAST]]) {
219+
// CHECK-NEXT: acc.loop combined(serial) {
220+
// CHECK: acc.yield
221+
// CHECK-NEXT: } loc
222+
// CHECK-NEXT: acc.yield
223+
// CHECK-NEXT: } loc
224+
225+
#pragma acc kernels loop if(N == 1)
226+
for(unsigned I = 0; I < N; ++I);
227+
// CHECK-NEXT: %[[N_LOAD:.*]] = cir.load %[[ALLOCA_N]] : !cir.ptr<!s32i>, !s32i
228+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
229+
// CHECK-NEXT: %[[EQ_RES:.*]] = cir.cmp(eq, %[[N_LOAD]], %[[ONE_LITERAL]]) : !s32i, !cir.bool
230+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[EQ_RES]] : !cir.bool to i1
231+
// CHECK-NEXT: acc.kernels combined(loop) if(%[[CONV_CAST]]) {
232+
// CHECK-NEXT: acc.loop combined(kernels) {
233+
// CHECK: acc.yield
234+
// CHECK-NEXT: } loc
235+
// CHECK-NEXT: acc.terminator
236+
// CHECK-NEXT: } loc
237+
238+
#pragma acc parallel loop if(N == 1) self(N == 2)
239+
for(unsigned I = 0; I < N; ++I);
240+
// CHECK-NEXT: %[[N_LOAD:.*]] = cir.load %[[ALLOCA_N]] : !cir.ptr<!s32i>, !s32i
241+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
242+
// CHECK-NEXT: %[[EQ_RES_IF:.*]] = cir.cmp(eq, %[[N_LOAD]], %[[ONE_LITERAL]]) : !s32i, !cir.bool
243+
// CHECK-NEXT: %[[CONV_CAST_IF:.*]] = builtin.unrealized_conversion_cast %[[EQ_RES_IF]] : !cir.bool to i1
244+
// CHECK-NEXT: %[[N_LOAD:.*]] = cir.load %[[ALLOCA_N]] : !cir.ptr<!s32i>, !s32i
245+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
246+
// CHECK-NEXT: %[[EQ_RES_SELF:.*]] = cir.cmp(eq, %[[N_LOAD]], %[[TWO_LITERAL]]) : !s32i, !cir.bool
247+
// CHECK-NEXT: %[[CONV_CAST_SELF:.*]] = builtin.unrealized_conversion_cast %[[EQ_RES_SELF]] : !cir.bool to i1
248+
// CHECK-NEXT: acc.parallel combined(loop) self(%[[CONV_CAST_SELF]]) if(%[[CONV_CAST_IF]]) {
249+
// CHECK-NEXT: acc.loop combined(parallel) {
250+
// CHECK: acc.yield
251+
// CHECK-NEXT: } loc
252+
// CHECK-NEXT: acc.yield
253+
// CHECK-NEXT: } loc
254+
179255
}

0 commit comments

Comments
 (0)