Skip to content

Commit a435e1f

Browse files
[acc] Add attribute for combined constructs (#80319)
Combined constructs are decomposed into separate operations. However, this does not adhere to `acc` dialect's goal to be able to regenerate semantically equivalent clauses as user's intent. Thus, add an attribute to keep track of the combined constructs.
1 parent cfdfeb4 commit a435e1f

File tree

5 files changed

+168
-7
lines changed

5 files changed

+168
-7
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() {
133133
return StringLiteral("acc.routine_info");
134134
}
135135

136+
static constexpr StringLiteral getCombinedConstructsAttrName() {
137+
return CombinedConstructsTypeAttr::name;
138+
}
139+
136140
struct RuntimeCounters
137141
: public mlir::SideEffects::Resource::Base<RuntimeCounters> {
138142
mlir::StringRef getName() final { return "AccRuntimeCounters"; }

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,24 @@ def GangArgTypeArrayAttr :
218218
let constBuilderCall = ?;
219219
}
220220

221+
// Combined constructs enumerations
222+
def OpenACC_KernelsLoop : I32EnumAttrCase<"KernelsLoop", 1, "kernels_loop">;
223+
def OpenACC_ParallelLoop : I32EnumAttrCase<"ParallelLoop", 2, "parallel_loop">;
224+
def OpenACC_SerialLoop : I32EnumAttrCase<"SerialLoop", 3, "serial_loop">;
225+
226+
def OpenACC_CombinedConstructsType : I32EnumAttr<"CombinedConstructsType",
227+
"Differentiate between combined constructs",
228+
[OpenACC_KernelsLoop, OpenACC_ParallelLoop, OpenACC_SerialLoop]> {
229+
let genSpecializedAttr = 0;
230+
let cppNamespace = "::mlir::acc";
231+
}
232+
233+
def OpenACC_CombinedConstructsAttr : EnumAttr<OpenACC_Dialect,
234+
OpenACC_CombinedConstructsType,
235+
"combined_constructs"> {
236+
let assemblyFormat = [{ ```<` $value `>` }];
237+
}
238+
221239
// Define a resource for the OpenACC runtime counters.
222240
def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
223241

@@ -933,7 +951,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
933951
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
934952
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
935953
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
936-
OptionalAttr<DefaultValueAttr>:$defaultAttr);
954+
OptionalAttr<DefaultValueAttr>:$defaultAttr,
955+
UnitAttr:$combined);
937956

938957
let regions = (region AnyRegion:$region);
939958

@@ -993,6 +1012,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
9931012
}];
9941013

9951014
let assemblyFormat = [{
1015+
( `combined` `(` `loop` `)` $combined^)?
9961016
oilist(
9971017
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
9981018
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
@@ -1068,7 +1088,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
10681088
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
10691089
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
10701090
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
1071-
OptionalAttr<DefaultValueAttr>:$defaultAttr);
1091+
OptionalAttr<DefaultValueAttr>:$defaultAttr,
1092+
UnitAttr:$combined);
10721093

10731094
let regions = (region AnyRegion:$region);
10741095

@@ -1109,6 +1130,7 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
11091130
}];
11101131

11111132
let assemblyFormat = [{
1133+
( `combined` `(` `loop` `)` $combined^)?
11121134
oilist(
11131135
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
11141136
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
@@ -1182,7 +1204,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
11821204
Optional<I1>:$selfCond,
11831205
UnitAttr:$selfAttr,
11841206
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
1185-
OptionalAttr<DefaultValueAttr>:$defaultAttr);
1207+
OptionalAttr<DefaultValueAttr>:$defaultAttr,
1208+
UnitAttr:$combined);
11861209

11871210
let regions = (region AnyRegion:$region);
11881211

@@ -1242,6 +1265,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
12421265
}];
12431266

12441267
let assemblyFormat = [{
1268+
( `combined` `(` `loop` `)` $combined^)?
12451269
oilist(
12461270
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
12471271
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
@@ -1573,7 +1597,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
15731597
Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
15741598
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
15751599
Variadic<AnyType>:$reductionOperands,
1576-
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
1600+
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
1601+
OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined
15771602
);
15781603

15791604
let results = (outs Variadic<AnyType>:$results);
@@ -1665,6 +1690,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
16651690

16661691
let hasCustomAssemblyFormat = 1;
16671692
let assemblyFormat = [{
1693+
custom<CombinedConstructsLoop>($combined)
16681694
oilist(
16691695
`gang` `` custom<GangClause>($gangOperands, type($gangOperands),
16701696
$gangOperandsArgType, $gangOperandsDeviceType,

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,50 @@ static void printDeviceTypeOperandsWithKeywordOnly(
12831283
p << ")";
12841284
}
12851285

1286+
static ParseResult
1287+
parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
1288+
mlir::acc::CombinedConstructsTypeAttr &attr) {
1289+
if (succeeded(parser.parseOptionalKeyword("combined"))) {
1290+
if (parser.parseLParen())
1291+
return failure();
1292+
if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1293+
attr = mlir::acc::CombinedConstructsTypeAttr::get(
1294+
parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1295+
} else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1296+
attr = mlir::acc::CombinedConstructsTypeAttr::get(
1297+
parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1298+
} else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1299+
attr = mlir::acc::CombinedConstructsTypeAttr::get(
1300+
parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1301+
} else {
1302+
parser.emitError(parser.getCurrentLocation(),
1303+
"expected compute construct name");
1304+
return failure();
1305+
}
1306+
if (parser.parseRParen())
1307+
return failure();
1308+
}
1309+
return success();
1310+
}
1311+
1312+
static void
1313+
printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
1314+
mlir::acc::CombinedConstructsTypeAttr attr) {
1315+
if (attr) {
1316+
switch (attr.getValue()) {
1317+
case mlir::acc::CombinedConstructsType::KernelsLoop:
1318+
p << "combined(kernels)";
1319+
break;
1320+
case mlir::acc::CombinedConstructsType::ParallelLoop:
1321+
p << "combined(parallel)";
1322+
break;
1323+
case mlir::acc::CombinedConstructsType::SerialLoop:
1324+
p << "combined(serial)";
1325+
break;
1326+
};
1327+
}
1328+
}
1329+
12861330
//===----------------------------------------------------------------------===//
12871331
// SerialOp
12881332
//===----------------------------------------------------------------------===//
@@ -1851,6 +1895,13 @@ LogicalResult acc::LoopOp::verify() {
18511895
"reductions", false)))
18521896
return failure();
18531897

1898+
if (getCombined().has_value() &&
1899+
(getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1900+
getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1901+
getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1902+
return emitError("unexpected combined constructs attribute");
1903+
}
1904+
18541905
// Check non-empty body().
18551906
if (getRegion().empty())
18561907
return emitError("expected non-empty body.");

mlir/test/Dialect/OpenACC/invalid.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,3 +738,43 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
738738
acc.terminator
739739
}
740740
}
741+
742+
// -----
743+
744+
func.func @acc_combined() {
745+
// expected-error @below {{expected 'loop'}}
746+
acc.parallel combined() {
747+
}
748+
749+
return
750+
}
751+
752+
// -----
753+
754+
func.func @acc_combined() {
755+
// expected-error @below {{expected compute construct name}}
756+
acc.loop combined(loop) {
757+
}
758+
759+
return
760+
}
761+
762+
// -----
763+
764+
func.func @acc_combined() {
765+
// expected-error @below {{expected 'loop'}}
766+
acc.parallel combined(parallel loop) {
767+
}
768+
769+
return
770+
}
771+
772+
// -----
773+
774+
func.func @acc_combined() {
775+
// expected-error @below {{expected ')'}}
776+
acc.loop combined(parallel loop) {
777+
}
778+
779+
return
780+
}

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,9 +1846,49 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
18461846

18471847
// -----
18481848

1849-
%c2 = arith.constant 2 : i32
1850-
%c1 = arith.constant 1 : i32
1851-
acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
1849+
// CHECK-LABEL: func.func @acc_num_gangs
1850+
func.func @acc_num_gangs() {
1851+
%c2 = arith.constant 2 : i32
1852+
%c1 = arith.constant 1 : i32
1853+
acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
1854+
}
1855+
1856+
return
18521857
}
18531858

18541859
// CHECK: acc.parallel num_gangs({%c2{{.*}} : i32} [#acc.device_type<default>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
1860+
1861+
// -----
1862+
1863+
// CHECK-LABEL: func.func @acc_combined
1864+
func.func @acc_combined() {
1865+
acc.parallel combined(loop) {
1866+
acc.loop combined(parallel) {
1867+
acc.yield
1868+
}
1869+
acc.terminator
1870+
}
1871+
1872+
acc.kernels combined(loop) {
1873+
acc.loop combined(kernels) {
1874+
acc.yield
1875+
}
1876+
acc.terminator
1877+
}
1878+
1879+
acc.serial combined(loop) {
1880+
acc.loop combined(serial) {
1881+
acc.yield
1882+
}
1883+
acc.terminator
1884+
}
1885+
1886+
return
1887+
}
1888+
1889+
// CHECK: acc.parallel combined(loop)
1890+
// CHECK: acc.loop combined(parallel)
1891+
// CHECK: acc.kernels combined(loop)
1892+
// CHECK: acc.loop combined(kernels)
1893+
// CHECK: acc.serial combined(loop)
1894+
// CHECK: acc.loop combined(serial)

0 commit comments

Comments
 (0)