Skip to content

Commit 3968942

Browse files
committed
Revert "[mlir][mesh] adding shard-size control (#98145)"
This reverts commit fca6983. Also reverts the fixup: "[mlir] Fix -Wunused-variable in MeshOps.cpp (NFC)" This reverts commit fc73736.
1 parent d07f106 commit 3968942

28 files changed

+695
-1641
lines changed

mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
1313
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
1414
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
1515

16-
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
17-
mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
18-
mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
19-
2016
set(LLVM_TARGET_DEFINITIONS MeshOps.td)
2117
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
2218
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
include "mlir/IR/OpBase.td"
1313
include "mlir/IR/AttrTypeBase.td"
1414
include "mlir/IR/BuiltinTypeInterfaces.td"
15-
include "mlir/IR/CommonAttrConstraints.td"
1615
include "mlir/IR/EnumAttr.td"
1716

1817
//===----------------------------------------------------------------------===//
@@ -32,13 +31,11 @@ def Mesh_Dialect : Dialect {
3231
];
3332

3433
let useDefaultAttributePrinterParser = 1;
35-
let useDefaultTypePrinterParser = 1;
3634
let hasConstantMaterializer = 1;
3735
}
3836

3937
def Mesh_MeshAxis : I<16>;
4038
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
41-
def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
4239

4340
//===----------------------------------------------------------------------===//
4441
// Mesh Enums.
@@ -62,33 +59,104 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
6259
}
6360

6461
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
65-
let assemblyFormat = "$value";
66-
}
67-
68-
class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
69-
string baseCppClass = "::mlir::Type">
70-
: TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
71-
let mnemonic = typeMnemonic;
72-
}
73-
74-
def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
75-
let summary = "sharding definition";
76-
let assemblyFormat = "";
62+
let assemblyFormat = "`<` $value `>`";
7763
}
7864

7965
//===----------------------------------------------------------------------===//
8066
// Mesh Attribute
8167
//===----------------------------------------------------------------------===//
8268

