Skip to content

[mlir][mesh]fixes for 0d tensors #132948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 26, 2025
Merged

[mlir][mesh]fixes for 0d tensors #132948

merged 3 commits into from
Mar 26, 2025

Conversation

fschlimb
Copy link
Contributor

0d tensors are generally treated as scalars, e.g. they are always replicated.
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.

@tkarna Could you pleas have a look at this?

@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Frank Schlimbach (fschlimb)

Changes

0d tensors are generally treated as scalars, e.g. they are always replicated.
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.

@tkarna Could you pleas have a look at this?


Full diff: https://github.com/llvm/llvm-project/pull/132948.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+2)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+21-8)
  • (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+12-6)
  • (modified) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+7)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fc5cfffea27a7..32c2eca2cefa8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
 inline mesh::MeshOp
 getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
               SymbolTableCollection &symbolTableCollection) {
+  if (!meshSymbol)
+    return nullptr;
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
       op, meshSymbol);
 }
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3e9f86fde64f3..65475b69dbdb1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
 
 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
-  if (rankedTensorType) {
+  if (rankedTensorType && !rankedTensorType.getShape().empty()) {
     return shardShapedType(rankedTensorType, mesh, sharding);
   }
   return type;
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index f427d004c558f..8aaa0704119a8 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -718,6 +718,6 @@ void mesh::spmdizeTriviallyShardableOperation(
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
     newResult.setType(
         shardType(newResult.getType(),
-                  getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
+                  getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 601af0200e785..b6cb06ae3170f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
       block.getArguments(), std::back_inserter(res),
       [&symbolTableCollection](BlockArgument arg) {
         auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
-        if (!rankedTensorArg) {
+        if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
           return arg.getType();
         }
 
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
     TypedValue<RankedTensorType> rankedTensor =
         dyn_cast<TypedValue<RankedTensorType>>(operand);
-    if (!rankedTensor) {
+    if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
       return MeshSharding();
     }
 
@@ -690,18 +690,31 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
   std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
   llvm::transform(op.getResults(), std::back_inserter(res),
-                  [](OpResult result) {
+                  [&op](OpResult result) {
+                    if (!result.hasOneUse() || result.use_empty()) {
+                      return MeshSharding();
+                    }
                     TypedValue<RankedTensorType> rankedTensor =
                         dyn_cast<TypedValue<RankedTensorType>>(result);
                     if (!rankedTensor) {
                       return MeshSharding();
                     }
-                    if (!result.hasOneUse()) {
-                      return MeshSharding();
-                    }
                     Operation *userOp = *result.getUsers().begin();
-                    ShardOp shardOp = llvm::cast<ShardOp>(userOp);
-                    return MeshSharding(shardOp.getSharding());
+                    ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+                    if (shardOp) {
+                      return MeshSharding(shardOp.getSharding());
+                    }
+                    if (rankedTensor.getType().getRank() == 0) {
+                      // This is a 0d tensor result without explicit sharding.
+                      // Find mesh symbol from operands, if any.
+                      // Shardings without mesh are not always fully supported yet.
+                      for (auto operand: op.getOperands()) {
+                        if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+                          return MeshSharding(sharding.getMeshAttr());
+                        }
+                      }
+                    }
+                    return MeshSharding();
                   });
   return res;
 }
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b3d69eb5e1a23..fc93f1c1c9220 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
-    auto mesh =
-        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
-    auto shardType = cast<ShapedType>(
-        mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
+    assert(resultShardings.size() == 1);
+    auto resType = cast<RankedTensorType>(op->getResult(0).getType());
+    mlir::mesh::MeshOp mesh;
+    ShapedType shardType;
+    if (resType.getRank() > 0) {
+      mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+      shardType =
+          cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+    } else {
+      shardType = resType;
+    }
     Operation *newOp = nullptr;
     // if the sharding introduces a new dynamic dimension, we take it from
     // the dynamic sharding info. For now bail out if it's not
     // provided.
-    assert(resultShardings.size() == 1);
     if (!shardType.hasStaticShape()) {
       assert(op->getResult(0).hasOneUse());
       SmallVector<Value> newOperands;
-      auto oldType = cast<ShapedType>(op->getResult(0).getType());
+      auto oldType = cast<ShapedType>(resType);
       assert(oldType.getRank() == shardType.getRank());
       int currOldOprndNum = -1;
       mesh::ShardShapeOp shapeForDevice;
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 01cf5972177f4..3fb8424745501 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
 
   return
 }
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+  tensor.empty() : tensor<f32>
+  // CHECK-NEXT:  tensor.empty() : tensor<f32>
+  return
+}

@fschlimb fschlimb changed the title fixes for 0d tensors [mlir][mesh]fixes for 0d tensors Mar 25, 2025
Copy link

github-actions bot commented Mar 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good to me.

swift-ci pushed a commit to swiftlang/llvm-project that referenced this pull request Mar 26, 2025
In some cases 0d tensors have no sharding. This PR provides a few minor
fixes to account for such cases.
@fschlimb fschlimb merged commit 545ea0d into llvm:main Mar 26, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants