Skip to content

[mlir][mesh] Handling changed halo region sizes during spmdization #114238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 11, 2024

Conversation

fschlimb
Copy link
Contributor

  • Changed MeshSharding::sharded_dims_sizes from representing sizes per shard to offsets to origin per shard.
    • Local shard size are now a simple subtraction
    • Offsets are now readily available without a reduction operation
    • Enables constant value/shape propagation through standard canonicalization
    • Renamed to sharded_dims_offsets accordingly.
  • First spmdization pattern for halo regions.
    • Triggers when source and destination shardings differ only in their halo sizes
    • Copies local data from source into a new tensor and calls update_halo
    • Supports arbitrary mesh dimensions (unlike the other patterns which work on 1d meshes only)
  • UpdateHaloOp implements DestinationStyleOpInterface and accepts tensors and memrefs
    • also accepts target and source halo sizes; both are required for proper lowering
  • minor refactoring for testing partial MeshSharding equality
  • Canonicalization for ShardingOp folding constant values into respective static_* attributes

At some point, we should probably refactor how spmdization treats various resharding patterns.

@sogartar @yaochengji @mfrancio Could you please have a look?

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes
  • Changed MeshSharding::sharded_dims_sizes from representing sizes per shard to offsets to origin per shard.
    • Local shard size are now a simple subtraction
    • Offsets are now readily available without a reduction operation
    • Enables constant value/shape propagation through standard canonicalization
    • Renamed to sharded_dims_offsets accordingly.
  • First spmdization pattern for halo regions.
    • Triggers when source and destination shardings differ only in their halo sizes
    • Copies local data from source into a new tensor and calls update_halo
    • Supports arbitrary mesh dimensions (unlike the other patterns which work on 1d meshes only)
  • UpdateHaloOp implements DestinationStyleOpInterface and accepts tensors and memrefs
    • also accepts target and source halo sizes; both are required for proper lowering
  • minor refactoring for testing partial MeshSharding equality
  • Canonicalization for ShardingOp folding constant values into respective static_* attributes

At some point, we should probably refactor how spmdization treats various resharding patterns.

@sogartar @yaochengji @mfrancio Could you please have a look?


Patch is 41.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114238.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+11-8)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+46-25)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+107-55)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+94-5)
  • (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+17)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+3-3)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+12-14)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+31)
  • (modified) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+11-11)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
   SmallVector<MeshAxis> partial_axes;
   ReductionKind partial_type;
   SmallVector<int64_t> static_halo_sizes;
-  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
-  SmallVector<Value> dynamic_sharded_dims_sizes;
+  SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
   MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
                           ArrayRef<MeshAxis> partial_axes_ = {},
                           ReductionKind partial_type_ = ReductionKind::Sum,
                           ArrayRef<int64_t> static_halo_sizes_ = {},
-                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
-                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+                          ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
   ::llvm::StringRef getMesh() const { return mesh.getValue(); }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
-  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
-    return static_sharded_dims_sizes;
+  ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+    return static_sharded_dims_offsets;
   }
   ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
-  ArrayRef<Value> getDynamicShardedDimsSizes() const {
-    return dynamic_sharded_dims_sizes;
+  ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+    return dynamic_sharded_dims_offsets;
   }
   operator bool() const { return (!mesh) == false; }
   bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
   bool operator!=(const MeshSharding &rhs) const;
   bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+  bool equalHaloSizes(const MeshSharding &rhs) const;
+  bool equalShardSizes(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..04b4b55a433803 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -196,16 +197,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
     `?` indicates dynamic halo sizes.
     
-    6. [Optional] Sizes of sharded dimensions of each shard.
-    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
-    device-mesh one value for each sharded tensor dimension.
+    6. [Optional] Offsets for each shard and sharded tensor dimension.
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
+    For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
+    The offset of each first shard is omitted and is implicitly assumed to be 0.
+    The last value per dimension denotes the end of the last shard.
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
-    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
-    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
-    shard of shape 16x24x32.
+    `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
+    the device-mesh will get a shard of shape 24x20x32 and the second device will get a
+    shard of shape 8x12x32.
     `?` indicates dynamic shard dimensions.
     
-    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+    `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
 
@@ -240,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
@@ -250,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     Mesh_MeshAxesArrayAttr:$split_axes,
     OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
     OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
-    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+    Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
     Variadic<I64>:$dynamic_halo_sizes
   );
