Skip to content

Commit 14cd832

Browse files
mtsokolvar-const
authored andcommitted
[MLIR][Shape] Support >2 args in shape.broadcast folder (llvm#126808)
Hi! As the title says, this PR adds support for >2 arguments in `shape.broadcast` folder by sequentially calling `getBroadcastedShape`.
1 parent 135a0d5 commit 14cd832

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
649649
return getShapes().front();
650650
}
651651

652-
// TODO: Support folding with more than 2 input shapes
653-
if (getShapes().size() > 2)
652+
if (!adaptor.getShapes().front())
654653
return nullptr;
655654

656-
if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
657-
return nullptr;
658-
auto lhsShape = llvm::to_vector<6>(
659-
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660-
.getValues<int64_t>());
661-
auto rhsShape = llvm::to_vector<6>(
662-
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
655+
SmallVector<int64_t, 6> resultShape(
656+
llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
663657
.getValues<int64_t>());
664-
SmallVector<int64_t, 6> resultShape;
665658

666-
// If the shapes are not compatible, we can't fold it.
667-
// TODO: Fold to an "error".
668-
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
669-
return nullptr;
659+
for (auto next : adaptor.getShapes().drop_front()) {
660+
if (!next)
661+
return nullptr;
662+
auto nextShape = llvm::to_vector<6>(
663+
llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
664+
665+
SmallVector<int64_t, 6> tmpShape;
666+
// If the shapes are not compatible, we can't fold it.
667+
// TODO: Fold to an "error".
668+
if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
669+
return nullptr;
670+
671+
resultShape.clear();
672+
std::copy(tmpShape.begin(), tmpShape.end(),
673+
std::back_inserter(resultShape));
674+
}
670675

671676
Builder builder(getContext());
672677
return builder.getIndexTensorAttr(resultShape);

mlir/lib/Dialect/Traits.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
8484
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
8585
// One or both dimensions is unknown. Follow TensorFlow behavior:
8686
// - If either dimension is greater than 1, we assume that the program is
87-
// correct, and the other dimension will be broadcast to match it.
87+
// correct, and the other dimension will be broadcasted to match it.
8888
// - If either dimension is 1, the other dimension is the output.
8989
if (*i1 > 1) {
9090
*iR = *i1;

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
8686

8787
// -----
8888

89+
// Variadic case including extent tensors.
90+
// CHECK-LABEL: @broadcast_variadic
91+
func.func @broadcast_variadic() -> !shape.shape {
92+
// CHECK: shape.const_shape [7, 2, 10] : !shape.shape
93+
%0 = shape.const_shape [2, 1] : tensor<2xindex>
94+
%1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
95+
%2 = shape.const_shape [1, 10] : tensor<2xindex>
96+
%3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
97+
return %3 : !shape.shape
98+
}
99+
100+
// -----
101+
89102
// Rhs is a scalar.
90103
// CHECK-LABEL: func @f
91104
func.func @f(%arg0 : !shape.shape) -> !shape.shape {

0 commit comments

Comments
 (0)