Skip to content

[mlir][mesh] Add spmdization pass #80518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_

namespace mlir {

class DialectRegistry;

namespace func {

void registerShardingInterfaceExternalModels(DialectRegistry &registry);

} // namespace func
} // namespace mlir

#endif // MLIR_DIALECT_FUNC_IR_SHARDINGINTERFACEIMPL_H_
45 changes: 45 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/MathExtras.h"

namespace mlir {
namespace mesh {
Expand Down Expand Up @@ -48,6 +50,11 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {

Partial getPartialTypeFromReduction(IteratorType iType);

// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
}

inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
Expand All @@ -60,6 +67,13 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
}

template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
symbolTableCollection);
}

// Get the number of processes that participate in each group
// induced by `meshAxes`.
template <typename MeshAxesRange, typename MeshShapeRange>
Expand All @@ -78,6 +92,37 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}

// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
return ShapedType::kDynamic;

assert(dimSize % shardCount == 0);
return ceilDiv(dimSize, shardCount);
}

// Get the size of an unsharded dimension.
inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
return ShapedType::kDynamic;

return dimSize * shardCount;
}

// Return the sharded shape `shape` according ot sharding `sharding`.
// The shape for the tensor on each device in the mesh.
// Example:
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshShardingAttr sharding);

// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);

} // namespace mesh
} // namespace mlir

Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

namespace mlir {

class Operation;
class IRMapping;
class SymbolTableCollection;

namespace mesh {

Expand Down Expand Up @@ -58,8 +61,14 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,

} // namespace detail

} // namespace mesh
// Assumes full replication on all ranked tensor arguments and results.
void spmdizeFullyReplicatedOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable, OpBuilder &builder);

} // namespace mesh
} // namespace mlir

/// Include the ODS generated interface header files.
Expand Down
47 changes: 46 additions & 1 deletion mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,52 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
return detail::defaultAddShardingAnnotations(
$_op.getOperation(), b, shardingOption);
}]
>
>,
InterfaceMethod<
/*desc=*/[{
Convert self to SPMD form.
This method is used during the spmdization pass of a program fully
annotated with shardings.

The spmdization algorithm would read the surrounding sharding
annotations from the IR for each argument/result and prepare
`operandShardings` and `resultShardings`.
Values that are not ranked tensors do not have sharding annotations.
In this case their corresponding MeshShardingAttr is null.

For convenience it will also prepare `spmdizedOperands`, although
they can be retrieved from the `spmdizationMap`.

The `spmdizationMap` contains a mapping from unsharded to
sharded/spmdized values that are constructed during the spmdization
pass. The interface implementation must populate `spmdizationMap`
with the mapping for this op's results.

`builder` is set to insert new operations in the appropriate point.
The implementation should not return the builder to the original
insertion point.
It should leave it as is after all insertions are done.

The default implementation does full replication.
This assumes that all sharding annotations are for full replication.
}],
/*retTy=*/"LogicalResult",
/*methodName=*/"spmdize",
/*args=*/(ins
"ArrayRef<Value>": $spmdizedOperands,
"ArrayRef<MeshShardingAttr>": $operandShardings,
"ArrayRef<MeshShardingAttr>": $resultShardings,
"IRMapping&": $spmdizationMap,
"SymbolTableCollection &": $symbolTableCollection,
"OpBuilder &":$builder
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
spmdizeFullyReplicatedOperation(
*$_op.getOperation(), spmdizedOperands, operandShardings,
resultShardings, spmdizationMap, symbolTableCollection, builder);
return success();
}]>
];

let extraClassDeclaration = [{
Expand Down
125 changes: 125 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Value.h"

namespace mlir {

class Operation;
class IRMapping;
class SymbolTableCollection;

namespace mesh {

// Inserts a clone of the operation that has all ranked tensor
// arguments/results sharded.
void spmdizeTriviallyShardableOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable, OpBuilder &builder);

// All ranked tensor argument and result dimensions have
// independent parallel loop iterators.
template <typename Op>
struct IndependentParallelIteratorDomainShardingInterface
: public ShardingInterface::ExternalModel<
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
SmallVector<IteratorType> iterTypes;
for (Type t : operation->getOperandTypes()) {
populateIteratorTypes(t, iterTypes);
}
for (Type t : operation->getResultTypes()) {
populateIteratorTypes(t, iterTypes);
}
return iterTypes;
}

SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
// TODO: implement.
return SmallVector<AffineMap>();
}

LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
resultShardings, spmdizationMap,
symbolTable, builder);
return success();
}

private:
void populateIteratorTypes(Type t,
SmallVector<IteratorType> &iterTypes) const {
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
if (!rankedTensorType) {
return;
}

iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
iterTypes.push_back(IteratorType::Parallel);
}
}
};

// Sharding of elementwise operations like tensor addition and multiplication.
template <typename ElemwiseOp>
struct ElementwiseShardingInterface
: public ShardingInterface::ExternalModel<
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
if (!type)
return {};
SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
return types;
}

SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
MLIRContext *ctx = op->getContext();
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
if (!type)
return {};
int64_t rank = type.getRank();
int64_t num = op->getNumOperands() + op->getNumResults();
SmallVector<AffineMap> maps(num,
AffineMap::getMultiDimIdentityMap(rank, ctx));
return maps;
}

LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
resultShardings, spmdizationMap,
symbolTable, builder);
return success();
}
};

} // namespace mesh
} // namespace mlir

#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
58 changes: 58 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,62 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}

def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
let summary = "Partition a function into SPMD form.";
let description = [{
This pass fits in right after a pass that annotates the function with
shardings like the `ShardingPropagation` pass.
It operates on a fully annotated IR.

A fully annotated IR required that all ranked tensor operands, results and
block arguments are annotated with the `mesh.shard` operation.

All direct descendant operations in the function must implement the
`ShardingInterface` interface or all their ranked tensor operands and
results must have full replication sharding.

The input IR must have sharding annotations such that each operation
that implements `ShardingInterface` can handle during spmdization with
its `spmdize` method.
This can be achieved with the `ShardingPropagation` pass.

If the function has multiple terminating blocks,
it is the responsibility of the the one who annotates the function with
shardings to make sure that all returns would be consisted that is,
have the same sharding.

Example:
```mlir
mesh.mesh @mesh_1d(shape = 2)

func.func @f(
%arg0: tensor<2xi8>
) -> tensor<2xi8> {
%0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
%2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
%3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
%4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
return %4 : tensor<2xi8>
}
```
Spmdizing the above would result in
* Performing the element-wise `abs` operation on each device.
* Resharding to full replication with an all-gather.

```mlir
mesh.mesh @mesh_1d(shape = 2)

func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
%0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
%1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
return %1 : tensor<2xi8>
}
```
}];
let dependentDialects = [
"mesh::MeshDialect"
];
}

#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
Loading