Skip to content

Commit fca6983

Browse files
authored
[mlir][mesh] adding shard-size control (#98145)
- Replacing `#mesh.sharding` attribute with operation `mesh.sharding` - extended semantics now allow providing optional `halo_sizes` and `sharded_dims_sizes` - internally a sharding is represented as a non-IR class `mesh::MeshSharding` What previously was ```mlir %sharded0 = mesh.shard %arg0 <@Mesh0, [[0]]> : tensor<4x8xf32> %sharded1 = mesh.shard %arg1 <@Mesh0, [[0]]> annotate_for_users : tensor<16x8xf32> ``` is now ```mlir %sharding = mesh.sharding @Mesh0, [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ``` and allows additional annotations to control the shard sizes: ```mlir mesh.mesh @Mesh0 (shape = 4) %sharding0 = mesh.sharding @Mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32> %sharding1 = mesh.sharding @Mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding %1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32> ``` - `mesh.shard` op accepts additional optional attribute `force`, useful for halo updates - Some initial spmdization support for the new semantics - Support for `tensor.empty` reacting on `sharded_dims_sizes` and `halo_sizes` in the sharding - New collective operation `mesh.update_halo` as a spmdized target for shardings with `halo_sizes` @sogartar @yaochengji
1 parent 0c25f85 commit fca6983

28 files changed

+1639
-695
lines changed

mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
1313
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
1414
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
1515

16+
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
17+
mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
18+
mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
19+
1620
set(LLVM_TARGET_DEFINITIONS MeshOps.td)
1721
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
1822
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 22 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/IR/OpBase.td"
1313
include "mlir/IR/AttrTypeBase.td"
1414
include "mlir/IR/BuiltinTypeInterfaces.td"
15+
include "mlir/IR/CommonAttrConstraints.td"
1516
include "mlir/IR/EnumAttr.td"
1617

1718
//===----------------------------------------------------------------------===//
@@ -31,11 +32,13 @@ def Mesh_Dialect : Dialect {
3132
];
3233

3334
let useDefaultAttributePrinterParser = 1;
35+
let useDefaultTypePrinterParser = 1;
3436
let hasConstantMaterializer = 1;
3537
}
3638

3739
def Mesh_MeshAxis : I<16>;
3840
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
41+
def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
3942

4043
//===----------------------------------------------------------------------===//
4144
// Mesh Enums.
@@ -59,104 +62,33 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
5962
}
6063

6164
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
62-
let assemblyFormat = "`<` $value `>`";
65+
let assemblyFormat = "$value";
66+
}
67+
68+
class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
69+
string baseCppClass = "::mlir::Type">
70+
: TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
71+
let mnemonic = typeMnemonic;
72+
}
73+
74+
def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
75+
let summary = "sharding definition";
76+
let assemblyFormat = "";
6377
}
6478

6579
//===----------------------------------------------------------------------===//
6680
// Mesh Attribute
6781
//===----------------------------------------------------------------------===//
6882

