@@ -27,8 +27,8 @@ namespace {
27
27
#include " ShapeCanonicalization.inc"
28
28
}
29
29
30
- RankedTensorType shape::getExtentTensorType (MLIRContext *ctx) {
31
- return RankedTensorType::get ({ShapedType:: kDynamicSize }, IndexType::get (ctx));
30
+ RankedTensorType shape::getExtentTensorType (MLIRContext *ctx, int64_t rank ) {
31
+ return RankedTensorType::get ({rank }, IndexType::get (ctx));
32
32
}
33
33
34
34
bool shape::isExtentTensorType (Type type) {
@@ -660,11 +660,42 @@ struct CanonicalizeCastExtentTensorOperandsPattern
660
660
return success ();
661
661
}
662
662
};
663
+
664
+ struct BroadcastConcretizeResultTypePattern
665
+ : public OpRewritePattern<BroadcastOp> {
666
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
667
+
668
+ LogicalResult matchAndRewrite (BroadcastOp op,
669
+ PatternRewriter &rewriter) const override {
670
+ // Only concretize dynamic extent tensor result types.
671
+ auto resultTy = op.getType ().dyn_cast <RankedTensorType>();
672
+ if (!resultTy || !resultTy.isDynamicDim (0 ))
673
+ return failure ();
674
+
675
+ // Infer resulting shape rank if possible.
676
+ int64_t maxRank = 0 ;
677
+ for (Value shape : op.shapes ()) {
678
+ if (auto extentTensorTy = shape.getType ().dyn_cast <RankedTensorType>()) {
679
+ // Cannot infer resulting shape rank if any operand is dynamically
680
+ // ranked.
681
+ if (extentTensorTy.isDynamicDim (0 ))
682
+ return failure ();
683
+ maxRank = std::max (maxRank, extentTensorTy.getDimSize (0 ));
684
+ }
685
+ }
686
+
687
+ auto newOp = rewriter.create <BroadcastOp>(
688
+ op.getLoc (), getExtentTensorType (getContext (), maxRank), op.shapes ());
689
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, op.getType (), newOp);
690
+ return success ();
691
+ }
692
+ };
663
693
} // namespace
664
694
665
695
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
666
696
MLIRContext *context) {
667
- patterns.add <BroadcastFoldConstantOperandsPattern,
697
+ patterns.add <BroadcastConcretizeResultTypePattern,
698
+ BroadcastFoldConstantOperandsPattern,
668
699
BroadcastForwardSingleOperandPattern,
669
700
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
670
701
RemoveDuplicateOperandsPattern<BroadcastOp>,
0 commit comments