Skip to content

Commit 3bb7999

Browse files
stellaraccidentStella Laurenzo
authored and
Stella Laurenzo
committed
[mlir] Add global_load and global_store ops to ml_program.
* Adds simple, non-atomic, non-volatile, non-synchronized direct load/store ops. Differential Revision: https://reviews.llvm.org/D126230
1 parent aaf04c7 commit 3bb7999

File tree

12 files changed

+376
-5
lines changed

12 files changed

+376
-5
lines changed

mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@ mlir_tablegen(MLProgramAttributes.cpp.inc -gen-attrdef-defs)
88
add_public_tablegen_target(MLIRMLProgramAttributesIncGen)
99
add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen)
1010
add_mlir_doc(MLProgramAttributes MLProgramAttributes Dialects/ -gen-attrdef-doc)
11+
12+
set(LLVM_TARGET_DEFINITIONS MLProgramTypes.td)
13+
mlir_tablegen(MLProgramTypes.h.inc -gen-typedef-decls)
14+
mlir_tablegen(MLProgramTypes.cpp.inc -gen-typedef-defs)
15+
add_public_tablegen_target(MLIRMLProgramTypesIncGen)
16+
add_dependencies(mlir-headers MLIRMLProgramTypesIncGen)
17+
add_mlir_doc(MLProgramTypes MLProgramTypes Dialects/ -gen-typedef-doc)

mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
1010

1111
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h"
12+
#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.h"
1213
#include "mlir/IR/Dialect.h"
1314
#include "mlir/IR/FunctionInterfaces.h"
1415
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def MLProgram_Dialect : Dialect {
2828
}];
2929

3030
let useDefaultAttributePrinterParser = 1;
31+
let useDefaultTypePrinterParser = 1;
3132
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
3233
}
3334

mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLPROGRAM_OPS
1111

1212
include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
13+
include "mlir/Dialect/MLProgram/IR/MLProgramTypes.td"
1314
include "mlir/Interfaces/CallInterfaces.td"
1415
include "mlir/Interfaces/ControlFlowInterfaces.td"
1516
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -152,6 +153,51 @@ def MLProgram_GlobalOp : MLProgram_Op<"global", [
152153
let hasVerifier = 1;
153154
}
154155

156+
//===----------------------------------------------------------------------===//
157+
// GlobalLoadOp
158+
//===----------------------------------------------------------------------===//
159+
160+
def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
161+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
162+
]> {
163+
let summary = "Direct load of a mutable value from a global";
164+
let description = [{
165+
Performs a non-atomic, non-volatile, non-synchronized load from a global
166+
that may be mutable.
167+
168+
It is fully expected that these constraints are not suitable for
169+
all situations, and alternative ops should be defined and used for more
170+
advanced cases.
171+
172+
This op is side effecting and may not be valid to use in graph regions
173+
without additional consideration to evaluation order constraints.
174+
175+
Example:
176+
177+
```mlir
178+
%0 = ml_program.global_load @foobar : tensor<?xi32>
179+
```
180+
}];
181+
182+
let arguments = (ins
183+
Arg<SymbolRefAttr, "", [MemRead]>:$global,
184+
Variadic<MLProgram_TokenType>:$consumeTokens
185+
);
186+
let results = (outs
187+
AnyType:$result,
188+
Optional<MLProgram_TokenType>:$produceToken
189+
);
190+
191+
let assemblyFormat = [{
192+
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
193+
}];
194+
195+
let extraClassDeclaration = [{
196+
/// Gets the corresponding GlobalOp (or nullptr).
197+
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
198+
}];
199+
}
200+
155201
//===----------------------------------------------------------------------===//
156202
// GlobalLoadConstOp
157203
//===----------------------------------------------------------------------===//
@@ -175,14 +221,59 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
175221
}];
176222

177223
let arguments = (ins
178-
FlatSymbolRefAttr:$global
224+
SymbolRefAttr:$global
179225
);
180226
let results = (outs
181227
AnyType:$result
182228
);
183229