83-
def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
84-
let mnemonic = "axisarray";
85-
let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
86-
let assemblyFormat = "`[` $axes `]`";
69+
def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
70+
let mnemonic = "shard";
71+
72+
let parameters = (ins
73+
AttrParameter<"::mlir::FlatSymbolRefAttr",
74+
"The mesh on which tensors are sharded.">:$mesh,
75+
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
76+
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
77+
OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
78+
);
79+
80+
let summary = "Attribute that extends tensor type to distributed tensor type.";
81+
82+
let description = [{
83+
The MeshSharding attribute is used in a `mesh.shard` operation.
84+
It specifies how a tensor is sharded and distributed across the process
85+
mesh.
86+
87+
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
88+
mesh where the distributed tensor is placed. The symbol must resolve to a
89+
`mesh.mesh` operation.
90+
91+
2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
92+
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
93+
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
94+
along the x and y axes of the device mesh.
95+
96+
3. `partial_axes`: if not empty, this signifies that the tensor is partial
97+
one along the specified mesh axes. An all-reduce should be applied to obtain
98+
the complete tensor, with reduction type being specified by `partial_type`.
99+
100+
4. `partial_type`: indicates the reduction type of the possible all-reduce
101+
op. It has 4 possible values:
102+
`generic`: is not an allowed value inside a shard attribute.
103+
104+
Example:
105+
106+
```
107+
mesh.mesh @mesh0(shape = 2x2x4)
108+
109+
// The tensor is fully replicated on @mesh0.
110+
// Currently, there must be at least one sub-array present in axes, even
111+
// if it's empty. Otherwise, a parsing error will occur.
112+
#mesh.shard<@mesh0, [[]]>
113+
114+
// The tensor is sharded on the first dimension along axis 0 of @mesh0
115+
#mesh.shard<@mesh0, [[0]]>
116+
117+
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
118+
// it is also a partial_sum along mesh axis 1.
119+
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>
120+
121+
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
122+
// it is also a partial_max along mesh axis 1.
123+
#mesh.shard<@mesh0, [[0]], partial = max[1]>
124+
125+
// Could be used in the attribute of mesh.shard op
126+
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
127+
```
128+
}];
129+
let assemblyFormat = [{
130+
`<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
131+
$partial_axes^ `]`)? `>`
132+
}];
133+
134+
let builders = [
135+
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
136+
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
137+
"ArrayRef<MeshAxis>": $partial_axes,
138+
"mesh::ReductionKind": $partial_type), [{
139+
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
140+
split_axes, [&](ArrayRef<MeshAxis> array) {
141+
return MeshAxesAttr::get($_ctxt, array);
142+
});
143+
return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
144+
partial_type);
145+
}]>,
146+
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
147+
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
148+
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
149+
}]>
150+
];
151+
87152
let extraClassDeclaration = [{
88-
size_t size() const { return getAxes().size(); }
89-
auto begin() const { return getAxes().begin(); }
90-
auto end() const { return getAxes().end(); }
153+
bool operator==(::mlir::Attribute rhs) const;
154+
bool operator!=(::mlir::Attribute rhs) const;
155+
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
156+
bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
91157
}];
158+
159+
let genVerifyDecl = 1;
92160
}
93161

94162
#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 11 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ namespace mesh {
2424

2525
using MeshAxis = int16_t;
2626
using MeshAxesAttr = DenseI16ArrayAttr;
27-
using ShardShapeAttr = DenseI64ArrayAttr;
28-
using HaloSizePairAttr = DenseI64ArrayAttr;
2927

3028
} // namespace mesh
3129
} // namespace mlir
@@ -35,59 +33,6 @@ using HaloSizePairAttr = DenseI64ArrayAttr;
3533
#define GET_ATTRDEF_CLASSES
3634
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
3735

38-
namespace mlir {
39-
namespace mesh {
40-
41-
class MeshSharding {
42-
private:
43-
::mlir::FlatSymbolRefAttr mesh;
44-
SmallVector<MeshAxesAttr> split_axes;
45-
SmallVector<MeshAxis> partial_axes;
46-
ReductionKind partial_type;
47-
SmallVector<int64_t> static_halo_sizes;
48-
SmallVector<int64_t> static_sharded_dims_sizes;
49-
SmallVector<Value> dynamic_halo_sizes;
50-
SmallVector<Value> dynamic_sharded_dims_sizes;
51-
52-
public:
53-
MeshSharding() = default;
54-
MeshSharding(Value rhs);
55-
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
56-
ArrayRef<MeshAxesAttr> split_axes_,
57-
ArrayRef<MeshAxis> partial_axes_ = {},
58-
ReductionKind partial_type_ = ReductionKind::Sum,
59-
ArrayRef<int64_t> static_halo_sizes_ = {},
60-
ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
61-
ArrayRef<Value> dynamic_halo_sizes_ = {},
62-
ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
63-
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
64-
::llvm::StringRef getMesh() const { return mesh.getValue(); }
65-
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
66-
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
67-
ReductionKind getPartialType() const { return partial_type; }
68-
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
69-
ArrayRef<int64_t> getStaticShardedDimsSizes() const {
70-
return static_sharded_dims_sizes;
71-
}
72-
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
73-
ArrayRef<Value> getDynamicShardedDimsSizes() const {
74-
return dynamic_sharded_dims_sizes;
75-
}
76-
operator bool() const { return (!mesh) == false; }
77-
bool operator==(Value rhs) const;
78-
bool operator!=(Value rhs) const;
79-
bool operator==(const MeshSharding &rhs) const;
80-
bool operator!=(const MeshSharding &rhs) const;
81-
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
82-
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
83-
};
84-
85-
} // namespace mesh
86-
} // namespace mlir
87-
88-
#define GET_TYPEDEF_CLASSES
89-
#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
90-
9136
#define GET_OP_CLASSES
9237
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
9338

@@ -105,9 +50,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
10550
}
10651

10752
// Is the same tensor replicated on all processes.
108-
inline bool isFullReplication(MeshSharding sharding) {
109-
return sharding.getPartialAxes().empty() &&
110-
llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
53+
inline bool isFullReplication(MeshShardingAttr attr) {
54+
return attr.getPartialAxes().empty() &&
55+
llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
11156
return axes.asArrayRef().empty();
11257
});
11358
}
@@ -135,10 +80,8 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
13580
template <>
13681
inline mesh::MeshOp
13782
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
138-
return getMesh(
139-
op.getOperation(),
140-
cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
141-
symbolTableCollection);
83+
return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
84+
symbolTableCollection);
14285
}
14386

14487
// Get the number of processes that participate in each group
@@ -188,22 +131,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
188131
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
189132
// result in a shape for each shard of ?x2x?.
190133
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
191-
MeshSharding sharding);
134+
MeshShardingAttr sharding);
192135

193136
// If ranked tensor type return its sharded counterpart.
194137
//
195138
// If not ranked tensor type return `type`.
196139
// `sharding` in that case must be null.
197-
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
140+
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
198141

199142
// Insert shard op if there is not one that already has the same sharding.
200143
// May insert resharding if required.
201-
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
144+
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
202145
OpOperand &operand,
203146
OpBuilder &builder);
204-
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
205-
OpBuilder &builder);
206-
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
147+
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
148+
OpResult result, OpBuilder &builder);
149+
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
207150
OpOperand &operand,
208151
OpBuilder &builder);
209152

0 commit comments

Comments
 (0)