Skip to content

Commit dca5361

Browse files
committed
[MLIR][Shape] Concretize broadcast result type if possible
As a canonicalization, infer the resulting shape rank if possible. Differential Revision: https://reviews.llvm.org/D101377
1 parent 2d37f21 commit dca5361

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

mlir/include/mlir/Dialect/Shape/IR/Shape.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class PatternRewriter;
2929
namespace shape {
3030

3131
/// Alias type for extent tensors.
32-
RankedTensorType getExtentTensorType(MLIRContext *ctx);
32+
RankedTensorType getExtentTensorType(MLIRContext *ctx,
33+
int64_t rank = ShapedType::kDynamicSize);
3334

3435
// Check if a type is an extent tensor, e.g., tensor<?xindex>.
3536
bool isExtentTensorType(Type);

mlir/lib/Dialect/Shape/IR/Shape.cpp

+34-3
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace {
2727
#include "ShapeCanonicalization.inc"
2828
}
2929

30-
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31-
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
30+
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
31+
return RankedTensorType::get({rank}, IndexType::get(ctx));
3232
}
3333

3434
bool shape::isExtentTensorType(Type type) {
@@ -660,11 +660,42 @@ struct CanonicalizeCastExtentTensorOperandsPattern
660660
return success();
661661
}
662662
};
663+
664+
struct BroadcastConcretizeResultTypePattern
665+
: public OpRewritePattern<BroadcastOp> {
666+
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
667+
668+
LogicalResult matchAndRewrite(BroadcastOp op,
669+
PatternRewriter &rewriter) const override {
670+
// Only concretize dynamic extent tensor result types.
671+
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
672+
if (!resultTy || !resultTy.isDynamicDim(0))
673+
return failure();
674+
675+
// Infer resulting shape rank if possible.
676+
int64_t maxRank = 0;
677+
for (Value shape : op.shapes()) {
678+
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
679+
// Cannot infer resulting shape rank if any operand is dynamically
680+
// ranked.
681+
if (extentTensorTy.isDynamicDim(0))
682+
return failure();
683+
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
684+
}
685+
}
686+
687+
auto newOp = rewriter.create<BroadcastOp>(
688+
op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
689+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
690+
return success();
691+
}
692+
};
663693
} // namespace
664694

665695
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
666696
MLIRContext *context) {
667-
patterns.add<BroadcastFoldConstantOperandsPattern,
697+
patterns.add<BroadcastConcretizeResultTypePattern,
698+
BroadcastFoldConstantOperandsPattern,
668699
BroadcastForwardSingleOperandPattern,
669700
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
670701
RemoveDuplicateOperandsPattern<BroadcastOp>,

mlir/test/Dialect/Shape/canonicalize.mlir

+16-1
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,8 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
13441344
%arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
13451345
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
13461346
// CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
1347-
// CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
1347+
// CHECK: %[[UNCAST_RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
1348+
// CHECK: %[[RES:.*]] = tensor.cast %[[UNCAST_RES]] : tensor<3xindex> to tensor<?xindex>
13481349
// CHECK: return %[[WIT]], %[[RES]]
13491350
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
13501351
%1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
@@ -1353,3 +1354,17 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
13531354
-> tensor<?xindex>
13541355
return %2, %3 : !shape.witness, tensor<?xindex>
13551356
}
1357+
1358+
// -----
1359+
1360+
// CHECK-LABEL: @concretize_broadcast_result_type
1361+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xindex>, %[[ARG1:.*]]: tensor<3xindex>)
1362+
func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
1363+
%arg1 : tensor<3xindex>) -> tensor<?xindex> {
1364+
// CHECK: %[[CONCR:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<2xindex>, tensor<3xindex> -> tensor<3xindex>
1365+
// CHECK: %[[RES:.*]] = tensor.cast %[[CONCR]] : tensor<3xindex> to tensor<?xindex>
1366+
// CHECK: return %[[RES]]
1367+
%0 = shape.broadcast %arg0, %arg1 : tensor<2xindex>, tensor<3xindex>
1368+
-> tensor<?xindex>
1369+
return %0 : tensor<?xindex>
1370+
}

0 commit comments

Comments
 (0)