@@ -263,7 +266,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `split_axes` `=` $split_axes
     (`partial` `=` $partial_type $partial_axes^)?
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+    (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
@@ -272,16 +275,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxis>":$partial_axes,
                    "mesh::ReductionKind":$partial_type,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
-                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1056,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 }
 
 def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DestinationStyleOpInterface,
+  TypesMatchWith<
+    "result has same type as destination",
+    "result", "destination", "$_self">,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  AttrSizedOperandSegments
 ]> {
   let summary = "Update halo data.";
   let description = [{
     This operation updates halo regions of shards, e.g. if their sharding
-    specified halos and the actual tensor data might have changed
+    specified halos and the actual tensor/memref data might have changed
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
+    Source and destination might have different halo sizes.
+
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `dynamic/static_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes` and
+    `destination_halo_sizes/static_destination_halo_sizes` in source shard
+    and destination/result shard.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Optionally resizes to new halo sizes `target_halo_sizes`.
   }];
   let arguments = (ins
-    AnyNon0RankedMemRef:$input,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$dynamic_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+    Variadic<I64>:$source_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+    Variadic<I64>:$destination_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+  );
+  let results = (outs
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh
+    $source `into` $destination
+    `on` $mesh
     `split_axes` `=` $split_axes
-    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`target_halo_sizes` `=` $target_halo_sizes^)?
-    attr-dict `:` type($input)
+    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
+    attr-dict `:` type($source) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
   }];
 }
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..d65f7e4bbadd1a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
-                       ArrayRef<int64_t> shardedDimsSizes = {},
+                       ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
-  if (!shardedDimsSizes.empty()) {
+  if (!shardedDimsOffsets.empty()) {
+    uint64_t pos = 0;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-      if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
-        for (auto dimSz : shardedDimsSizes) {
-          auto inAxis = dimSz % inShape.size();
-          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
-                 inShape[inAxis] == ShapedType::kDynamic);
-        }
-#endif // NDEBUG
-      } else {
-        // find sharded dims in sharded_dims_sizes with same static size on
-        // all devices. Use kDynamic for dimensions with dynamic or non-uniform
-        // sizes in sharded_dims_sizes.
-        auto sz = shardedDimsSizes[tensorAxis];
-        bool same = true;
-        for (size_t i = tensorAxis + inShape.size();
-             i < shardedDimsSizes.size(); i += inShape.size()) {
-          if (shardedDimsSizes[i] != sz) {
-            same = false;
-            break;
+      if (!innerSplitAxes.empty()) {
+        auto sz = shardedDimsOffsets[pos];
+        bool same = !ShapedType::isDynamicShape(meshShape);
+        if (same) {
+          // find sharded dims in shardedDimsOffsets with same static size on
+          // all devices. Use kDynamic for dimensions with dynamic or
+          // non-uniform offs in shardedDimsOffsets.
+          uint64_t numShards = 0;
+          for (auto i : innerSplitAxes.asArrayRef()) {
+            numShards += meshShape[i];
+          }
+          for (size_t i = 1; i < numShards; ++i) {
+            if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+                sz) {
+              same = false;
+              break;
+            }
           }
+          pos += numShards;
         }
         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
@@ -255,7 +255,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
   shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr, sharding.getStaticShardedDimsSizes(),
+             resShapeArr, sharding.getStaticShardedDimsOffsets(),
              sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
@@ -432,13 +432,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_sizes) {
+                       ArrayRef<int64_t> static_sharded_dims_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+                                     static_sharded_dims_offsets),
       {});
 }
 
@@ -455,11 +456,11 @@ void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
-    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
   mlir::SmallVector<int64_t> staticHalos, staticDims;
   mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
-  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +478,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
             : b.getDenseI16ArrayAttr(from.getPartialAxes()),
         ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
                                              from.getPartialType()),
-        from.getStaticShardedDimsSizes().empty()
+        from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
-            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
-        from.getDynamicShardedDimsSizes(),
+            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+        from.getDynamicShardedDimsOffsets(),
         from.getStaticHaloSizes().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,7 +510,7 @@ LogicalResult ShardingOp::verify() {
       failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
 
-  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
     return emitOpError("halo sizes and shard shapes are mutually exclusive");
   }
 
