-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][mesh]fixes for 0d tensors #132948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh]fixes for 0d tensors #132948
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Frank Schlimbach (fschlimb) Changes0d tensors are generally treated as scalars, e.g. they are always replicated. @tkarna Could you pleas have a look at this? Full diff: https://github.com/llvm/llvm-project/pull/132948.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fc5cfffea27a7..32c2eca2cefa8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
inline mesh::MeshOp
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
+ if (!meshSymbol)
+ return nullptr;
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
op, meshSymbol);
}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3e9f86fde64f3..65475b69dbdb1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
- if (rankedTensorType) {
+ if (rankedTensorType && !rankedTensorType.getShape().empty()) {
return shardShapedType(rankedTensorType, mesh, sharding);
}
return type;
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index f427d004c558f..8aaa0704119a8 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -718,6 +718,6 @@ void mesh::spmdizeTriviallyShardableOperation(
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(
shardType(newResult.getType(),
- getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
+ getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 601af0200e785..b6cb06ae3170f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
- if (!rankedTensorArg) {
+ if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
return arg.getType();
}
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(operand);
- if (!rankedTensor) {
+ if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
return MeshSharding();
}
@@ -690,18 +690,31 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
- [](OpResult result) {
+ [&op](OpResult result) {
+ if (!result.hasOneUse() || result.use_empty()) {
+ return MeshSharding();
+ }
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshSharding();
}
- if (!result.hasOneUse()) {
- return MeshSharding();
- }
Operation *userOp = *result.getUsers().begin();
- ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- return MeshSharding(shardOp.getSharding());
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+ if (shardOp) {
+ return MeshSharding(shardOp.getSharding());
+ }
+ if (rankedTensor.getType().getRank() == 0) {
+ // This is a 0d tensor result without explicit sharding.
+ // Find mesh symbol from operands, if any.
+ // Shardings without mesh are not always fully supported yet.
+ for (auto operand: op.getOperands()) {
+ if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+ return MeshSharding(sharding.getMeshAttr());
+ }
+ }
+ }
+ return MeshSharding();
});
return res;
}
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b3d69eb5e1a23..fc93f1c1c9220 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
- auto mesh =
- mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
- auto shardType = cast<ShapedType>(
- mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
+ assert(resultShardings.size() == 1);
+ auto resType = cast<RankedTensorType>(op->getResult(0).getType());
+ mlir::mesh::MeshOp mesh;
+ ShapedType shardType;
+ if (resType.getRank() > 0) {
+ mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+ shardType =
+ cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+ } else {
+ shardType = resType;
+ }
Operation *newOp = nullptr;
// if the sharding introduces a new dynamic dimension, we take it from
// the dynamic sharding info. For now bail out if it's not
// provided.
- assert(resultShardings.size() == 1);
if (!shardType.hasStaticShape()) {
assert(op->getResult(0).hasOneUse());
SmallVector<Value> newOperands;
- auto oldType = cast<ShapedType>(op->getResult(0).getType());
+ auto oldType = cast<ShapedType>(resType);
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
mesh::ShardShapeOp shapeForDevice;
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 01cf5972177f4..3fb8424745501 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
return
}
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+ tensor.empty() : tensor<f32>
+ // CHECK-NEXT: tensor.empty() : tensor<f32>
+ return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good to me.
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.
0d tensors are generally treated as scalars, e.g. they are always replicated.
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.
@tkarna Could you pleas have a look at this?