12
12
include "mlir/IR/OpBase.td"
13
13
include "mlir/IR/AttrTypeBase.td"
14
14
include "mlir/IR/BuiltinTypeInterfaces.td"
15
- include "mlir/IR/CommonAttrConstraints.td"
16
15
include "mlir/IR/EnumAttr.td"
17
16
18
17
//===----------------------------------------------------------------------===//
@@ -32,13 +31,11 @@ def Mesh_Dialect : Dialect {
32
31
];
33
32
34
33
let useDefaultAttributePrinterParser = 1;
35
- let useDefaultTypePrinterParser = 1;
36
34
let hasConstantMaterializer = 1;
37
35
}
38
36
39
37
def Mesh_MeshAxis : I<16>;
40
38
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
41
- def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
42
39
43
40
//===----------------------------------------------------------------------===//
44
41
// Mesh Enums.
@@ -62,33 +59,104 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
62
59
}
63
60
64
61
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
65
- let assemblyFormat = "$value";
66
- }
67
-
68
- class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
69
- string baseCppClass = "::mlir::Type">
70
- : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
71
- let mnemonic = typeMnemonic;
72
- }
73
-
74
- def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
75
- let summary = "sharding definition";
76
- let assemblyFormat = "";
62
+ let assemblyFormat = "`<` $value `>`";
77
63
}
78
64
79
65
//===----------------------------------------------------------------------===//
80
66
// Mesh Attribute
81
67
//===----------------------------------------------------------------------===//
82
68
83
- def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
84
- let mnemonic = "axisarray";
85
- let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
86
- let assemblyFormat = "`[` $axes `]`";
69
+ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
70
+ let mnemonic = "shard";
71
+
72
+ let parameters = (ins
73
+ AttrParameter<"::mlir::FlatSymbolRefAttr",
74
+ "The mesh on which tensors are sharded.">:$mesh,
75
+ ArrayRefParameter<"MeshAxesAttr">:$split_axes,
76
+ OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
77
+ OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
78
+ );
79
+
80
+ let summary = "Attribute that extends tensor type to distributed tensor type.";
81
+
82
+ let description = [{
83
+ The MeshSharding attribute is used in a `mesh.shard` operation.
84
+ It specifies how a tensor is sharded and distributed across the process
85
+ mesh.
86
+
87
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
88
+ mesh where the distributed tensor is placed. The symbol must resolve to a
89
+ `mesh.mesh` operation.
90
+
91
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
92
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
93
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
94
+ along the x and y axes of the device mesh.
95
+
96
+ 3. `partial_axes`: if not empty, this signifies that the tensor is partial
97
+ one along the specified mesh axes. An all-reduce should be applied to obtain
98
+ the complete tensor, with reduction type being specified by `partial_type`.
99
+
100
+ 4. `partial_type`: indicates the reduction type of the possible all-reduce
101
+ op. It has 4 possible values:
102
+ `generic`: is not an allowed value inside a shard attribute.
103
+
104
+ Example:
105
+
106
+ ```
107
+ mesh.mesh @mesh0(shape = 2x2x4)
108
+
109
+ // The tensor is fully replicated on @mesh0.
110
+ // Currently, there must be at least one sub-array present in axes, even
111
+ // if it's empty. Otherwise, a parsing error will occur.
112
+ #mesh.shard<@mesh0, [[]]>
113
+
114
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0
115
+ #mesh.shard<@mesh0, [[0]]>
116
+
117
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
118
+ // it is also a partial_sum along mesh axis 1.
119
+ #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
120
+
121
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
122
+ // it is also a partial_max along mesh axis 1.
123
+ #mesh.shard<@mesh0, [[0]], partial = max[1]>
124
+
125
+ // Could be used in the attribute of mesh.shard op
126
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
127
+ ```
128
+ }];
129
+ let assemblyFormat = [{
130
+ `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
131
+ $partial_axes^ `]`)? `>`
132
+ }];
133
+
134
+ let builders = [
135
+ AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
136
+ "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
137
+ "ArrayRef<MeshAxis>": $partial_axes,
138
+ "mesh::ReductionKind": $partial_type), [{
139
+ SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
140
+ split_axes, [&](ArrayRef<MeshAxis> array) {
141
+ return MeshAxesAttr::get($_ctxt, array);
142
+ });
143
+ return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
144
+ partial_type);
145
+ }]>,
146
+ AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
147
+ "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
148
+ return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
149
+ }]>
150
+ ];
151
+
87
152
let extraClassDeclaration = [{
88
- size_t size() const { return getAxes().size(); }
89
- auto begin() const { return getAxes().begin(); }
90
- auto end() const { return getAxes().end(); }
153
+ bool operator==(::mlir::Attribute rhs) const;
154
+ bool operator!=(::mlir::Attribute rhs) const;
155
+ bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
156
+ bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
91
157
}];
158
+
159
+ let genVerifyDecl = 1;
92
160
}
93
161
94
162
#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
0 commit comments