Skip to content

[MLIR][Shape] Support >2 args in shape.broadcast folder #126808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return getShapes().front();
}

// TODO: Support folding with more than 2 input shapes
if (getShapes().size() > 2)
if (!adaptor.getShapes().front())
return nullptr;

if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
return nullptr;
auto lhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
.getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
SmallVector<int64_t, 6> resultShape(
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
.getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;

// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
return nullptr;
for (auto next : adaptor.getShapes().drop_front()) {
if (!next)
return nullptr;
auto nextShape = llvm::to_vector<6>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the getBroadcastedShape implementation shape vector size is hardcoded to 6, so I did it similarly here. Does it make sense? Looks like an arbitrary value from the outside.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).

llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());

SmallVector<int64_t, 6> tmpShape;
// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
return nullptr;

resultShape.clear();
std::copy(tmpShape.begin(), tmpShape.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this what clang-format produced?

Copy link
Contributor Author

@mtsokol mtsokol Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpienaar Yes, that's correct - it was produced by a clang-format. Here's another place where std::copy is formatted the same way:

std::copy(Overrides.begin(), Overrides.end(),
reinterpret_cast<ModuleMacro **>(this + 1));

std::back_inserter(resultShape));
}

Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
// correct, and the other dimension will be broadcasted to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {

// -----

// Variadic case including extent tensors.
// CHECK-LABEL: @broadcast_variadic
func.func @broadcast_variadic() -> !shape.shape {
// CHECK: shape.const_shape [7, 2, 10] : !shape.shape
%0 = shape.const_shape [2, 1] : tensor<2xindex>
%1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
%2 = shape.const_shape [1, 10] : tensor<2xindex>
%3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
return %3 : !shape.shape
}

// -----

// Rhs is a scalar.
// CHECK-LABEL: func @f
func.func @f(%arg0 : !shape.shape) -> !shape.shape {
Expand Down