@@ -539,13 +540,49 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
-      getStaticShardedDimsSizes().size() > 0) {
-    return emitError() << "sharded dims sizes are not allowed for "
+      getStaticShardedDimsOffsets().size() > 0) {
+    return emitError() << "sharded dims offsets are not allowed for "
                           "devices meshes with dynamic shape.";
   }
   return success();
 }
 
+namespace {
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+  using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardingOp op,
+                                PatternRewriter &b) const override {
+    auto mixedHalos =
+        getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+    auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+                                    op.getDynamicShardedDimsOffsets(), b);
+
+    // No constant operands were folded, just return;
+    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+      return failure();
+    }
+
+    auto halos = decomposeMixedValues(mixedHalos);
+    auto offs = decomposeMixedValues(mixedOffs);
+
+    op.setStaticHaloSizes(halos.first);
+    op.getDynamicHaloSizesMutable().assign(halos.second);
+    op.setStaticShardedDimsOffsets(offs.first);
+    op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+    return success();
+  }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                             mlir::MLIRContext *context) {
+  results.add<FoldDynamicLists>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MeshSharding
 //===----------------------------------------------------------------------===//
@@ -555,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
     return false;
   }
 
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+      (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+      !llvm::equal(
+          llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+          llvm::make_range(rhs.getPartialAxes().begin(),
+                           rhs.getPartialAxes().end()))) {
     return false;
   }
 
@@ -576,6 +618,31 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+  return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+  if (rhs.getStaticShardedDimsOffsets().size() !=
+          getStaticShardedDimsOffsets().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+                                    getStaticShardedDimsOffsets().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+                                    rhs.getStaticShardedDimsOffsets().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicShardedDimsOffsets().size() !=
+          getDynamicShardedDimsOffsets().size() ||
+      !llvm::equal(
+          llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+                   ...
[truncated]

Copy link

github-actions bot commented Oct 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@mfrancio mfrancio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more intuitive to have offsets starting from zero unless there is a specific reason not to follow this approach?

6. [Optional] Offsets for each shard and sharded tensor dimension.
`sharded_dims_offsets` is provided as a flattened 1d array of i64s:
For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
The offset of each first shard is omitted and is implicitly assumed to be 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular benefit in following this convention? Would this allow, for example, to crop the last dimension and not get the full dimensionality of the tensor?

I would think that actually listing the offsets directly (exclusive prefix-sum) would be more intuitive if the above case is anyway disallowed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes get your point. We need the last value to able to compute the shard size of the last shard. As said elsewhere, I had to either do it as you suggest and have an array of size N+1 or have the first value given implicitly. I often prefer avoiding redundancy so I did it this way. I have no strong opinion here, if you have a clear preference I can switch to the longer array solution.

Copy link
Contributor

@mfrancio mfrancio Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel pretty strongly about having the offset computed as exclusive prefix sum of the "size" array. That is, the offset array is still size N, the first element is 0, and the offset for shard i can be retrieved as sharded_dims_offsets[i].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need an additional last element, otherwise we cannot deduce the size of of the last shard.
That is because there is no way to implicitly define the global tensor shape. It should be possible to deduce the sizes without pulling in a dependency to the input tensor. The start will always be 0, so that is redundant information and can be avoided to keep the array to size N. Not so the end.
I will change to having offset for shard i as sharded_dims_offsets[i] and append the offset to the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if i did not provide more context in my previous message, but is there an instance where the input tensor shape is not directly available that warrant the N+1 representation?

If that's the case, then I would agree with you in that the original implementation (size N, inclusive prefix sum representing the end offset) was probably better, especially because now you have a flattened 2d array of n+1 values which adds additional complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, at this point I don't know about a case where the shape is unavailable.

Still, I prefer a form that is self-sufficient. There is some complexity involved no matter what. If extracting the tensor size from another entity is more complex than dealing with a N+1 values probably lies in the eye of the beholder.

"devices meshes with dynamic shape.";
}
return success();
}

