@@ -27,8 +27,8 @@ namespace {
27
27
#include " ShapeCanonicalization.inc"
28
28
}
29
29
30
- RankedTensorType shape::getExtentTensorType (MLIRContext *ctx, int64_t rank ) {
31
- return RankedTensorType::get ({rank }, IndexType::get (ctx));
30
+ RankedTensorType shape::getExtentTensorType (MLIRContext *ctx) {
31
+ return RankedTensorType::get ({ShapedType:: kDynamicSize }, IndexType::get (ctx));
32
32
}
33
33
34
34
bool shape::isExtentTensorType (Type type) {
@@ -660,42 +660,11 @@ struct CanonicalizeCastExtentTensorOperandsPattern
660
660
return success ();
661
661
}
662
662
};
663
-
664
- struct BroadcastConcretizeResultTypePattern
665
- : public OpRewritePattern<BroadcastOp> {
666
- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
667
-
668
- LogicalResult matchAndRewrite (BroadcastOp op,
669
- PatternRewriter &rewriter) const override {
670
- // Only concretize dynamic extent tensor result types.
671
- auto resultTy = op.getType ().dyn_cast <RankedTensorType>();
672
- if (!resultTy || !resultTy.isDynamicDim (0 ))
673
- return failure ();
674
-
675
- // Infer resulting shape rank if possible.
676
- int64_t maxRank = 0 ;
677
- for (Value shape : op.shapes ()) {
678
- if (auto extentTensorTy = shape.getType ().dyn_cast <RankedTensorType>()) {
679
- // Cannot infer resulting shape rank if any operand is dynamically
680
- // ranked.
681
- if (extentTensorTy.isDynamicDim (0 ))
682
- return failure ();
683
- maxRank = std::max (maxRank, extentTensorTy.getDimSize (0 ));
684
- }
685
- }
686
-
687
- auto newOp = rewriter.create <BroadcastOp>(
688
- op.getLoc (), getExtentTensorType (getContext (), maxRank), op.shapes ());
689
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, op.getType (), newOp);
690
- return success ();
691
- }
692
- };
693
663
} // namespace
694
664
695
665
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
696
666
MLIRContext *context) {
697
- patterns.add <BroadcastConcretizeResultTypePattern,
698
- BroadcastFoldConstantOperandsPattern,
667
+ patterns.add <BroadcastFoldConstantOperandsPattern,
699
668
BroadcastForwardSingleOperandPattern,
700
669
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
701
670
RemoveDuplicateOperandsPattern<BroadcastOp>,
0 commit comments