Skip to content

Commit adbf21f

Browse files
authored
[mlir][mesh] Add spmdization pass (#80518)
Add a pass that converts a function that has sharding annotations into SPMD form.
1 parent 8ae0485 commit adbf21f

File tree

18 files changed

+848
-101
lines changed

18 files changed

+848
-101
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
10+
#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace func {
17+
18+
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
19+
20+
} // namespace func
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1314
#include "mlir/IR/OpDefinition.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/IR/SymbolTable.h"
1617
#include "mlir/Interfaces/InferTypeOpInterface.h"
1718
#include "mlir/Interfaces/SideEffectInterfaces.h"
19+
#include "mlir/Support/MathExtras.h"
1820

1921
namespace mlir {
2022
namespace mesh {
@@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
4850

4951
Partial getPartialTypeFromReduction(IteratorType iType);
5052

53+
// Is the same tensor replicated on all processes.
54+
inline bool isFullReplication(MeshShardingAttr attr) {
55+
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
56+
}
57+
5158
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
5259
SymbolTableCollection &symbolTableCollection) {
5360
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
@@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
6067
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
6168
}
6269

70+
template <>
71+
inline mesh::MeshOp
72+
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
73+
return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
74+
symbolTableCollection);
75+
}
76+
6377
// Get the number of processes that participate in each group
6478
// induced by `meshAxes`.
6579
template <typename MeshAxesRange, typename MeshShapeRange>
@@ -78,6 +92,37 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
7892
return res;
7993
}
8094

95+
// Get the size of a sharded dimension.
96+
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
97+
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
98+
return ShapedType::kDynamic;
99+
100+
assert(dimSize % shardCount == 0);
101+
return ceilDiv(dimSize, shardCount);
102+
}
103+
104+
// Get the size of an unsharded dimension.
105+
inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
106+
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
107+
return ShapedType::kDynamic;
108+
109+
return dimSize * shardCount;
110+
}
111+
112+
// Return the sharded shape `shape` according ot sharding `sharding`.
113+
// The shape for the tensor on each device in the mesh.
114+
// Example:
115+
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
116+
// result in a shape for each shard of ?x2x?.
117+
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
118+
MeshShardingAttr sharding);
119+
120+
// If ranked tensor type return its sharded counterpart.
121+
//
122+
// If not ranked tensor type return `type`.
123+
// `sharding` in that case must be null.
124+
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
125+
81126
} // namespace mesh
82127
} // namespace mlir
83128

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
1111

1212
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/IR/Value.h"
1314
#include "mlir/Support/LLVM.h"
1415

1516
namespace mlir {
1617

1718
class Operation;
19+
class IRMapping;
20+
class SymbolTableCollection;
1821

1922
namespace mesh {
2023

@@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
5861

5962
} // namespace detail
6063

61-
} // namespace mesh
64+
// Assumes full replication on all ranked tensor arguments and results.
65+
void spmdizeFullyReplicatedOperation(
66+
Operation &op, ArrayRef<Value> spmdizedOperands,
67+
ArrayRef<MeshShardingAttr> operandShardings,
68+
ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
69+
SymbolTableCollection &symbolTable, OpBuilder &builder);
6270

71+
} // namespace mesh
6372
} // namespace mlir
6473

6574
/// Include the ODS generated interface header files.

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
8888
return detail::defaultAddShardingAnnotations(
8989
$_op.getOperation(), b, shardingOption);
9090
}]
91-
>
91+
>,
92+
InterfaceMethod<
93+
/*desc=*/[{
94+
Convert self to SPMD form.
95+
This method is used during the spmdization pass of a program fully
96+
annotated with shardings.
97+
98+
The spmdization algorithm would read the surrounding sharding
99+
annotations from the IR for each argument/result and prepare
100+
`operandShardings` and `resultShardings`.
101+
Values that are not ranked tensors do not have sharding annotations.
102+
In this case their corresponding MeshShardingAttr is null.
103+
104+
For convenience it will also prepare `spmdizedOperands`, although
105+
they can be retrieved from the `spmdizationMap`.
106+
107+
The `spmdizationMap` contains a mapping from unsharded to
108+
sharded/spmdized values that are constructed during the spmdization
109+
pass. The interface implementation must populate `spmdizationMap`
110+
with the mapping for this op's results.
111+
112+
`builder` is set to insert new operations in the appropriate point.
113+
The implementation should not return the builder to the original
114+
insertion point.
115+
It should leave it as is after all insertions are done.
116+
117+
The default implementation does full replication.
118+
This assumes that all sharding annotations are for full replication.
119+
}],
120+
/*retTy=*/"LogicalResult",
121+
/*methodName=*/"spmdize",
122+
/*args=*/(ins
123+
"ArrayRef<Value>": $spmdizedOperands,
124+
"ArrayRef<MeshShardingAttr>": $operandShardings,
125+
"ArrayRef<MeshShardingAttr>": $resultShardings,
126+
"IRMapping&": $spmdizationMap,
127+
"SymbolTableCollection &": $symbolTableCollection,
128+
"OpBuilder &":$builder
129+
),
130+
/*methodBody=*/"",
131+
/*defaultImplementation=*/[{
132+
spmdizeFullyReplicatedOperation(
133+
*$_op.getOperation(), spmdizedOperands, operandShardings,
134+
resultShardings, spmdizationMap, symbolTableCollection, builder);
135+
return success();
136+
}]>
92137
];
93138

94139
let extraClassDeclaration = [{
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10+
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
11+
12+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
14+
#include "mlir/IR/AffineMap.h"
15+
#include "mlir/IR/Value.h"
16+
17+
namespace mlir {
18+
19+
class Operation;
20+
class IRMapping;
21+
class SymbolTableCollection;
22+
23+
namespace mesh {
24+
25+
// Inserts a clone of the operation that has all ranked tensor
26+
// arguments/results sharded.
27+
void spmdizeTriviallyShardableOperation(
28+
Operation &op, ArrayRef<Value> spmdizedOperands,
29+
ArrayRef<MeshShardingAttr> operandShardings,
30+
ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
31+
SymbolTableCollection &symbolTable, OpBuilder &builder);
32+
33+
// All ranked tensor argument and result dimensions have
34+
// independent parallel loop iterators.
35+
template <typename Op>
36+
struct IndependentParallelIteratorDomainShardingInterface
37+
: public ShardingInterface::ExternalModel<
38+
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
39+
SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
40+
SmallVector<IteratorType> iterTypes;
41+
for (Type t : operation->getOperandTypes()) {
42+
populateIteratorTypes(t, iterTypes);
43+
}
44+
for (Type t : operation->getResultTypes()) {
45+
populateIteratorTypes(t, iterTypes);
46+
}
47+
return iterTypes;
48+
}
49+
50+
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51+
// TODO: implement.
52+
return SmallVector<AffineMap>();
53+
}
54+
55+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
56+
ArrayRef<MeshShardingAttr> operandShardings,
57+
ArrayRef<MeshShardingAttr> resultShardings,
58+
IRMapping &spmdizationMap,
59+
SymbolTableCollection &symbolTable,
60+
OpBuilder &builder) const {
61+
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
62+
resultShardings, spmdizationMap,
63+
symbolTable, builder);
64+
return success();
65+
}
66+
67+
private:
68+
void populateIteratorTypes(Type t,
69+
SmallVector<IteratorType> &iterTypes) const {
70+
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
71+
if (!rankedTensorType) {
72+
return;
73+
}
74+
75+
iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
76+
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
77+
iterTypes.push_back(IteratorType::Parallel);
78+
}
79+
}
80+
};
81+
82+
// Sharding of elementwise operations like tensor addition and multiplication.
83+
template <typename ElemwiseOp>
84+
struct ElementwiseShardingInterface
85+
: public ShardingInterface::ExternalModel<
86+
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
87+
SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
88+
Value val = op->getOperand(0);
89+
auto type = val.getType().dyn_cast<RankedTensorType>();
90+
if (!type)
91+
return {};
92+
SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
93+
return types;
94+
}
95+
96+
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
97+
MLIRContext *ctx = op->getContext();
98+
Value val = op->getOperand(0);
99+
auto type = val.getType().dyn_cast<RankedTensorType>();
100+
if (!type)
101+
return {};
102+
int64_t rank = type.getRank();
103+
int64_t num = op->getNumOperands() + op->getNumResults();
104+
SmallVector<AffineMap> maps(num,
105+
AffineMap::getMultiDimIdentityMap(rank, ctx));
106+
return maps;
107+
}
108+
109+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
110+
ArrayRef<MeshShardingAttr> operandShardings,
111+
ArrayRef<MeshShardingAttr> resultShardings,
112+
IRMapping &spmdizationMap,
113+
SymbolTableCollection &symbolTable,
114+
OpBuilder &builder) const {
115+
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
116+
resultShardings, spmdizationMap,
117+
symbolTable, builder);
118+
return success();
119+
}
120+
};
121+
122+
} // namespace mesh
123+
} // namespace mlir
124+
125+
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,62 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
2929
];
3030
}
3131

32+
def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
33+
let summary = "Partition a function into SPMD form.";
34+
let description = [{
35+
This pass fits in right after a pass that annotates the function with
36+
shardings like the `ShardingPropagation` pass.
37+
It operates on a fully annotated IR.
38+
39+
A fully annotated IR required that all ranked tensor operands, results and
40+
block arguments are annotated with the `mesh.shard` operation.
41+
42+
All direct descendant operations in the function must implement the
43+
`ShardingInterface` interface or all their ranked tensor operands and
44+
results must have full replication sharding.
45+
46+
The input IR must have sharding annotations such that each operation
47+
that implements `ShardingInterface` can handle during spmdization with
48+
its `spmdize` method.
49+
This can be achieved with the `ShardingPropagation` pass.
50+
51+
If the function has multiple terminating blocks,
52+
it is the responsibility of the the one who annotates the function with
53+
shardings to make sure that all returns would be consisted that is,
54+
have the same sharding.
55+
56+
Example:
57+
```mlir
58+
mesh.mesh @mesh_1d(shape = 2)
59+
60+
func.func @f(
61+
%arg0: tensor<2xi8>
62+
) -> tensor<2xi8> {
63+
%0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
64+
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
65+
%2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
66+
%3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
67+
%4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
68+
return %4 : tensor<2xi8>
69+
}
70+
```
71+
Spmdizing the above would result in
72+
* Performing the element-wise `abs` operation on each device.
73+
* Resharding to full replication with an all-gather.
74+
75+
```mlir
76+
mesh.mesh @mesh_1d(shape = 2)
77+
78+
func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
79+
%0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
80+
%1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
81+
return %1 : tensor<2xi8>
82+
}
83+
```
84+
}];
85+
let dependentDialects = [
86+
"mesh::MeshDialect"
87+
];
88+
}
89+
3290
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD

0 commit comments

Comments
 (0)