namespace {
class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we comment above the purpose of this canonicalization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

MeshSharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
// currently handles only cases where halo sizes differ but everything else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments in this file miss proper capitalization and punctuation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some improvements.

TypedValue<ShapedType> sourceShard) {
// currently handles only cases where halo sizes differ but everything else
// stays the same (from source to destination sharding)
if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably early return and reduce the indentation block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

`sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
device-mesh one value for each sharded tensor dimension.
6. [Optional] Offsets for each shard and sharded tensor dimension.
`sharded_dims_offsets` is provided as a flattened 1d array of i64s:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it a flattened 1d array?

In my understanding, the offsets in each dimension is a strictly increasing array. I feel it might be more straightforward to organize it as a 2d array.

BTW, I think we'd better add some code to check whether it's a strictly increasing array or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ideally this would be a 2d array. As we allow dynamic and static values, I didn't intuitively know how to implement the parsing for it. Mechanics exist for 1d arrays, though. We can add this in an extra PR, but at this point I didn't want to spend time on it. Of course it'd be great if someone had a solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The offsets can be dynamic values, so the check cannot be a static check and requires generating IR. Where would you put such a check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for static values, we can use IndexListArrayAttr or SmallVector<SmallVector<int64_t>>. For dynamic values, we can use SmallVector<SmallVector>. But I'm totally fine that we can optimize it later.

With regarding to the check, we can add the check logic in the verify function only for static values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check and tests

@fschlimb
Copy link
Contributor Author

fschlimb commented Nov 4, 2024

Wouldn't it be more intuitive to have offsets starting from zero unless there is a specific reason not to follow this approach?

Yes, we could do that. It would require one the array to be of size N+1, though. I was undecided at the time; if you have a clear preference I will make the change.

@fschlimb
Copy link
Contributor Author

fschlimb commented Nov 6, 2024

@mfrancio @yaochengji Thanks for your thoughtful reviews. If there are no other concerns (@sogartar ?), could you please merge (I do not yet have write permissions).

@mfrancio
Copy link
Contributor

mfrancio commented Nov 6, 2024

@mfrancio @yaochengji Thanks for your thoughtful reviews. If there are no other concerns (@sogartar ?), could you please merge (I do not yet have write permissions).

I can help with that - I'll wait until the weekend to give @yaochengji or @sogartar the opportunity to take another look.

@yaochengji
Copy link
Member

LGTM, thanks

@mfrancio mfrancio merged commit ffc7fea into llvm:main Nov 11, 2024
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 11, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia running on mlir-nvidia while building mlir at step 5 "build-check-mlir-build-only".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/6200

Here is the relevant piece of the build log for the reference
Step 5 (build-check-mlir-build-only) failure: build (failure)
...
35.777 [253/9/4755] Linking CXX shared library lib/libMLIRGPUTransforms.so.20.0git
35.784 [252/9/4756] Creating library symlink lib/libMLIRGPUTransforms.so
35.792 [252/8/4757] Linking CXX shared library lib/libMLIRCAPIExecutionEngine.so.20.0git
35.798 [251/8/4758] Creating library symlink lib/libMLIRCAPIExecutionEngine.so
36.506 [251/7/4759] Building CXX object tools/mlir/test/lib/Dialect/Mesh/CMakeFiles/MLIRMeshTest.dir/TestSimplifications.cpp.o
38.677 [251/6/4760] Building CXX object tools/mlir/test/lib/Dialect/Mesh/CMakeFiles/MLIRMeshTest.dir/TestOpLowering.cpp.o
39.518 [251/5/4761] Building CXX object tools/mlir/test/lib/Dialect/Mesh/CMakeFiles/MLIRMeshTest.dir/TestReshardingSpmdization.cpp.o
46.255 [251/4/4762] Building CXX object tools/mlir/lib/Dialect/Linalg/Transforms/CMakeFiles/obj.MLIRLinalgTransforms.dir/MeshShardingInterfaceImpl.cpp.o
47.483 [251/3/4763] Building CXX object tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o
47.567 [250/3/4764] Linking CXX shared library lib/libMLIRMeshDialect.so.20.0git
FAILED: lib/libMLIRMeshDialect.so.20.0git 
: && /usr/bin/clang++ -fPIC -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -Werror=unguarded-availability-new -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wmissing-field-initializers -pedantic -Wno-long-long -Wc++98-compat-extra-semi -Wimplicit-fallthrough -Wcovered-switch-default -Wno-noexcept-type -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wsuggest-override -Wstring-conversion -Wmisleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Werror=mismatched-tags -Werror=global-constructors -O3 -DNDEBUG  -Wl,-z,defs -Wl,-z,nodelete -fuse-ld=lld -Wl,--color-diagnostics   -Wl,--gc-sections -shared -Wl,-soname,libMLIRMeshDialect.so.20.0git -o lib/libMLIRMeshDialect.so.20.0git tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o  -Wl,-rpath,"\$ORIGIN/../lib:/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib:"  lib/libMLIRArithDialect.so.20.0git  lib/libMLIRDialectUtils.so.20.0git  lib/libMLIRViewLikeInterface.so.20.0git  lib/libMLIRCastInterfaces.so.20.0git  lib/libMLIRDialect.so.20.0git  lib/libMLIRInferIntRangeCommon.so.20.0git  lib/libMLIRInferIntRangeInterface.so.20.0git  lib/libMLIRInferTypeOpInterface.so.20.0git  lib/libMLIRUBDialect.so.20.0git  lib/libMLIRIR.so.20.0git  lib/libMLIRSupport.so.20.0git  lib/libLLVMSupport.so.20.0git  -Wl,-rpath-link,/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib && :
ld.lld: error: undefined symbol: mlir::detail::verifyDestinationStyleOpInterface(mlir::Operation*)
>>> referenced by MeshOps.cpp
>>>               tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o:(mlir::Op<mlir::mesh::UpdateHaloOp, mlir::OpTrait::ZeroRegions, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult<mlir::Type>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::AtLeastNOperands<2u>::Impl, mlir::OpTrait::AttrSizedOperandSegments, mlir::OpTrait::OpInvariants, mlir::BytecodeOpInterface::Trait, mlir::DestinationStyleOpInterface::Trait, mlir::SymbolUserOpInterface::Trait>::verifyRegionInvariants(mlir::Operation*))
clang: error: linker command failed with exit code 1 (use -v to see invocation)
68.305 [250/2/4765] Building CXX object tools/mlir/lib/Dialect/Tosa/CMakeFiles/obj.MLIRTosaDialect.dir/IR/TosaOps.cpp.o
101.738 [250/1/4766] Building CXX object tools/mlir/lib/Dialect/Linalg/IR/CMakeFiles/obj.MLIRLinalgDialect.dir/LinalgDialect.cpp.o
ninja: build stopped: subcommand failed.

@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 11, 2024

LLVM Buildbot has detected a new failure on builder flang-aarch64-libcxx running on linaro-flang-aarch64-libcxx while building mlir at step 5 "build-unified-tree".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/89/builds/10252

Here is the relevant piece of the build log for the reference
Step 5 (build-unified-tree) failure: build (failure)
...
79.428 [593/15/6651] Building CXX object tools/llvm-objcopy/CMakeFiles/llvm-objcopy.dir/llvm-objcopy.cpp.o
79.436 [583/24/6652] Building CXX object tools/llvm-objcopy/CMakeFiles/llvm-objcopy.dir/ObjcopyOptions.cpp.o
79.437 [583/23/6653] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/SourcePrinter.cpp.o
79.439 [583/22/6654] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/COFFDump.cpp.o
79.439 [583/21/6655] Linking CXX executable bin/llvm-ml
79.446 [575/28/6656] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/llvm-objdump.cpp.o
79.447 [575/27/6657] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/ELFDump.cpp.o
79.461 [575/26/6658] Linking CXX executable bin/llvm-nm
79.475 [575/25/6659] Linking CXX executable bin/llvm-mt
79.508 [575/24/6660] Linking CXX shared library lib/libMLIRMeshDialect.so.20.0git
FAILED: lib/libMLIRMeshDialect.so.20.0git 
: && /usr/local/bin/c++ -fPIC -stdlib=libc++ -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -Werror=unguarded-availability-new -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wmissing-field-initializers -pedantic -Wno-long-long -Wc++98-compat-extra-semi -Wimplicit-fallthrough -Wcovered-switch-default -Wno-noexcept-type -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wsuggest-override -Wstring-conversion -Wmisleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Werror=mismatched-tags -Werror=global-constructors -O3 -DNDEBUG  -stdlib=libc++ -Wl,-z,defs -Wl,-z,nodelete   -Wl,-rpath-link,/home/tcwg-buildbot/worker/flang-aarch64-libcxx/build/./lib  -Wl,--gc-sections -shared -Wl,-soname,libMLIRMeshDialect.so.20.0git -o lib/libMLIRMeshDialect.so.20.0git tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o  -Wl,-rpath,"\$ORIGIN/../lib:/home/tcwg-buildbot/worker/flang-aarch64-libcxx/build/lib:"  lib/libMLIRArithDialect.so.20.0git  lib/libMLIRDialectUtils.so.20.0git  lib/libMLIRViewLikeInterface.so.20.0git  lib/libMLIRCastInterfaces.so.20.0git  lib/libMLIRDialect.so.20.0git  lib/libMLIRInferIntRangeCommon.so.20.0git  lib/libMLIRInferIntRangeInterface.so.20.0git  lib/libMLIRInferTypeOpInterface.so.20.0git  lib/libMLIRUBDialect.so.20.0git  lib/libMLIRIR.so.20.0git  lib/libMLIRSupport.so.20.0git  lib/libLLVMSupport.so.20.0git  -Wl,-rpath-link,/home/tcwg-buildbot/worker/flang-aarch64-libcxx/build/lib && :
/usr/bin/ld: tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o: in function `mlir::Op<mlir::mesh::UpdateHaloOp, mlir::OpTrait::ZeroRegions, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult<mlir::Type>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::AtLeastNOperands<2u>::Impl, mlir::OpTrait::AttrSizedOperandSegments, mlir::OpTrait::OpInvariants, mlir::BytecodeOpInterface::Trait, mlir::DestinationStyleOpInterface::Trait, mlir::SymbolUserOpInterface::Trait>::verifyRegionInvariants(mlir::Operation*)':
MeshOps.cpp:(.text._ZN4mlir2OpINS_4mesh12UpdateHaloOpEJNS_7OpTrait11ZeroRegionsENS3_9OneResultENS3_14OneTypedResultINS_4TypeEE4ImplENS3_14ZeroSuccessorsENS3_16AtLeastNOperandsILj2EE4ImplENS3_24AttrSizedOperandSegmentsENS3_12OpInvariantsENS_19BytecodeOpInterface5TraitENS_27DestinationStyleOpInterface5TraitENS_21SymbolUserOpInterface5TraitEEE22verifyRegionInvariantsEPNS_9OperationE[_ZN4mlir2OpINS_4mesh12UpdateHaloOpEJNS_7OpTrait11ZeroRegionsENS3_9OneResultENS3_14OneTypedResultINS_4TypeEE4ImplENS3_14ZeroSuccessorsENS3_16AtLeastNOperandsILj2EE4ImplENS3_24AttrSizedOperandSegmentsENS3_12OpInvariantsENS_19BytecodeOpInterface5TraitENS_27DestinationStyleOpInterface5TraitENS_21SymbolUserOpInterface5TraitEEE22verifyRegionInvariantsEPNS_9OperationE]+0x14): undefined reference to `mlir::detail::verifyDestinationStyleOpInterface(mlir::Operation*)'
clang++: error: linker command failed with exit code 1 (use -v to see invocation)
79.524 [575/23/6661] Linking CXX executable bin/llvm-opt-report
79.552 [575/22/6662] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/llvm-objdump-driver.cpp.o
79.556 [575/21/6663] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/WasmDump.cpp.o
79.557 [575/20/6664] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyClassLayoutGraphicalDumper.cpp.o
79.559 [575/19/6665] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyClassDefinitionDumper.cpp.o
79.560 [575/18/6666] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyCompilandDumper.cpp.o
79.561 [575/17/6667] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/OffloadDump.cpp.o
79.562 [575/16/6668] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/XCOFFDump.cpp.o
79.563 [575/15/6669] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyExternalSymbolDumper.cpp.o
79.567 [575/14/6670] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyTypedefDumper.cpp.o
79.567 [575/13/6671] Linking CXX executable bin/llvm-objcopy
79.568 [575/12/6672] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyFunctionDumper.cpp.o
79.570 [575/11/6673] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyVariableDumper.cpp.o
79.572 [575/10/6674] Building CXX object tools/llvm-objdump/CMakeFiles/llvm-objdump.dir/MachODump.cpp.o
79.573 [575/9/6675] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyEnumDumper.cpp.o
79.574 [575/8/6676] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/StreamUtil.cpp.o
79.575 [575/7/6677] Building CXX object tools/llvm-pdbutil/CMakeFiles/llvm-pdbutil.dir/PrettyTypeDumper.cpp.o
79.618 [575/6/6678] Linking CXX executable bin/llvm-opt-fuzzer
81.901 [575/5/6679] Linking CXX shared library lib/libFortranEvaluate.so.20.0git
83.097 [575/4/6680] Linking CXX shared library lib/libclang-cpp.so.20.0git
95.317 [575/3/6681] Building CXX object tools/mlir/lib/Dialect/Linalg/Transforms/CMakeFiles/obj.MLIRLinalgTransforms.dir/MeshShardingInterfaceImpl.cpp.o
105.792 [575/2/6682] Building CXX object tools/mlir/lib/Dialect/Tosa/CMakeFiles/obj.MLIRTosaDialect.dir/IR/TosaOps.cpp.o
137.595 [575/1/6683] Building CXX object tools/mlir/lib/Dialect/Linalg/IR/CMakeFiles/obj.MLIRLinalgDialect.dir/LinalgDialect.cpp.o
ninja: build stopped: subcommand failed.

@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 11, 2024

LLVM Buildbot has detected a new failure on builder flang-aarch64-latest-gcc running on linaro-flang-aarch64-latest-gcc while building mlir at step 5 "build-unified-tree".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/130/builds/6014

Here is the relevant piece of the build log for the reference
Step 5 (build-unified-tree) failure: build (failure)
...
196.382 [299/8/6968] Linking CXX executable bin/verify-uselistorder
196.414 [299/7/6969] Linking CXX executable bin/sancov
196.419 [299/6/6970] Linking CXX executable bin/clang-repl
196.456 [299/5/6971] Linking CXX shared library lib/libLLVMOptDriver.so.20.0git
196.464 [298/5/6972] Creating library symlink lib/libLLVMOptDriver.so
196.680 [297/5/6973] Linking CXX executable bin/opt
196.692 [297/4/6974] Linking CXX shared library lib/libFortranEvaluate.so.20.0git
196.702 [296/4/6975] Creating library symlink lib/libFortranEvaluate.so
198.009 [295/4/6976] Building CXX object tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o
198.244 [294/4/6977] Linking CXX shared library lib/libMLIRMeshDialect.so.20.0git
FAILED: lib/libMLIRMeshDialect.so.20.0git 
: && /usr/local/bin/c++ -fPIC -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -fno-lifetime-dse -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wno-missing-field-initializers -pedantic -Wno-long-long -Wimplicit-fallthrough -Wno-maybe-uninitialized -Wno-nonnull -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wdelete-non-virtual-dtor -Wsuggest-override -Wno-comment -Wno-misleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Wno-unused-but-set-parameter -O3 -DNDEBUG  -Wl,-z,defs -Wl,-z,nodelete   -Wl,-rpath-link,/home/tcwg-buildbot/worker/flang-aarch64-latest-gcc/build/./lib  -Wl,--gc-sections -shared -Wl,-soname,libMLIRMeshDialect.so.20.0git -o lib/libMLIRMeshDialect.so.20.0git tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o  -Wl,-rpath,"\$ORIGIN/../lib:/home/tcwg-buildbot/worker/flang-aarch64-latest-gcc/build/lib:"  lib/libMLIRArithDialect.so.20.0git  lib/libMLIRDialectUtils.so.20.0git  lib/libMLIRViewLikeInterface.so.20.0git  lib/libMLIRCastInterfaces.so.20.0git  lib/libMLIRDialect.so.20.0git  lib/libMLIRInferIntRangeCommon.so.20.0git  lib/libMLIRInferIntRangeInterface.so.20.0git  lib/libMLIRInferTypeOpInterface.so.20.0git  lib/libMLIRUBDialect.so.20.0git  lib/libMLIRIR.so.20.0git  lib/libMLIRSupport.so.20.0git  lib/libLLVMSupport.so.20.0git  -Wl,-rpath-link,/home/tcwg-buildbot/worker/flang-aarch64-latest-gcc/build/lib && :
/usr/bin/ld: tools/mlir/lib/Dialect/Mesh/IR/CMakeFiles/obj.MLIRMeshDialect.dir/MeshOps.cpp.o: in function `mlir::Op<mlir::mesh::UpdateHaloOp, mlir::OpTrait::ZeroRegions, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult<mlir::Type>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::AtLeastNOperands<2u>::Impl, mlir::OpTrait::AttrSizedOperandSegments, mlir::OpTrait::OpInvariants, mlir::BytecodeOpInterface::Trait, mlir::DestinationStyleOpInterface::Trait, mlir::SymbolUserOpInterface::Trait>::verifyRegionInvariants(mlir::Operation*)':
MeshOps.cpp:(.text._ZN4mlir2OpINS_4mesh12UpdateHaloOpEJNS_7OpTrait11ZeroRegionsENS3_9OneResultENS3_14OneTypedResultINS_4TypeEE4ImplENS3_14ZeroSuccessorsENS3_16AtLeastNOperandsILj2EE4ImplENS3_24AttrSizedOperandSegmentsENS3_12OpInvariantsENS_19BytecodeOpInterface5TraitENS_27DestinationStyleOpInterface5TraitENS_21SymbolUserOpInterface5TraitEEE22verifyRegionInvariantsEPNS_9OperationE[_ZN4mlir2OpINS_4mesh12UpdateHaloOpEJNS_7OpTrait11ZeroRegionsENS3_9OneResultENS3_14OneTypedResultINS_4TypeEE4ImplENS3_14ZeroSuccessorsENS3_16AtLeastNOperandsILj2EE4ImplENS3_24AttrSizedOperandSegmentsENS3_12OpInvariantsENS_19BytecodeOpInterface5TraitENS_27DestinationStyleOpInterface5TraitENS_21SymbolUserOpInterface5TraitEEE22verifyRegionInvariantsEPNS_9OperationE]+0x10): undefined reference to `mlir::detail::verifyDestinationStyleOpInterface(mlir::Operation*)'
collect2: error: ld returned 1 exit status
199.676 [294/3/6978] Linking CXX shared library lib/libFortranSemantics.so.20.0git
202.970 [294/2/6979] Linking CXX shared library lib/libclang-cpp.so.20.0git
222.006 [294/1/6980] Building CXX object tools/mlir/lib/Dialect/Tosa/CMakeFiles/obj.MLIRTosaDialect.dir/IR/TosaOps.cpp.o
ninja: build stopped: subcommand failed.

@jplehr
Copy link
Contributor

jplehr commented Nov 11, 2024

Is someone investigating these CI issues? Our flang bot turned red, too.

@fschlimb
Copy link
Contributor Author

Yep, a PR will be available in a few minutes.

@fschlimb
Copy link
Contributor Author

See #115703, I do not have yet write permissions, if you can, @jplehr please merge #115703

jplehr pushed a commit that referenced this pull request Nov 11, 2024
fixing CI failures caused by #114238 by adding
MLIRDestinationStyleOpInterface lib
@jplehr @mfrancio @rengolin
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…lvm#114238)

* Changed `MeshSharding::sharded_dims_sizes` from representing sizes per
  shard to offsets to origin per shard.
  - Local shard size are now a simple subtraction
  - Offsets are now readily available without a reduction operation
  - Enables constant value/shape propagation through standard
    canonicalization
  - Renamed to `sharded_dims_offsets` accordingly.
* First spmdization pattern for halo regions.
  - Triggers when source and destination shardings differ only in their
    halo sizes
- Copies local data from source into a new tensor and calls update_halo
- Supports arbitrary mesh dimensions (unlike the other patterns which
  work on 1d meshes only)
* `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts
  tensors and memrefs
  - also accepts target and source halo sizes; both are required for
    proper lowering
* minor refactoring for testing partial MeshSharding equality
* Canonicalization for ShardingOp folding constant values into
  respective `static_*` attributes
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
fixing CI failures caused by llvm#114238 by adding
MLIRDestinationStyleOpInterface lib
@jplehr @mfrancio @rengolin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants