Skip to content

[MLIR][Shape] Support >2 args in shape.broadcast folder #126808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2025

Conversation

mtsokol
Copy link
Contributor

@mtsokol mtsokol commented Feb 11, 2025

Hi!

As the title says, this PR adds support for >2 arguments in shape.broadcast folder by sequentially calling getBroadcastedShape.

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir

Author: Mateusz Sokół (mtsokol)

Changes

Hi!

As the title says, this PR adds support for >2 arguments in shape.broadcast folder by sequentially calling getBroadcastedShape.


Full diff: https://github.com/llvm/llvm-project/pull/126808.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+21-13)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
     return getShapes().front();
   }
 
-  // TODO: Support folding with more than 2 input shapes
-  if (getShapes().size() > 2)
+  if (!adaptor.getShapes().front())
     return nullptr;
 
-  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
-    return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
-          .getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+  auto firstShape = llvm::to_vector<6>(
+      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
+
   SmallVector<int64_t, 6> resultShape;
+  resultShape.clear();
+  std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
 
-  // If the shapes are not compatible, we can't fold it.
-  // TODO: Fold to an "error".
-  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
-    return nullptr;
+  for (auto next : adaptor.getShapes().drop_front()) {
+    if (!next)
+      return nullptr;
+    auto nextShape = llvm::to_vector<6>(
+        llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+    SmallVector<int64_t, 6> tmpShape;
+    // If the shapes are not compatible, we can't fold it.
+    // TODO: Fold to an "error".
+    if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+      return nullptr;
+
+    resultShape.clear();
+    std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+  }
 
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
       // One or both dimensions is unknown. Follow TensorFlow behavior:
       // - If either dimension is greater than 1, we assume that the program is
-      //   correct, and the other dimension will be broadcast to match it.
+      //   correct, and the other dimension will be broadcasted to match it.
       // - If either dimension is 1, the other dimension is the output.
       if (*i1 > 1) {
         *iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
 
 // -----
 
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+  %0 = shape.const_shape [2, 1] : tensor<2xindex>
+  %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+  %2 = shape.const_shape [1, 10] : tensor<2xindex>
+  %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+  return %3 : !shape.shape
+}
+
+// -----
+
 // Rhs is a scalar.
 // CHECK-LABEL: func @f
 func.func @f(%arg0 : !shape.shape) -> !shape.shape {

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir-shape

Author: Mateusz Sokół (mtsokol)

Changes

Hi!

As the title says, this PR adds support for >2 arguments in shape.broadcast folder by sequentially calling getBroadcastedShape.


Full diff: https://github.com/llvm/llvm-project/pull/126808.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+21-13)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
     return getShapes().front();
   }
 
-  // TODO: Support folding with more than 2 input shapes
-  if (getShapes().size() > 2)
+  if (!adaptor.getShapes().front())
     return nullptr;
 
-  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
-    return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
-          .getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+  auto firstShape = llvm::to_vector<6>(
+      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
+
   SmallVector<int64_t, 6> resultShape;
+  resultShape.clear();
+  std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
 
-  // If the shapes are not compatible, we can't fold it.
-  // TODO: Fold to an "error".
-  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
-    return nullptr;
+  for (auto next : adaptor.getShapes().drop_front()) {
+    if (!next)
+      return nullptr;
+    auto nextShape = llvm::to_vector<6>(
+        llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+    SmallVector<int64_t, 6> tmpShape;
+    // If the shapes are not compatible, we can't fold it.
+    // TODO: Fold to an "error".
+    if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+      return nullptr;
+
+    resultShape.clear();
+    std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+  }
 
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
       // One or both dimensions is unknown. Follow TensorFlow behavior:
       // - If either dimension is greater than 1, we assume that the program is
-      //   correct, and the other dimension will be broadcast to match it.
+      //   correct, and the other dimension will be broadcasted to match it.
       // - If either dimension is 1, the other dimension is the output.
       if (*i1 > 1) {
         *iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
 
 // -----
 
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+  %0 = shape.const_shape [2, 1] : tensor<2xindex>
+  %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+  %2 = shape.const_shape [1, 10] : tensor<2xindex>
+  %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+  return %3 : !shape.shape
+}
+
+// -----
+
 // Rhs is a scalar.
 // CHECK-LABEL: func @f
 func.func @f(%arg0 : !shape.shape) -> !shape.shape {

for (auto next : adaptor.getShapes().drop_front()) {
if (!next)
return nullptr;
auto nextShape = llvm::to_vector<6>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the getBroadcastedShape implementation shape vector size is hardcoded to 6, so I did it similarly here. Does it make sense? Looks like an arbitrary value from the outside.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).

Comment on lines 675 to 676
resultShape.clear();
std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed getBroadcastedShape implementation and I'm not sure if it's the best way to handle Vectors/Shapes here, so a penny for your thoughts!

Copy link

github-actions bot commented Feb 11, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@mtsokol mtsokol force-pushed the shape-broadcast-fold-vararg branch from c4a815d to c3cf613 Compare February 11, 2025 22:15
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good thanks

for (auto next : adaptor.getShapes().drop_front()) {
if (!next)
return nullptr;
auto nextShape = llvm::to_vector<6>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).

SmallVector<int64_t, 6> resultShape;
resultShape.clear();
std::copy(firstShape.begin(), firstShape.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is firstShape needed vs directly initializing resultShape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it isn't needed here - updated!

@mtsokol mtsokol force-pushed the shape-broadcast-fold-vararg branch from c3cf613 to b259199 Compare February 19, 2025 16:57
@mtsokol mtsokol requested a review from jpienaar February 20, 2025 09:34
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, modulo checking formatting.

return nullptr;

resultShape.clear();
std::copy(tmpShape.begin(), tmpShape.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this what clang-format produced?

Copy link
Contributor Author

@mtsokol mtsokol Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpienaar Yes, that's correct - it was produced by a clang-format. Here's another place where std::copy is formatted the same way:

std::copy(Overrides.begin(), Overrides.end(),
reinterpret_cast<ModuleMacro **>(this + 1));

@mtsokol
Copy link
Contributor Author

mtsokol commented Mar 13, 2025

@jpienaar Should we merge it?

@mtsokol
Copy link
Contributor Author

mtsokol commented Apr 3, 2025

I think it's ready, right?

@makslevental makslevental merged commit df84aa8 into llvm:main Apr 14, 2025
8 checks passed
@mtsokol mtsokol deleted the shape-broadcast-fold-vararg branch April 14, 2025 19:11
bcardosolopes added a commit to bcardosolopes/llvm-project that referenced this pull request Apr 14, 2025
* origin/main: (199 commits)
  [NFC][AsmPrinter] Refactor AsmPrinter and AArch64AsmPrinter to prepare for jump table partitions on aarch64 (llvm#125993)
  [HEXAGON] Fix corner cases for hwloops pass (llvm#135439)
  [flang] Handle volatility in lowering and codegen (llvm#135311)
  [MLIR][Shape] Support >2 args in `shape.broadcast` folder (llvm#126808)
  [DirectX] Use scalar arguments for @llvm.dx.dot intrinsics (llvm#134570)
  Remove the redundant check for "WeakPtr" in isSmartPtrClass to fix the issue 135612. (llvm#135629)
  [BOLT] Support relative vtable (llvm#135449)
  [flang] Fix linking to libMLIR (llvm#135483)
  [AsmPrinter] Link .section_sizes to the correct section (llvm#135583)
  [ctxprof] Handle instrumenting functions with `musttail` calls (llvm#135121)
  [SystemZ] Consider VST/VL as SimpleBDXStore/Load (llvm#135623)
  [libc++][CI] Pin the XCode version. (llvm#135412)
  [lldb-dap] Fix win32 build. (llvm#135638)
  [Interp] Mark inline-virtual.cpp as unsupported with ASan (llvm#135402)
  [libc++] Removes the _LIBCPP_VERBOSE_ABORT_NOT_NOEXCEPT macro. (llvm#135494)
  [CVP] Add tests for ucmp/scmp with switch (NFC)
  [mlir][tosa] Align AbsOp example variable names (llvm#135268)
  [mlir][tosa] Align AddOp examples to spec (llvm#135266)
  [mlir][tosa] Align RFFT2d and FFT2d operator examples (llvm#135261)
  [flang][OpenMP][HLFIR] Support vector subscripted array sections for DEPEND (llvm#133892)
  ...
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
Hi!

As the title says, this PR adds support for >2 arguments in
`shape.broadcast` folder by sequentially calling `getBroadcastedShape`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants