Skip to content

Commit 48f8865

Browse files
authored
[MLIR] Extend MPI dialect (llvm#123255)
cc @tobiasgrosser @wsmoses this PR adds some new ops and types to the MLIR MPI dialect. the goal is to get the minimum required ops here to get a project of us working, and if everything works well, continue adding ops to the mpi dialect on subsequent PRs until we achieve some level of compliance with the MPI standard. --- Things left to do in subsequent PRs: - Add back the `mpi.comm` type and add as optional argument of current implemented ops that should support it (i.e. `send`, `recv`, `isend`, `irecv`, `allreduce`, `barrier`). - Support defining custom `MPI_Op`s (the MPI operations, not the tablegen `MPI_Op`) as regions. - Add more ops.
1 parent 9725595 commit 48f8865

File tree

5 files changed

+293
-12
lines changed

5 files changed

+293
-12
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPI.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,43 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
215215
let assemblyFormat = "`<` $value `>`";
216216
}
217217

218+
def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
219+
def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
220+
def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
221+
def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
222+
def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
223+
def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
224+
def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
225+
def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
226+
def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
227+
def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
228+
def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
229+
def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
230+
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
231+
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
232+
233+
def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
234+
MPI_OpNull,
235+
MPI_OpMax,
236+
MPI_OpMin,
237+
MPI_OpSum,
238+
MPI_OpProd,
239+
MPI_OpLand,
240+
MPI_OpBand,
241+
MPI_OpLor,
242+
MPI_OpBor,
243+
MPI_OpLxor,
244+
MPI_OpBxor,
245+
MPI_OpMinloc,
246+
MPI_OpMaxloc,
247+
MPI_OpReplace
248+
]> {
249+
let genSpecializedAttr = 0;
250+
let cppNamespace = "::mlir::mpi";
251+
}
252+
253+
def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
254+
let assemblyFormat = "`<` $value `>`";
255+
}
256+
218257
#endif // MLIR_DIALECT_MPI_IR_MPI_TD

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 186 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
5959
let assemblyFormat = "attr-dict `:` type(results)";
6060
}
6161

62+
//===----------------------------------------------------------------------===//
63+
// CommSizeOp
64+
//===----------------------------------------------------------------------===//
65+
66+
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
67+
let summary = "Get the size of the group associated to the communicator, "
68+
"equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
69+
let description = [{
70+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
71+
72+
This operation can optionally return an `!mpi.retval` value that can be used
73+
to check for errors.
74+
}];
75+
76+
let results = (
77+
outs Optional<MPI_Retval> : $retval,
78+
I32 : $size
79+
);
80+
81+
let assemblyFormat = "attr-dict `:` type(results)";
82+
}
83+
6284
//===----------------------------------------------------------------------===//
6385
// SendOp
6486
//===----------------------------------------------------------------------===//
@@ -71,13 +93,17 @@ def MPI_SendOp : MPI_Op<"send", []> {
7193
`dest`. The `tag` value and communicator enables the library to determine
7294
the matching of multiple sends and receives between the same ranks.
7395

74-
Communicators other than `MPI_COMM_WORLD` are not supprted for now.
96+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
7597

7698
This operation can optionally return an `!mpi.retval` value that can be used
7799
to check for errors.
78100
}];
79101

80-
let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
102+
let arguments = (
103+
ins AnyMemRef : $ref,
104+
I32 : $tag,
105+
I32 : $rank
106+
);
81107

82108
let results = (outs Optional<MPI_Retval>:$retval);
83109

@@ -87,6 +113,42 @@ def MPI_SendOp : MPI_Op<"send", []> {
87113
let hasCanonicalizer = 1;
88114
}
89115

116+
//===----------------------------------------------------------------------===//
117+
// ISendOp
118+
//===----------------------------------------------------------------------===//
119+
120+
def MPI_ISendOp : MPI_Op<"isend", []> {
121+
let summary =
122+
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
123+
let description = [{
124+
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
125+
rank `dest`. The `tag` value and communicator enables the library to
126+
determine the matching of multiple sends and receives between the same
127+
ranks.
128+
129+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
130+
131+
This operation can optionally return an `!mpi.retval` value that can be used
132+
to check for errors.
133+
}];
134+
135+
let arguments = (
136+
ins AnyMemRef : $ref,
137+
I32 : $tag,
138+
I32 : $rank
139+
);
140+
141+
let results = (
142+
outs Optional<MPI_Retval>:$retval,
143+
MPI_Request : $req
144+
);
145+
146+
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
147+
"`:` type($ref) `,` type($tag) `,` type($rank) "
148+
"`->` type(results)";
149+
let hasCanonicalizer = 1;
150+
}
151+
90152
//===----------------------------------------------------------------------===//
91153
// RecvOp
92154
//===----------------------------------------------------------------------===//
@@ -100,24 +162,142 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
100162
determine the matching of multiple sends and receives between the same
101163
ranks.
102164

103-
Communicators other than `MPI_COMM_WORLD` are not supprted for now.
165+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
104166
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
105167
is not yet ported to MLIR.
106168

107169
This operation can optionally return an `!mpi.retval` value that can be used
108170
to check for errors.
109171
}];
110172

111-
let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
173+
let arguments = (
174+
ins AnyMemRef : $ref,
175+
I32 : $tag, I32 : $rank
176+
);
112177

113178
let results = (outs Optional<MPI_Retval>:$retval);
114179

115-
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
180+
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
116181
"type($ref) `,` type($tag) `,` type($rank)"
117182
"(`->` type($retval)^)?";
118183
let hasCanonicalizer = 1;
119184
}
120185

186+
//===----------------------------------------------------------------------===//
187+
// IRecvOp
188+
//===----------------------------------------------------------------------===//
189+
190+
def MPI_IRecvOp : MPI_Op<"irecv", []> {
191+
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
192+
"MPI_COMM_WORLD, &req)`";
193+
let description = [{
194+
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
195+
from rank `dest`. The `tag` value and communicator enables the library to
196+
determine the matching of multiple sends and receives between the same
197+
ranks.
198+
199+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
200+
201+
This operation can optionally return an `!mpi.retval` value that can be used
202+
to check for errors.
203+
}];
204+
205+
let arguments = (
206+
ins AnyMemRef : $ref,
207+
I32 : $tag,
208+
I32 : $rank
209+
);
210+
211+
let results = (
212+
outs Optional<MPI_Retval>:$retval,
213+
MPI_Request : $req
214+
);
215+
216+
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
217+
"type($ref) `,` type($tag) `,` type($rank) `->`"
218+
"type(results)";
219+
let hasCanonicalizer = 1;
220+
}
221+
222+
//===----------------------------------------------------------------------===//
223+
// AllReduceOp
224+
//===----------------------------------------------------------------------===//
225+
226+
def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
227+
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
228+
"MPI_COMM_WORLD)`";
229+
let description = [{
230+
MPI_Allreduce performs a reduction operation on the values in the sendbuf
231+
array and stores the result in the recvbuf array. The operation is
232+
performed across all processes in the communicator.
233+
234+
The `op` attribute specifies the reduction operation to be performed.
235+
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
236+
supported.
237+
238+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
239+
240+
This operation can optionally return an `!mpi.retval` value that can be used
241+
to check for errors.
242+
}];
243+
244+
let arguments = (
245+
ins AnyMemRef : $sendbuf,
246+
AnyMemRef : $recvbuf,
247+
MPI_OpClassAttr : $op
248+
);
249+
250+
let results = (outs Optional<MPI_Retval>:$retval);
251+
252+
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
253+
"type($sendbuf) `,` type($recvbuf)"
254+
"(`->` type($retval)^)?";
255+
}
256+
257+
//===----------------------------------------------------------------------===//
258+
// BarrierOp
259+
//===----------------------------------------------------------------------===//
260+
261+
def MPI_Barrier : MPI_Op<"barrier", []> {
262+
let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
263+
let description = [{
264+
MPI_Barrier blocks execution until all processes in the communicator have
265+
reached this routine.
266+
267+
Communicators other than `MPI_COMM_WORLD` are not supported for now.
268+
269+
This operation can optionally return an `!mpi.retval` value that can be used
270+
to check for errors.
271+
}];
272+
273+
let results = (outs Optional<MPI_Retval>:$retval);
274+
275+
let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
276+
}
277+
278+
//===----------------------------------------------------------------------===//
279+
// WaitOp
280+
//===----------------------------------------------------------------------===//
281+
282+
def MPI_Wait : MPI_Op<"wait", []> {
283+
let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
284+
let description = [{
285+
MPI_Wait blocks execution until the request has completed.
286+
287+
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
288+
is not yet ported to MLIR.
289+
290+
This operation can optionally return an `!mpi.retval` value that can be used
291+
to check for errors.
292+
}];
293+
294+
let arguments = (ins MPI_Request : $req);
295+
296+
let results = (outs Optional<MPI_Retval>:$retval);
297+
298+
let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
299+
"(`->` type($retval) ^)?";
300+
}
121301

122302
//===----------------------------------------------------------------------===//
123303
// FinalizeOp
@@ -139,7 +319,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
139319
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
140320
}
141321

142-
143322
//===----------------------------------------------------------------------===//
144323
// RetvalCheckOp
145324
//===----------------------------------------------------------------------===//
@@ -163,10 +342,8 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
163342
let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)";
164343
}
165344

166-
167-
168345
//===----------------------------------------------------------------------===//
169-
// RetvalCheckOp
346+
// ErrorClassOp
170347
//===----------------------------------------------------------------------===//
171348

172349
def MPI_ErrorClassOp : MPI_Op<"error_class", []> {

mlir/include/mlir/Dialect/MPI/IR/MPITypes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,26 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
4040
}];
4141
}
4242

43+
//===----------------------------------------------------------------------===//
44+
// mpi::RequestType
45+
//===----------------------------------------------------------------------===//
46+
47+
def MPI_Request : MPI_Type<"Request", "request"> {
48+
let summary = "MPI asynchronous request handler";
49+
let description = [{
50+
This type represents a handler to an asynchronous request.
51+
}];
52+
}
53+
54+
//===----------------------------------------------------------------------===//
55+
// mpi::StatusType
56+
//===----------------------------------------------------------------------===//
57+
58+
def MPI_Status : MPI_Type<"Status", "status"> {
59+
let summary = "MPI reception operation status type";
60+
let description = [{
61+
This type represents the status of a reception operation.
62+
}];
63+
}
64+
4365
#endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD

mlir/lib/Dialect/MPI/IR/MPIOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ void mlir::mpi::RecvOp::getCanonicalizationPatterns(
5353
results.add<FoldCast<mlir::mpi::RecvOp>>(context);
5454
}
5555

56+
void mlir::mpi::ISendOp::getCanonicalizationPatterns(
57+
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
58+
results.add<FoldCast<mlir::mpi::ISendOp>>(context);
59+
}
60+
61+
void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
62+
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
63+
results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
64+
}
65+
5666
//===----------------------------------------------------------------------===//
5767
// TableGen'd op method definitions
5868
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)