-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][mesh] Handling changed halo region sizes during spmdization #114238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…SplitAndPartialAxes
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) Changes
At some point, we should probably refactor how spmdization treats various resharding patterns. @sogartar @yaochengji @mfrancio Could you please have a look? Patch is 41.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114238.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
SmallVector<MeshAxis> partial_axes;
ReductionKind partial_type;
SmallVector<int64_t> static_halo_sizes;
- SmallVector<int64_t> static_sharded_dims_sizes;
+ SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
- SmallVector<Value> dynamic_sharded_dims_sizes;
+ SmallVector<Value> dynamic_sharded_dims_offsets;
public:
MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
ArrayRef<MeshAxis> partial_axes_ = {},
ReductionKind partial_type_ = ReductionKind::Sum,
ArrayRef<int64_t> static_halo_sizes_ = {},
- ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
ArrayRef<Value> dynamic_halo_sizes_ = {},
- ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+ ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
- ArrayRef<int64_t> getStaticShardedDimsSizes() const {
- return static_sharded_dims_sizes;
+ ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+ return static_sharded_dims_offsets;
}
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
- ArrayRef<Value> getDynamicShardedDimsSizes() const {
- return dynamic_sharded_dims_sizes;
+ ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+ return dynamic_sharded_dims_offsets;
}
operator bool() const { return (!mesh) == false; }
bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
bool operator!=(const MeshSharding &rhs) const;
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+ bool equalHaloSizes(const MeshSharding &rhs) const;
+ bool equalShardSizes(const MeshSharding &rhs) const;
};
} // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..04b4b55a433803 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -196,16 +197,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
`?` indicates dynamic halo sizes.
- 6. [Optional] Sizes of sharded dimensions of each shard.
- `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
- device-mesh one value for each sharded tensor dimension.
+ 6. [Optional] Offsets for each shard and sharded tensor dimension.
+ `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
+ For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
+ The offset of each first shard is omitted and is implicitly assumed to be 0.
+ The last value per dimension denotes the end of the last shard.
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
- `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
- the device-mesh will get a shard of shape 16x8x32 and the second device will get a
- shard of shape 16x24x32.
+ `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
+ the device-mesh will get a shard of shape 24x20x32 and the second device will get a
+ shard of shape 8x12x32.
`?` indicates dynamic shard dimensions.
- `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+ `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
Examples:
@@ -240,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
- %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+ %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
```
}];
@@ -250,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
Mesh_MeshAxesArrayAttr:$split_axes,
OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
- Variadic<I64>:$dynamic_sharded_dims_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+ Variadic<I64>:$dynamic_sharded_dims_offsets,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
Variadic<I64>:$dynamic_halo_sizes
);
@@ -263,7 +266,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
`split_axes` `=` $split_axes
(`partial` `=` $partial_type $partial_axes^)?
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
- (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+ (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
}];
let builders = [
@@ -272,16 +275,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxis>":$partial_axes,
"mesh::ReductionKind":$partial_type,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
- CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1056,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
}
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DestinationStyleOpInterface,
+ TypesMatchWith<
+ "result has same type as destination",
+ "result", "destination", "$_self">,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ AttrSizedOperandSegments
]> {
let summary = "Update halo data.";
let description = [{
This operation updates halo regions of shards, e.g. if their sharding
- specified halos and the actual tensor data might have changed
+ specified halos and the actual tensor/memref data might have changed
on the remote devices. Changes might be caused by mutating operations
and/or if the new halo regions are larger than the existing ones.
+ Source and destination might have different halo sizes.
+
Assumes all devices hold tensors with same-sized halo data as specified
- by `dynamic/static_halo_sizes`.
+ by `source_halo_sizes/static_source_halo_sizes` and
+ `destination_halo_sizes/static_destination_halo_sizes` in source shard
+ and destination/result shard.
`split_axes` specifies for each tensor axis along which mesh axes its halo
data is updated.
- Optionally resizes to new halo sizes `target_halo_sizes`.
}];
let arguments = (ins
- AnyNon0RankedMemRef:$input,
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
FlatSymbolRefAttr:$mesh,
Mesh_MeshAxesArrayAttr:$split_axes,
- Variadic<I64>:$dynamic_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+ Variadic<I64>:$source_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+ Variadic<I64>:$destination_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+ );
+ let results = (outs
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
);
let assemblyFormat = [{
- $input `on` $mesh
+ $source `into` $destination
+ `on` $mesh
`split_axes` `=` $split_axes
- (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
- (`target_halo_sizes` `=` $target_halo_sizes^)?
- attr-dict `:` type($input)
+ (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+ (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
+ attr-dict `:` type($source) `->` type($result)
+ }];
+ let extraClassDeclaration = [{
+ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
}];
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..d65f7e4bbadd1a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
- ArrayRef<int64_t> shardedDimsSizes = {},
+ ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
- if (!shardedDimsSizes.empty()) {
+ if (!shardedDimsOffsets.empty()) {
+ uint64_t pos = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
- if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
- for (auto dimSz : shardedDimsSizes) {
- auto inAxis = dimSz % inShape.size();
- assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
- inShape[inAxis] == ShapedType::kDynamic);
- }
-#endif // NDEBUG
- } else {
- // find sharded dims in sharded_dims_sizes with same static size on
- // all devices. Use kDynamic for dimensions with dynamic or non-uniform
- // sizes in sharded_dims_sizes.
- auto sz = shardedDimsSizes[tensorAxis];
- bool same = true;
- for (size_t i = tensorAxis + inShape.size();
- i < shardedDimsSizes.size(); i += inShape.size()) {
- if (shardedDimsSizes[i] != sz) {
- same = false;
- break;
+ if (!innerSplitAxes.empty()) {
+ auto sz = shardedDimsOffsets[pos];
+ bool same = !ShapedType::isDynamicShape(meshShape);
+ if (same) {
+ // find sharded dims in shardedDimsOffsets with same static size on
+ // all devices. Use kDynamic for dimensions with dynamic or
+ // non-uniform offs in shardedDimsOffsets.
+ uint64_t numShards = 0;
+ for (auto i : innerSplitAxes.asArrayRef()) {
+ numShards += meshShape[i];
+ }
+ for (size_t i = 1; i < numShards; ++i) {
+ if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+ sz) {
+ same = false;
+ break;
+ }
}
+ pos += numShards;
}
outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
}
@@ -255,7 +255,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
- resShapeArr, sharding.getStaticShardedDimsSizes(),
+ resShapeArr, sharding.getStaticShardedDimsOffsets(),
sharding.getStaticHaloSizes());
return shape.clone(resShapeArr);
}
@@ -432,13 +432,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halo_sizes,
- ArrayRef<int64_t> static_sharded_dims_sizes) {
+ ArrayRef<int64_t> static_sharded_dims_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+ static_sharded_dims_offsets),
{});
}
@@ -455,11 +456,11 @@ void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+ ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
- dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+ dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +478,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
: b.getDenseI16ArrayAttr(from.getPartialAxes()),
::mlir::mesh::ReductionKindAttr::get(b.getContext(),
from.getPartialType()),
- from.getStaticShardedDimsSizes().empty()
+ from.getStaticShardedDimsOffsets().empty()
? DenseI64ArrayAttr()
- : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
- from.getDynamicShardedDimsSizes(),
+ : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+ from.getDynamicShardedDimsOffsets(),
from.getStaticHaloSizes().empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,7 +510,7 @@ LogicalResult ShardingOp::verify() {
failed(checkMeshAxis(getPartialAxes().value())))
return failure();
- if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+ if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
return emitOpError("halo sizes and shard shapes are mutually exclusive");
}
@@ -539,13 +540,49 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
- getStaticShardedDimsSizes().size() > 0) {
- return emitError() << "sharded dims sizes are not allowed for "
+ getStaticShardedDimsOffsets().size() > 0) {
+ return emitError() << "sharded dims offsets are not allowed for "
"devices meshes with dynamic shape.";
}
return success();
}
+namespace {
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+ using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShardingOp op,
+ PatternRewriter &b) const override {
+ auto mixedHalos =
+ getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+ auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+ op.getDynamicShardedDimsOffsets(), b);
+
+ // No constant operands were folded, just return;
+ if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+ return failure();
+ }
+
+ auto halos = decomposeMixedValues(mixedHalos);
+ auto offs = decomposeMixedValues(mixedOffs);
+
+ op.setStaticHaloSizes(halos.first);
+ op.getDynamicHaloSizesMutable().assign(halos.second);
+ op.setStaticShardedDimsOffsets(offs.first);
+ op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+ return success();
+ }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+ mlir::MLIRContext *context) {
+ results.add<FoldDynamicLists>(context);
+}
+
//===----------------------------------------------------------------------===//
// MeshSharding
//===----------------------------------------------------------------------===//
@@ -555,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
return false;
}
- if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+ (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+ !llvm::equal(
+ llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+ llvm::make_range(rhs.getPartialAxes().begin(),
+ rhs.getPartialAxes().end()))) {
return false;
}
@@ -576,6 +618,31 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
}
bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+ return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+ if (rhs.getStaticShardedDimsOffsets().size() !=
+ getStaticShardedDimsOffsets().size() ||
+ !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+ getStaticShardedDimsOffsets().end()),
+ llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+ rhs.getStaticShardedDimsOffsets().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicShardedDimsOffsets().size() !=
+ getDynamicShardedDimsOffsets().size() ||
+ !llvm::equal(
+ llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+ ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be more intuitive to have offsets starting from zero unless there is a specific reason not to follow this approach?
6. [Optional] Offsets for each shard and sharded tensor dimension. | ||
`sharded_dims_offsets` is provided as a flattened 1d array of i64s: | ||
For each sharded tensor dimension the offsets (starting index) of all shards in that dimension. | ||
The offset of each first shard is omitted and is implicitly assumed to be 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a particular benefit in following this convention? Would this allow, for example, to crop the last dimension and not get the full dimensionality of the tensor?
I would think that actually listing the offsets directly (exclusive prefix-sum) would be more intuitive if the above case is anyway disallowed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes get your point. We need the last value to able to compute the shard size of the last shard. As said elsewhere, I had to either do it as you suggest and have an array of size N+1 or have the first value given implicitly. I often prefer avoiding redundancy so I did it this way. I have no strong opinion here, if you have a clear preference I can switch to the longer array solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel pretty strongly about having the offset computed as exclusive prefix sum of the "size" array. That is, the offset array is still size N, the first element is 0, and the offset for shard i
can be retrieved as sharded_dims_offsets[i]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need an additional last element, otherwise we cannot deduce the size of of the last shard.
That is because there is no way to implicitly define the global tensor shape. It should be possible to deduce the sizes without pulling in a dependency to the input tensor. The start will always be 0, so that is redundant information and can be avoided to keep the array to size N. Not so the end.
I will change to having offset for shard i
as sharded_dims_offsets[i]
and append the offset to the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if i did not provide more context in my previous message, but is there an instance where the input tensor shape is not directly available that warrant the N+1 representation?
If that's the case, then I would agree with you in that the original implementation (size N, inclusive prefix sum representing the end offset) was probably better, especially because now you have a flattened 2d array of n+1 values which adds additional complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, at this point I don't know about a case where the shape is unavailable.
Still, I prefer a form that is self-sufficient. There is some complexity involved no matter what. If extracting the tensor size from another entity is more complex than dealing with a N+1
values probably lies in the eye of the beholder.
"devices meshes with dynamic shape."; | ||
} | ||
return success(); | ||
} | ||
|
||
namespace { | ||
class FoldDynamicLists final : public OpRewritePattern<ShardingOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we comment above the purpose of this canonicalization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
MeshSharding targetSharding, | ||
ShapedType sourceUnshardedShape, | ||
TypedValue<ShapedType> sourceShard) { | ||
// currently handles only cases where halo sizes differ but everything else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments in this file miss proper capitalization and punctuation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made some improvements.
TypedValue<ShapedType> sourceShard) { | ||
// currently handles only cases where halo sizes differ but everything else | ||
// stays the same (from source to destination sharding) | ||
if (sourceSharding.equalSplitAndPartialAxes(targetSharding) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably early return and reduce the indentation block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution!
`sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the | ||
device-mesh one value for each sharded tensor dimension. | ||
6. [Optional] Offsets for each shard and sharded tensor dimension. | ||
`sharded_dims_offsets` is provided as a flattened 1d array of i64s: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it a flattened 1d array?
In my understanding, the offsets in each dimension is a strictly increasing array. I feel it might be more straightforward to organize it as a 2d array.
BTW, I think we'd better add some code to check whether it's a strictly increasing array or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ideally this would be a 2d array. As we allow dynamic and static values, I didn't intuitively know how to implement the parsing for it. Mechanics exist for 1d arrays, though. We can add this in an extra PR, but at this point I didn't want to spend time on it. Of course it'd be great if someone had a solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The offsets can be dynamic values, so the check cannot be a static check and requires generating IR. Where would you put such a check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for static values, we can use IndexListArrayAttr or SmallVector<SmallVector<int64_t>>. For dynamic values, we can use SmallVector<SmallVector>. But I'm totally fine that we can optimize it later.
With regarding to the check, we can add the check logic in the verify
function only for static values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a check and tests
Yes, we could do that. It would require one the array to be of size N+1, though. I was undecided at the time; if you have a clear preference I will make the change. |
Co-authored-by: Matteo Franciolini <[email protected]>
Co-authored-by: Matteo Franciolini <[email protected]>
Co-authored-by: Chengji Yao <[email protected]>
Co-authored-by: Matteo Franciolini <[email protected]>
@mfrancio @yaochengji Thanks for your thoughtful reviews. If there are no other concerns (@sogartar ?), could you please merge (I do not yet have write permissions). |
I can help with that - I'll wait until the weekend to give @yaochengji or @sogartar the opportunity to take another look. |
LGTM, thanks |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/6200 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/89/builds/10252 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/130/builds/6014 Here is the relevant piece of the build log for the reference
|
Is someone investigating these CI issues? Our flang bot turned red, too. |
Yep, a PR will be available in a few minutes. |
…lvm#114238) * Changed `MeshSharding::sharded_dims_sizes` from representing sizes per shard to offsets to origin per shard. - Local shard size are now a simple subtraction - Offsets are now readily available without a reduction operation - Enables constant value/shape propagation through standard canonicalization - Renamed to `sharded_dims_offsets` accordingly. * First spmdization pattern for halo regions. - Triggers when source and destination shardings differ only in their halo sizes - Copies local data from source into a new tensor and calls update_halo - Supports arbitrary mesh dimensions (unlike the other patterns which work on 1d meshes only) * `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts tensors and memrefs - also accepts target and source halo sizes; both are required for proper lowering * minor refactoring for testing partial MeshSharding equality * Canonicalization for ShardingOp folding constant values into respective `static_*` attributes
fixing CI failures caused by llvm#114238 by adding MLIRDestinationStyleOpInterface lib @jplehr @mfrancio @rengolin
MeshSharding::sharded_dims_sizes
from representing sizes per shard to offsets to origin per shard.sharded_dims_offsets
accordingly.UpdateHaloOp
implementsDestinationStyleOpInterface
and accepts tensors and memrefsstatic_*
attributesAt some point, we should probably refactor how spmdization treats various resharding patterns.
@sogartar @yaochengji @mfrancio Could you please have a look?