@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4837
4837
// Already a constant
4838
4838
newMixedTileSizes.push_back (std::get<1 >(it));
4839
4839
} else {
4840
- int64_t tileSize = getConstantIntValue (std::get<1 >(it)).value ();
4841
- assert (tileSize == shape && " tile size and dim size don't match!" );
4842
- (void )tileSize;
4840
+ assert (getConstantIntValue (std::get<1 >(it)).value () == shape &&
4841
+ " tile size and dim size don't match!" );
4843
4842
newMixedTileSizes.push_back (
4844
4843
(rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4845
4844
}
4846
4845
}
4847
4846
4848
4847
// Clone op.
4848
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4849
+ // this point. However, in practice, we use them for things that we'd like
4850
+ // to preserve. Implement a better abstraction.
4849
4851
PackOp newOp = rewriter.create <PackOp>(
4850
4852
op.getLoc (), newOperands[0 ], newOperands[1 ], op.getInnerDimsPos (),
4851
4853
newMixedTileSizes, op.getPaddingValue (), op.getOuterDimsPerm ());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4865
4867
}
4866
4868
};
4867
4869
4870
+ // / Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4871
+ // / `tensor.cast` has source that is more static than the consuming op.
4872
+ // /
4873
+ // / Example:
4874
+ // / ```mlir
4875
+ // / %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4876
+ // / %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4877
+ // / ```
4878
+ // /
4879
+ // / folds into:
4880
+ // /
4881
+ // / ```mlir
4882
+ // / %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4883
+ // / ```
4884
+ struct FoldTensorCastUnPackOp : public OpRewritePattern <UnPackOp> {
4885
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
4886
+
4887
+ LogicalResult matchAndRewrite (UnPackOp op,
4888
+ PatternRewriter &rewriter) const override {
4889
+ if (!foldTensorCastPrecondition (op))
4890
+ return failure ();
4891
+
4892
+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4893
+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4894
+ Value sourceTensor = newOperands[0 ];
4895
+
4896
+ // Get the updated mixed-tile-sizes attribute.
4897
+ SmallVector<OpFoldResult> newMixedTileSizes;
4898
+ for (auto it : llvm::zip (cast<ShapedType>(sourceTensor.getType ())
4899
+ .getShape ()
4900
+ .take_back (op.getMixedTiles ().size ()),
4901
+ op.getMixedTiles ())) {
4902
+ int64_t shape = std::get<0 >(it);
4903
+ // If the current source shape is dynamic, just preserve this mixed
4904
+ // size.
4905
+ if (shape == ShapedType::kDynamic ) {
4906
+ newMixedTileSizes.push_back (std::get<1 >(it));
4907
+ continue ;
4908
+ }
4909
+
4910
+ // If the current source is static, update the dynamic mixed-size
4911
+ // (provided the original value is dynamic).
4912
+ if (Attribute attr =
4913
+ llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4914
+ // Already a constant
4915
+ newMixedTileSizes.push_back (std::get<1 >(it));
4916
+ } else {
4917
+ assert (getConstantIntValue (std::get<1 >(it)).value () == shape &&
4918
+ " tile size and dim size don't match!" );
4919
+ newMixedTileSizes.push_back (
4920
+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4921
+ }
4922
+ }
4923
+
4924
+ // Clone op.
4925
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4926
+ // this point. However, in practice, we use them for things that we'd like
4927
+ // to preserve. Implement a better abstraction.
4928
+ UnPackOp newOp = rewriter.create <UnPackOp>(
4929
+ op.getLoc (), sourceTensor, newOperands[1 ], op.getInnerDimsPos (),
4930
+ newMixedTileSizes, op.getOuterDimsPerm ());
4931
+ newOp->setDiscardableAttrs (op->getDiscardableAttrDictionary ());
4932
+
4933
+ // Replace op.
4934
+ Value oldResult = op.getResult ();
4935
+ Value newResult = newOp.getResult ();
4936
+ Value replacement = (newResult.getType () != oldResult.getType ())
4937
+ ? rewriter.create <tensor::CastOp>(
4938
+ op->getLoc (), oldResult.getType (), newResult)
4939
+ : newResult;
4940
+
4941
+ rewriter.replaceOp (op, {replacement});
4942
+
4943
+ return success ();
4944
+ }
4945
+ };
4946
+
4868
4947
// / Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4869
4948
// / the `tensor.cast` has source that is more static than the consuming op.
4870
4949
// /
@@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp
4890
4969
PatternRewriter &rewriter) const override {
4891
4970
4892
4971
// Reject tensor::PackOp - there's dedicated pattern for that instead.
4893
- if (!foldTensorCastPrecondition (op) || dyn_cast<tensor::PackOp>(*op))
4972
+ if (!foldTensorCastPrecondition (op) ||
4973
+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
4894
4974
return failure ();
4895
4975
4896
4976
SmallVector<Type> newResultTypes (op->getResultTypes ());
@@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp
4923
5003
void TensorDialect::getCanonicalizationPatterns (
4924
5004
RewritePatternSet &results) const {
4925
5005
results.add <FoldTensorCastPackOp>(getContext ());
5006
+ results.add <FoldTensorCastUnPackOp>(getContext ());
4926
5007
results.add <FoldTensorCastProducerOp>(getContext ());
4927
5008
}
4928
5009
0 commit comments