69-
def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
70-
let mnemonic = "shard";
71-
72-
let parameters = (ins
73-
AttrParameter<"::mlir::FlatSymbolRefAttr",
74-
"The mesh on which tensors are sharded.">:$mesh,
75-
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
76-
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
77-
OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
78-
);
79-
80-
let summary = "Attribute that extends tensor type to distributed tensor type.";
81-
82-
let description = [{
83-
The MeshSharding attribute is used in a `mesh.shard` operation.
84-
It specifies how a tensor is sharded and distributed across the process
85-
mesh.
86-
87-
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
88-
mesh where the distributed tensor is placed. The symbol must resolve to a
89-
`mesh.mesh` operation.
90-
91-
2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
92-
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
93-
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
94-
along the x and y axes of the device mesh.
95-
96-
3. `partial_axes`: if not empty, this signifies that the tensor is partial
97-
one along the specified mesh axes. An all-reduce should be applied to obtain
98-
the complete tensor, with reduction type being specified by `partial_type`.
99-
100-
4. `partial_type`: indicates the reduction type of the possible all-reduce
101-
op. It has 4 possible values:
102-
`generic`: is not an allowed value inside a shard attribute.
103-
104-
Example:
105-
106-
```
107-
mesh.mesh @mesh0(shape = 2x2x4)
108-
109-
// The tensor is fully replicated on @mesh0.
110-
// Currently, there must be at least one sub-array present in axes, even
111-
// if it's empty. Otherwise, a parsing error will occur.
112-
#mesh.shard<@mesh0, [[]]>
113-
114-
// The tensor is sharded on the first dimension along axis 0 of @mesh0
115-
#mesh.shard<@mesh0, [[0]]>
116-
117-
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
118-
// it is also a partial_sum along mesh axis 1.
119-
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>
120-
121-
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
122-
// it is also a partial_max along mesh axis 1.
123-
#mesh.shard<@mesh0, [[0]], partial = max[1]>
124-
125-
// Could be used in the attribute of mesh.shard op
126-
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
127-
```
128-
}];
129-
let assemblyFormat = [{
130-
`<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
131-
$partial_axes^ `]`)? `>`
132-
}];
133-
134-
let builders = [
135-
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
136-
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
137-
"ArrayRef<MeshAxis>": $partial_axes,
138-
"mesh::ReductionKind": $partial_type), [{
139-
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
140-
split_axes, [&](ArrayRef<MeshAxis> array) {
141-
return MeshAxesAttr::get($_ctxt, array);
142-
});
143-
return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
144-
partial_type);
145-
}]>,
146-
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
147-
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
148-
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
149-
}]>
150-
];
151-
83+
def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
84+
let mnemonic = "axisarray";
85+
let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
86+
let assemblyFormat = "`[` $axes `]`";
15287
let extraClassDeclaration = [{
153-
bool operator==(::mlir::Attribute rhs) const;
154-
bool operator!=(::mlir::Attribute rhs) const;
155-
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
156-
bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
88+
size_t size() const { return getAxes().size(); }
89+
auto begin() const { return getAxes().begin(); }
90+
auto end() const { return getAxes().end(); }
15791
}];
158-
159-
let genVerifyDecl = 1;
16092
}
16193

16294
#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ namespace mesh {
2424

2525
using MeshAxis = int16_t;
2626
using MeshAxesAttr = DenseI16ArrayAttr;
27+
using ShardShapeAttr = DenseI64ArrayAttr;
28+
using HaloSizePairAttr = DenseI64ArrayAttr;
2729

2830
} // namespace mesh
2931
} // namespace mlir
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
3335
#define GET_ATTRDEF_CLASSES
3436
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
3537

38+
namespace mlir {
39+
namespace mesh {
40+
41+
class MeshSharding {
42+
private:
43+
::mlir::FlatSymbolRefAttr mesh;
44+
SmallVector<MeshAxesAttr> split_axes;
45+
SmallVector<MeshAxis> partial_axes;
46+
ReductionKind partial_type;
47+
SmallVector<int64_t> static_halo_sizes;
48+
SmallVector<int64_t> static_sharded_dims_sizes;
49+
SmallVector<Value> dynamic_halo_sizes;
50+
SmallVector<Value> dynamic_sharded_dims_sizes;
51+
52+
public:
53+
MeshSharding() = default;
54+
MeshSharding(Value rhs);
55+
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
56+
ArrayRef<MeshAxesAttr> split_axes_,
57+
ArrayRef<MeshAxis> partial_axes_ = {},
58+
ReductionKind partial_type_ = ReductionKind::Sum,
59+
ArrayRef<int64_t> static_halo_sizes_ = {},
60+
ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
61+
ArrayRef<Value> dynamic_halo_sizes_ = {},
62+
ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
63+
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
64+
::llvm::StringRef getMesh() const { return mesh.getValue(); }
65+
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
66+
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
67+
ReductionKind getPartialType() const { return partial_type; }
68+
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
69+
ArrayRef<int64_t> getStaticShardedDimsSizes() const {
70+
return static_sharded_dims_sizes;
71+
}
72+
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
73+
ArrayRef<Value> getDynamicShardedDimsSizes() const {
74+
return dynamic_sharded_dims_sizes;
75+
}
76+
operator bool() const { return (!mesh) == false; }
77+
bool operator==(Value rhs) const;
78+
bool operator!=(Value rhs) const;
79+
bool operator==(const MeshSharding &rhs) const;
80+
bool operator!=(const MeshSharding &rhs) const;
81+
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
82+
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
83+
};
84+
85+
} // namespace mesh
86+
} // namespace mlir
87+
88+
#define GET_TYPEDEF_CLASSES
89+
#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
90+
3691
#define GET_OP_CLASSES
3792
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
3893

@@ -50,9 +105,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
50105
}
51106

52107
// Is the same tensor replicated on all processes.
53-
inline bool isFullReplication(MeshShardingAttr attr) {
54-
return attr.getPartialAxes().empty() &&
55-
llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
108+
inline bool isFullReplication(MeshSharding sharding) {
109+
return sharding.getPartialAxes().empty() &&
110+
llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
56111
return axes.asArrayRef().empty();
57112
});
58113
}
@@ -80,8 +135,10 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
80135
template <>
81136
inline mesh::MeshOp
82137
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
83-
return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
84-
symbolTableCollection);
138+
return getMesh(
139+
op.getOperation(),
140+
cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
141+
symbolTableCollection);
85142
}
86143

87144
// Get the number of processes that participate in each group
@@ -131,22 +188,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
131188
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
132189
// result in a shape for each shard of ?x2x?.
133190
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
134-
MeshShardingAttr sharding);
191+
MeshSharding sharding);
135192

136193
// If ranked tensor type return its sharded counterpart.
137194
//
138195
// If not ranked tensor type return `type`.
139196
// `sharding` in that case must be null.
140-
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
197+
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
141198

142199
// Insert shard op if there is not one that already has the same sharding.
143200
// May insert resharding if required.
144-
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
201+
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
145202
OpOperand &operand,
146203
OpBuilder &builder);
147-
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
148-
OpResult result, OpBuilder &builder);
149-
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
204+
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
205+
OpBuilder &builder);
206+
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
150207
OpOperand &operand,
151208
OpBuilder &builder);
152209

0 commit comments

Comments
 (0)