Skip to content

Commit 5df2c00

Browse files
authored
[mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (llvm#77838)
Remove the somewhat redundant rank attribute. Before this change ``` mesh.cluster @mesh(rank = 3, dim_sizes = 2x3) ``` After ``` mesh.cluster @mesh(shape = 2x3x?) ``` The rank is instead determined by the provided shape. With this change no longer `getDimSizes()` can be wrongly assumed to have size equal to the cluster rank. Now `getShape().size()` will always equal `getRank()`.
1 parent d85df3f commit 5df2c00

File tree

13 files changed

+134
-166
lines changed

13 files changed

+134
-166
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
117117
Example:
118118

119119
```
120-
mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
120+
mesh.cluster @mesh0(shape = 2x2x4)
121121

122122
// The tensor is fully replicated on @mesh0.
123123
// Currently, there must be at least one sub-array present in axes, even

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
3636
cluster. This name serves as a symbolic reference to the cluster throughout
3737
the MLIR module, allowing for consistent referencing and easier debugging.
3838

39-
2. `rank`: This attribute specifies the number of axes of the cluster. The
40-
rank indicates the dimensionality of the mesh cluster and can be used to
41-
determine the layout and the addressing space of the computation distributed
42-
across the mesh.
43-
44-
3. `dim_sizes`: This attribute represents the shape of the device cluster.
39+
2. `shape`: This attribute represents the shape of the device cluster.
4540
It uses the same notation as a tensor shape. Also allowing for dynamic
4641
dimensions.
4742
This flexibility allows for dynamic device assignment or configurations
@@ -53,19 +48,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
5348
```
5449
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
5550
// The dimension sizes are 4, 8, 12
56-
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
51+
mesh.cluster @mesh0(shape = 4x8x12)
5752

5853
// A device mesh cluster with 2 axes, the total device number is unknown
5954
// The first dimension size is 4 and the second is unknown
60-
mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
55+
mesh.cluster @mesh1(shape = 4x?)
6156

6257
// A device mesh cluster with 2 axes, the total device number is unknown
6358
// The first dimension size is unknown and the second is 4
64-
mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
59+
mesh.cluster @mesh2(shape = ?x4)
6560

6661
// A device mesh cluster with 2 axes, the number of devices along both axes
6762
// is unknown
68-
mesh.cluster @mesh3(rank = 2)
63+
mesh.cluster @mesh3(shape = ?x?)
6964

7065
// Used in the mesh sharding attribute to extend the standard tensor to
7166
// distributed
@@ -74,24 +69,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
7469
}];
7570
let arguments = (ins
7671
SymbolNameAttr:$sym_name,
77-
I64Attr:$rank,
78-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
72+
DenseI64ArrayAttr:$shape
7973
);
8074
let assemblyFormat = [{
81-
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
75+
$sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
8276
attr-dict
8377
}];
8478
let extraClassDeclaration = [{
85-
// The `dim_sizes` attribute may have size less than the rank of the mesh.
86-
// Returns the shape of the mesh with missing trailing dimensions
87-
// explicitly set as dynamic.
88-
::mlir::SmallVector<int64_t> canonicalDimSizes();
89-
90-
template <typename OutIt>
91-
void canonicalDimSizes(OutIt outIt) {
92-
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
93-
std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
94-
}
79+
int64_t getRank() { return getShape().size(); }
9580
}];
9681
let hasVerifier = 1;
9782
}
@@ -283,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
283268

284269
Example:
285270
```mlir
286-
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
271+
mesh.cluster @mesh0(shape = 2x2)
287272
...
288273
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
289274
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -368,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
368353

369354
Example:
370355
```
371-
mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
356+
mesh.cluster @mesh0(shape = 3)
372357
...
373358
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
374359
split_axis = 0 concat_axis = 0
@@ -425,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
425410

426411
Example:
427412
```
428-
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
413+
mesh.cluster @mesh0(shape = 2x2)
429414

430415
%1 = mesh.broadcast %0 on @mesh0
431416
mesh_axes = [0]
@@ -481,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
481466

482467
Example:
483468
```mlir
484-
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
469+
mesh.cluster @mesh0(shape = 2x2)
485470
...
486471
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
487472
gather_axis = 1 root = [1]
@@ -604,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
604589
across the device group.
605590
Example:
606591
```
607-
mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
592+
mesh.cluster @mesh0(shape = 2x2)
608593
...
609594
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
610595
reduction = <max> scatter_axis = 0
@@ -667,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
667652

668653
Example:
669654
```
670-
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
655+
mesh.cluster @mesh0(shape = 2x2)
671656
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
672657
scatter_axis = 0
673658
root = [1]
@@ -763,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
763748

764749
Example:
765750
```
766-
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
751+
mesh.cluster @mesh0(shape = 2x4)
767752
%1 = mesh.shift on @mesh0 mesh_axes = [1]
768753
shift_axis = 1 offset = 2 rotate
769754
: tensor<2xi8> -> tensor<2xi8>

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
196196
//===----------------------------------------------------------------------===//
197197

198198
LogicalResult ClusterOp::verify() {
199-
ArrayRef<int64_t> dimSizes = getDimSizes();
200-
uint64_t rank = getRank();
199+
int64_t rank = getRank();
201200

202-
if (rank == 0)
201+
if (rank <= 0)
203202
return emitOpError("rank of cluster is expected to be a positive integer");
204203

205-
if (dimSizes.size() > rank)
204+
if (getShape().size() > rank)
206205
return emitOpError(
207-
"rank of dim_sizes is not expected to be larger than rank of cluster");
206+
"rank of shape is not expected to be larger than rank of cluster");
208207

209-
for (int64_t dimSize : dimSizes) {
208+
for (int64_t dimSize : getShape()) {
210209
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
211210
return emitOpError("dimension size of a mesh cluster is expected to be "
212211
"non-negative or dynamic");
@@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
215214
return success();
216215
}
217216

218-
SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
219-
SmallVector<int64_t> result;
220-
canonicalDimSizes(std::back_inserter(result));
221-
result.reserve(getRank());
222-
return result;
223-
}
224-
225217
//===----------------------------------------------------------------------===//
226218
// mesh.cluster_shape op
227219
//===----------------------------------------------------------------------===//
@@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
614606
auto gatherAxis = getGatherAxis().getSExtValue();
615607
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
616608
gatherAxis, getMeshAxes(),
617-
mesh.value().canonicalDimSizes());
609+
mesh.value().getShape());
618610
}
619611

620612
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
648640

649641
return verifyAllToAllOperandAndResultShape(
650642
getOperand(), getResult(), getSplitAxis().getSExtValue(),
651-
getConcatAxis().getSExtValue(), getMeshAxes(),
652-
mesh.value().canonicalDimSizes());
643+
getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
653644
}
654645

655646
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
667658
if (failed(mesh)) {
668659
return failure();
669660
}
670-
auto meshShape = mesh.value().canonicalDimSizes();
671661
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
672-
getRootDynamic(), getMeshAxes(), meshShape))) {
662+
getRootDynamic(), getMeshAxes(),
663+
mesh.value().getShape()))) {
673664
return failure();
674665
}
675666

@@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
690681
if (failed(mesh)) {
691682
return failure();
692683
}
693-
auto meshShape = mesh.value().canonicalDimSizes();
694684
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
695-
getRootDynamic(), getMeshAxes(), meshShape))) {
685+
getRootDynamic(), getMeshAxes(),
686+
mesh.value().getShape()))) {
696687
return failure();
697688
}
698689

699690
auto gatherAxis = getGatherAxis().getSExtValue();
700691
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
701692
getMeshAxes(),
702-
mesh.value().canonicalDimSizes());
693+
mesh.value().getShape());
703694
}
704695

705696
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
716707
if (failed(mesh)) {
717708
return failure();
718709
}
719-
auto meshShape = mesh.value().canonicalDimSizes();
720-
if (getSource() && failed(verifyInGroupDevice(
721-
getLoc(), getSourceAttrName(), getSource().value(),
722-
getSourceDynamic(), getMeshAxes(), meshShape))) {
710+
if (getSource() &&
711+
failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
712+
getSource().value(), getSourceDynamic(),
713+
getMeshAxes(), mesh.value().getShape()))) {
723714
return failure();
724715
}
725716
return success();
@@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
739730
if (failed(mesh)) {
740731
return failure();
741732
}
742-
auto meshShape = mesh.value().canonicalDimSizes();
743733
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
744-
getRootDynamic(), getMeshAxes(), meshShape))) {
734+
getRootDynamic(), getMeshAxes(),
735+
mesh.value().getShape()))) {
745736
return failure();
746737
}
747738

@@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
766757

767758
return verifyScatterOperandAndResultShape(
768759
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
769-
mesh.value().canonicalDimSizes());
760+
mesh.value().getShape());
770761
}
771762

772763
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
783774
if (failed(mesh)) {
784775
return failure();
785776
}
786-
auto meshShape = mesh.value().canonicalDimSizes();
787777
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
788-
getRootDynamic(), getMeshAxes(), meshShape))) {
778+
getRootDynamic(), getMeshAxes(),
779+
mesh.value().getShape()))) {
789780
return failure();
790781
}
791782

792783
auto scatterAxis = getScatterAxis().getSExtValue();
793784
return verifyScatterOperandAndResultShape(getInput(), getResult(),
794785
scatterAxis, getMeshAxes(),
795-
mesh.value().canonicalDimSizes());
786+
mesh.value().getShape());
796787
}
797788

798789
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
809800
if (failed(mesh)) {
810801
return failure();
811802
}
812-
auto meshShape = mesh.value().canonicalDimSizes();
813803
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
814804
getDestination(), getDestinationDynamic(),
815-
getMeshAxes(), meshShape))) {
805+
getMeshAxes(), mesh.value().getShape()))) {
816806
return failure();
817807
}
818808
return success();

mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
8080
opMeshAxes = opAxesIota;
8181
}
8282
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
83-
return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
83+
return ShapedType::isDynamic(mesh.getShape()[axis]);
8484
})) {
8585
// All mesh dimensions are dynamic. Nothing to fold.
8686
return failure();
@@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
9191
SmallVector<size_t> newToOldResultsIndexMap;
9292

9393
for (size_t i = 0; i < opMeshAxes.size(); ++i) {
94-
auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
94+
auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
9595
if (ShapedType::isDynamic(meshAxisSize)) {
9696
newToOldResultsIndexMap.push_back(i);
9797
newShapeOpMeshAxes.push_back(opMeshAxes[i]);

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
8888
MeshShardingAttr sharding) {
8989
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
9090
SmallVector<Dim> resShapeArr(shape.getShape().size());
91-
shardShape(shape.getShape(), mesh.canonicalDimSizes(),
92-
sharding.getSplitAxes(), resShapeArr);
91+
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
92+
resShapeArr);
9393
return shape.clone(resShapeArr);
9494
}
9595

@@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
212212

213213
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
214214
ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
215-
ShapedType targetShape =
216-
targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
217-
mesh.canonicalDimSizes()[splitMeshAxis]);
215+
ShapedType targetShape = targetShapeInSplitLastAxis(
216+
sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
218217

219218
Value meshAxisSize =
220219
builder
@@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
391390
MeshShardingAttr targetSharding =
392391
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
393392
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
394-
sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
395-
splitTensorAxis);
393+
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
396394
Value allGatherResult = builder.create<AllGatherOp>(
397395
RankedTensorType::get(allGatherResultShape.getShape(),
398396
allGatherResultShape.getElementType()),
@@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
526524
MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
527525
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
528526
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
529-
sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
530-
sourceTensorAxis, targetTensorAxis);
527+
sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
528+
targetTensorAxis);
531529
Value allToAllResult = builder.create<AllToAllOp>(
532530
RankedTensorType::get(allToAllResultShape.getShape(),
533531
allToAllResultShape.getElementType()),

mlir/test/Dialect/Mesh/canonicalization.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt --canonicalize %s | FileCheck %s
22

3-
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
3+
mesh.cluster @mesh0(shape = 2x4)
44

55
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
66
func.func @all_reduce_empty_mesh_axes(

mlir/test/Dialect/Mesh/folding.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
22

3-
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
4-
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
3+
mesh.cluster @mesh0(shape = 4x?x2)
4+
mesh.cluster @mesh1(shape = 2x3)
55

66
// CHECK-LABEL: func.func @cluster_shape_op_folding
77
func.func @cluster_shape_op_folding() -> (index, index) {

0 commit comments

Comments
 (0)