184230
let assemblyFormat = [{
185-
$global attr-dict `:` type($result)
231+
$global `:` type($result) attr-dict
232+
}];
233+
234+
let extraClassDeclaration = [{
235+
/// Gets the corresponding GlobalOp (or nullptr).
236+
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
237+
}];
238+
}
239+
240+
//===----------------------------------------------------------------------===//
241+
// GlobalStoreOp
242+
//===----------------------------------------------------------------------===//
243+
244+
def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
245+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
246+
]> {
247+
let summary = "Direct store of a value into a mutable global";
248+
let description = [{
249+
Performs a non-atomic, non-volatile, non-synchronized store to a mutable
250+
global.
251+
252+
It is fully expected that these constraints are not suitable for
253+
all situations, and alternative ops should be defined and used for more
254+
advanced cases.
255+
256+
This op is side effecting and may not be valid to use in graph regions
257+
without additional consideration to evaluation order constraints.
258+
259+
Example:
260+
261+
```mlir
262+
ml_program.global_store @foobar = %0 : tensor<?xi32>
263+
```
264+
}];
265+
266+
let arguments = (ins
267+
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
268+
AnyType:$value,
269+
Variadic<MLProgram_TokenType>:$consumeTokens
270+
);
271+
let results = (outs
272+
Optional<MLProgram_TokenType>:$produceToken
273+
);
274+
275+
let assemblyFormat = [{
276+
$global `=` $value `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($value) attr-dict
186277
}];
187278

188279
let extraClassDeclaration = [{
@@ -310,4 +401,24 @@ def MLProgram_ReturnOp : MLProgram_Op<"return", [
310401
let hasVerifier = 1;
311402
}
312403

404+
//===----------------------------------------------------------------------===//
405+
// TokenOp
406+
//===----------------------------------------------------------------------===//
407+
408+
def MLProgram_TokenOp : MLProgram_Op<"token", [
409+
NoSideEffect
410+
]> {
411+
let summary = "Produces a new token value";
412+
let description = [{
413+
Token values are used to chain side effecting ops in a graph so as to
414+
establish an execution order. This op produces a token.
415+
}];
416+
417+
let results = (outs
418+
MLProgram_TokenType:$token
419+
);
420+
421+
let assemblyFormat = "attr-dict";
422+
}
423+
313424
#endif // MLPROGRAM_OPS
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MLProgramTypes.h - Type Classes --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMTYPES_H_
10+
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMTYPES_H_
11+
12+
#include "mlir/IR/Types.h"
13+
14+
//===----------------------------------------------------------------------===//
15+
// Tablegen Type Declarations
16+
//===----------------------------------------------------------------------===//
17+
18+
#define GET_TYPEDEF_CLASSES
19+
#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.h.inc"
20+
21+
#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMTYPES_H_
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- MLProgramTypes.td - Type definitions ----------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLPROGRAM_TYPES
10+
#define MLPROGRAM_TYPES
11+
12+
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
14+
15+
class MLProgram_Type<string name, list<Trait> traits = [],
16+
string baseCppClass = "::mlir::Type">
17+
: TypeDef<MLProgram_Dialect, name, traits, baseCppClass> {}
18+
19+
def MLProgram_TokenType : MLProgram_Type<"Token"> {
20+
let summary = "Token for establishing execution ordering in a graph";
21+
let mnemonic = "token";
22+
}
23+
24+
#endif // MLPROGRAM_TYPES

mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMLProgram
88
DEPENDS
99
MLIRMLProgramOpsIncGen
1010
MLIRMLProgramAttributesIncGen
11+
MLIRMLProgramTypesIncGen
1112

1213
LINK_LIBS PUBLIC
1314
MLIRDialect

mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ using namespace mlir::ml_program;
2020
#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
2121
#define GET_ATTRDEF_CLASSES
2222
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc"
23+
#define GET_TYPEDEF_CLASSES
24+
#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.cpp.inc"
2325

2426
namespace {
2527
struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface {
@@ -40,9 +42,16 @@ void ml_program::MLProgramDialect::initialize() {
4042
addAttributes<
4143
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc"
4244
>();
45+
46+
#define GET_TYPEDEF_LIST
47+
addTypes<
48+
#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.cpp.inc"
49+
>();
50+
4351
addOperations<
4452
#define GET_OP_LIST
4553
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
4654
>();
55+
4756
addInterfaces<MLProgramOpAsmDialectInterface>();
4857
}

mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,64 @@ using namespace mlir::ml_program;
1717
// Custom asm helpers
1818
//===----------------------------------------------------------------------===//
1919

20+
/// Parse and print an ordering clause for a variadic of consuming tokens
21+
/// and an optional producing token.
22+
///
23+
/// Syntax:
24+
/// ordering(%0, %1 -> !ml_program.token)
25+
/// ordering(() -> !ml_program.token)
26+
/// ordering(%0, %1)
27+
///
28+
/// If both the consuming and producing token are not present on the op, then
29+
/// the clause prints nothing.
30+
static ParseResult parseTokenOrdering(
31+
OpAsmParser &parser,
32+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
33+
Type &produceTokenType) {
34+
if (failed(parser.parseOptionalKeyword("ordering")) ||
35+
failed(parser.parseLParen()))
36+
return success();
37+
38+
// Parse consuming token list. If there are no consuming tokens, the
39+
// '()' null list represents this.
40+
if (succeeded(parser.parseOptionalLParen())) {
41+
if (failed(parser.parseRParen()))
42+
return failure();
43+
} else {
44+
if (failed(parser.parseOperandList(consumeTokens,
45+
/*requiredOperandCount=*/-1)))
46+
return failure();
47+
}
48+
49+
// Parse optional producer token.
50+
if (succeeded(parser.parseOptionalArrow()))
51+
if (failed(parser.parseType(produceTokenType)))
52+
return failure();
53+
54+
if (failed(parser.parseRParen()))
55+
return failure();
56+
57+
return success();
58+
}
59+
60+
static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61+
OperandRange consumeTokens,
62+
Type produceTokenType) {
63+
if (consumeTokens.empty() && !produceTokenType)
64+
return;
65+
66+
p << " ordering(";
67+
if (consumeTokens.empty())
68+
p << "()";
69+
else
70+
p.printOperands(consumeTokens);
71+
if (produceTokenType) {
72+
p << " -> ";
73+
p.printType(produceTokenType);
74+
}
75+
p << ")";
76+
}
77+
2078
/// some.op custom<TypeOrAttr>($type, $attr)
2179
///
2280
/// Uninitialized:
@@ -111,6 +169,30 @@ LogicalResult GlobalOp::verify() {
111169
return success();
112170
}
113171

172+
//===----------------------------------------------------------------------===//
173+
// GlobalLoadOp
174+
//===----------------------------------------------------------------------===//
175+
176+
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
177+
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
178+
getOperation()->getParentOp(), getGlobalAttr());
179+
}
180+
181+
LogicalResult
182+
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
183+
GlobalOp referrent = getGlobalOp(symbolTable);
184+
if (!referrent)
185+
return emitOpError() << "undefined global: " << getGlobal();
186+
187+
if (referrent.getType() != getResult().getType()) {
188+
return emitOpError() << "cannot load from global typed "
189+
<< referrent.getType() << " as "
190+
<< getResult().getType();
191+
}
192+
193+
return success();
194+
}
195+
114196
//===----------------------------------------------------------------------===//
115197
// GlobalLoadConstOp
116198
//===----------------------------------------------------------------------===//
@@ -138,6 +220,35 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
138220
return success();
139221
}
140222

223+
//===----------------------------------------------------------------------===//
224+
// GlobalStoreOp
225+
//===----------------------------------------------------------------------===//
226+
227+
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
228+
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
229+
getOperation()->getParentOp(), getGlobalAttr());
230+
}
231+
232+
LogicalResult
233+
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
234+
GlobalOp referrent = getGlobalOp(symbolTable);
235+
if (!referrent)
236+
return emitOpError() << "undefined global: " << getGlobal();
237+
238+
if (!referrent.getIsMutable()) {
239+
return emitOpError() << "cannot store to an immutable global "
240+
<< getGlobal();
241+
}
242+
243+
if (referrent.getType() != getValue().getType()) {
244+
return emitOpError() << "cannot store to a global typed "
245+
<< referrent.getType() << " from "
246+
<< getValue().getType();
247+
}
248+
249+
return success();
250+
}
251+
141252
//===----------------------------------------------------------------------===//
142253
// SubgraphOp
143254
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)