Skip to content

[MLIR][Tensor] Canonicalize fully covering slice insertions into tensors with unit prefixes #92912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2835,6 +2835,91 @@ struct InsertSliceOpSourceCastInserter final
return success();
}
};

/// If the destination tensor of the insertion of a slice has the same
/// number of elements as the slice, but with a shape that only
/// differs by a prefix of unit-sized dimensions, and if the insertion
/// happens at zero offsets, unit strides and with a size matching the
/// size of the destination, the insertion covers all elements of the
/// destination. The result of such an insertion is equivalent to the
/// slice, with its shape expanded to the type of the destination.
///
/// Example:
/// ```mlir
/// %0 = tensor.insert_slice %slice into
/// %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] :
/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
/// ```
///
/// folds into:
///
/// ```mlir
/// %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] :
/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
/// ```
struct InsertSliceOpFullRewriteCanonicalizer final
: public OpRewritePattern<InsertSliceOp> {
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType sourceType = insertSliceOp.getSourceType();
RankedTensorType resultType = insertSliceOp.getType();

if (sourceType != resultType && sourceType.hasStaticShape() &&
resultType.hasStaticShape() &&
isSameSizedSuffixShape(resultType.getShape(), sourceType.getShape()) &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(insertSliceOp,
resultType))) {
SmallVector<ReassociationIndices> reassocIndices;

// Number of leading dimensions with unit size that are not
// shared with the source type
size_t unitPrefixLength =
resultType.getShape().size() - sourceType.getShape().size();

// Compose mapping of leading dimensions with unit size and the
// fist common dimension to the first dimension of the source
// tensor
ReassociationIndices unitPrefixExpansion;

size_t dim;
for (dim = 0; dim < unitPrefixLength; dim++)
unitPrefixExpansion.push_back(dim);

unitPrefixExpansion.push_back(unitPrefixLength);
reassocIndices.push_back(unitPrefixExpansion);

// Map remaining common dimensions of the source to the target
for (dim = dim + 1; dim < resultType.getShape().size(); dim++) {
reassocIndices.push_back({static_cast<int64_t>(dim)});
}

rewriter.replaceOpWithNewOp<ExpandShapeOp>(
insertSliceOp, insertSliceOp.getType(), insertSliceOp.getSource(),
reassocIndices);

return mlir::success();
}

return mlir::failure();
}

private:
/// Checks if `suffix` is a suffix of `shape` and all preceding
/// elements in `shape` are ones.
static bool isSameSizedSuffixShape(ArrayRef<int64_t> shape,
ArrayRef<int64_t> suffix) {
if (shape.size() >= suffix.size()) {
ArrayRef<int64_t> prefix = shape.take_front(shape.size() - suffix.size());
ArrayRef<int64_t> remainder = shape.take_back(suffix.size());

return llvm::all_of(prefix, [](int64_t d) { return d == 1; }) &&
remainder == suffix;
}

return false;
}
};
} // namespace

llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
Expand All @@ -2845,7 +2930,8 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
InsertSliceOpCastFolder<InsertSliceOp>,
InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
InsertSliceOpSourceCastInserter<InsertSliceOp>,
InsertSliceOpFullRewriteCanonicalizer>(context);
}

Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6

// -----

// CHECK-LABEL: func @trivial_insert_slice_unit_prefix
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: tensor.insert_slice
// CHECK: %[[EXPANDED:.[a-z0-9A-Z_]+]] = tensor.expand_shape %[[ARG0]] {{\[\[0, 1, 2, 3\], \[4\], \[5\], \[6\]\] output}}_shape {{\[1, 1, 1, 4, 6, 16, 32\]}} : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8>
// CHECK: return %[[EXPANDED]] : tensor<1x1x1x4x6x16x32xi8>
func.func @trivial_insert_slice_unit_prefix(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<1x1x1x4x6x16x32xi8>) -> tensor<1x1x1x4x6x16x32xi8> {
%0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 4, 6, 16, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8>
return %0 : tensor<1x1x1x4x6x16x32xi8>
}

// -----

// CHECK-LABEL: func @empty_insert_slice
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
Expand Down
Loading