Skip to content

Commit 1bc2d8e

Browse files
committed
[mlir][tensor] Introduce FoldTensorCastUnPackOp
This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change mirrors a similar update made for `tensor::PackOp` in #114559. Below is the updated rationale for `tensor::UnPackOp`. Currently, `FoldTensorCastProducerOp` incorrectly folds the following: ```mlir %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` as: ```mlir %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` This leads to an Op verification failure because the folder does not update the inner tile sizes in the unpack Op. This patch resolves the issue. Additional Changes: * invalid.mlir: Fixes a typo. * TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra comments following this discussion: #115772.
1 parent 998bdae commit 1bc2d8e

File tree

3 files changed

+107
-5
lines changed

3 files changed

+107
-5
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48374837
// Already a constant
48384838
newMixedTileSizes.push_back(std::get<1>(it));
48394839
} else {
4840-
int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
4841-
assert(tileSize == shape && "tile size and dim size don't match!");
4842-
(void)tileSize;
4840+
assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
4841+
"tile size and dim size don't match!");
48434842
newMixedTileSizes.push_back(
48444843
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
48454844
}
48464845
}
48474846

48484847
// Clone op.
4848+
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
4849+
// this point. However, in practice, we use them for things that we'd like
4850+
// to preserve. Implement a better abstraction.
48494851
PackOp newOp = rewriter.create<PackOp>(
48504852
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
48514853
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48654867
}
48664868
};
48674869

4870+
/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4871+
/// `tensor.cast` has source that is more static than the consuming op.
4872+
///
4873+
/// Example:
4874+
/// ```mlir
4875+
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4876+
/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4877+
/// ```
4878+
///
4879+
/// folds into:
4880+
///
4881+
/// ```mlir
4882+
/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4883+
/// ```
4884+
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
4885+
using OpRewritePattern<UnPackOp>::OpRewritePattern;
4886+
4887+
LogicalResult matchAndRewrite(UnPackOp op,
4888+
PatternRewriter &rewriter) const override {
4889+
if (!foldTensorCastPrecondition(op))
4890+
return failure();
4891+
4892+
SmallVector<Type> newResultTypes(op->getResultTypes());
4893+
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4894+
Value sourceTensor = newOperands[0];
4895+
4896+
// Get the updated mixed-tile-sizes attribute.
4897+
SmallVector<OpFoldResult> newMixedTileSizes;
4898+
for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
4899+
.getShape()
4900+
.take_back(op.getMixedTiles().size()),
4901+
op.getMixedTiles())) {
4902+
int64_t shape = std::get<0>(it);
4903+
// If the current source shape is dynamic, just preserve this mixed
4904+
// size.
4905+
if (shape == ShapedType::kDynamic) {
4906+
newMixedTileSizes.push_back(std::get<1>(it));
4907+
continue;
4908+
}
4909+
4910+
// If the current source is static, update the dynamic mixed-size
4911+
// (provided the original value is dynamic).
4912+
if (Attribute attr =
4913+
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4914+
// Already a constant
4915+
newMixedTileSizes.push_back(std::get<1>(it));
4916+
} else {
4917+
assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
4918+
"tile size and dim size don't match!");
4919+
newMixedTileSizes.push_back(
4920+
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4921+
}
4922+
}
4923+
4924+
// Clone op.
4925+
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
4926+
// this point. However, in practice, we use them for things that we'd like
4927+
// to preserve. Implement a better abstraction.
4928+
UnPackOp newOp = rewriter.create<UnPackOp>(
4929+
op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
4930+
newMixedTileSizes, op.getOuterDimsPerm());
4931+
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4932+
4933+
// Replace op.
4934+
Value oldResult = op.getResult();
4935+
Value newResult = newOp.getResult();
4936+
Value replacement = (newResult.getType() != oldResult.getType())
4937+
? rewriter.create<tensor::CastOp>(
4938+
op->getLoc(), oldResult.getType(), newResult)
4939+
: newResult;
4940+
4941+
rewriter.replaceOp(op, {replacement});
4942+
4943+
return success();
4944+
}
4945+
};
4946+
48684947
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
48694948
/// the `tensor.cast` has source that is more static than the consuming op.
48704949
///
@@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp
48904969
PatternRewriter &rewriter) const override {
48914970

48924971
// Reject tensor::PackOp - there's dedicated pattern for that instead.
4893-
if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
4972+
if (!foldTensorCastPrecondition(op) ||
4973+
isa<tensor::PackOp, tensor::UnPackOp>(*op))
48944974
return failure();
48954975

48964976
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp
49235003
void TensorDialect::getCanonicalizationPatterns(
49245004
RewritePatternSet &results) const {
49255005
results.add<FoldTensorCastPackOp>(getContext());
5006+
results.add<FoldTensorCastUnPackOp>(getContext());
49265007
results.add<FoldTensorCastProducerOp>(getContext());
49275008
}
49285009

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
27862786
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
27872787
return %0#1 : index
27882788
}
2789+
27892790
// -----
27902791

27912792
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size(
28142815

28152816
// -----
28162817

2818+
// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
2819+
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
2820+
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
2821+
// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
2822+
// CHECK: return %[[RES]] : tensor<7x?xi32>
2823+
func.func @fold_cast_unpack_dynamic_tile_size(
2824+
%src: tensor<1x1x8x1xi32>,
2825+
%res: tensor<7x?xi32>) -> tensor<7x?xi32> {
2826+
2827+
%cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2828+
%c8 = arith.constant 8 : index
2829+
%unpack = tensor.unpack %cast
2830+
inner_dims_pos = [0, 1]
2831+
inner_tiles = [%c8, 1]
2832+
into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
2833+
return %unpack : tensor<7x?xi32>
2834+
}
2835+
2836+
// -----
2837+
28172838
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
28182839
// CHECK: tensor.pack {{.*}} {test_attr}
28192840
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
699699

700700
// -----
701701

702-
func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
702+
func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
703703
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
704704
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
705705
return %0 : tensor<256x128xf32>

0 commit comments

Comments
 (0)