Skip to content

Commit adf838d

Browse files
authored
[mlir][Vectorizer] Added support to Vectorize tensor.unpack (#76087)
Added support to vectorized tensor.unpack. The unpack Op is split into a `vector.transfer_read`, `vector.transpose`, `vector.shape_cast` and a `vector.transfer_write`.
1 parent 5a20a20 commit adf838d

File tree

6 files changed

+337
-54
lines changed

6 files changed

+337
-54
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@ FailureOr<RankedTensorType>
3232
computeTransposedType(RankedTensorType rankedTensorType,
3333
ArrayRef<int64_t> transposeVector);
3434

35-
/// Given a tensor::PackOp, compute the permutation vector to shuffle the
36-
/// packed shape into the shape before any outer or inner permutations have
37-
/// been applied.
38-
/// i.e. for a pack from an ABCD layout to an ABCDba:
39-
/// The packed shape would be ABCDba.
40-
/// The pre-permutation shape would be AaBbCD.
41-
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
35+
SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
36+
SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
37+
38+
SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
39+
PackingMetadata &metadata);
4240

4341
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
4442
/// source tensor or inserts the source tensor into a destination tensor with

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
31523152

31533153
// TODO: Check that the correct number of vectorSizes was provided.
31543154
for (Operation *target : targets) {
3155-
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
3155+
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3156+
target)) {
31563157
return mlir::emitSilenceableFailure(target->getLoc())
31573158
<< "Unsupported Op, cannot vectorize";
31583159
}

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
237237
PackingMetadata packingMetadata = computePackingMetadata(
238238
packedTensorType.getRank(), packOp.getInnerDimsPos());
239239
SmallVector<int64_t> packedToStripMinedShapePerm =
240-
tensor::getPackInverseDestPermutation(packOp);
240+
tensor::getPackInverseDestPerm(packOp);
241241

242242
// 3. Compute the stripMinedShape: this is the packed shape before any outer
243243
// or inner permutations have been applied.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 157 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14051405
/// permutations.
14061406
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14071407
ArrayRef<int64_t> destShape) {
1408-
return applyPermutation(destShape,
1409-
tensor::getPackInverseDestPermutation(packOp));
1408+
return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
14101409
}
14111410

14121411
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15471546

15481547
// Create TransposeOp.
15491548
auto destPermutation =
1550-
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
1549+
invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
15511550
auto transposeOp = rewriter.create<vector::TransposeOp>(
15521551
loc, shapeCastOp.getResult(), destPermutation);
15531552

@@ -1559,6 +1558,112 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15591558
return success();
15601559
}
15611560

1561+
/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1562+
/// Vector::TransferReadOp - Reads a vector from the source tensor
1563+
/// vector::TransposeOp - Transpose the Source tensor
1564+
/// ShapeCastOp - Reshape the data based on the target.
1565+
/// vector::TransferWriteOp. - Write the result vector back to the destination
1566+
/// tensor
1567+
static LogicalResult
1568+
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1569+
ArrayRef<int64_t> inputVectorSizes,
1570+
SmallVectorImpl<Value> &newResults) {
1571+
1572+
OpBuilder::InsertionGuard g(rewriter);
1573+
rewriter.setInsertionPoint(unpackOp);
1574+
1575+
RankedTensorType unpackTensorType = unpackOp.getSourceType();
1576+
1577+
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1578+
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1579+
1580+
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
1581+
inputVectorSizes.end());
1582+
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1583+
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1584+
1585+
// ReadMask is the size of tensor used to read and apply mask. It is
1586+
// set like this: Let's say the vectorSize (VS) array is size 'N' and
1587+
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1588+
// size M-N
1589+
// Thus:
1590+
// - initially: ReadMaskShape = vectorInputSizes
1591+
// - Divide all the readMaskShape locations pointed by innerDimPos
1592+
// by the innerTileSize attribute value.
1593+
// - if outer_dims_perms is present: do that permutation on readMaskShape.
1594+
// - Append the remaining shape from SS
1595+
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1596+
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1597+
// 128] and outer_dims_perm is [1, 0] then read shape is:
1598+
// ReadMaskShape(initial): [512, 128]
1599+
// Final Value(after innerDim Adjustment): [512/32, 128/16]
1600+
// = [16, 8]
1601+
// After applying outer_dims_perm: [8, 16]
1602+
// After appending the rest of the sourceShape: [8, 16, 32, 16]
1603+
1604+
for (auto [index, size] : enumerate(innerTiles)) {
1605+
readMaskShape[innerDimPos[index]] =
1606+
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
1607+
}
1608+
if (!outerDimsPerm.empty()) {
1609+
applyPermutationToVector(readMaskShape, outerDimsPerm);
1610+
}
1611+
readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
1612+
sourceShape.end());
1613+
1614+
ReifiedRankedShapedTypeDims reifiedRetShapes;
1615+
LogicalResult status =
1616+
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1617+
.reifyResultShapes(rewriter, reifiedRetShapes);
1618+
if (status.failed()) {
1619+
LDBG("Unable to reify result shapes of " << unpackOp);
1620+
return failure();
1621+
}
1622+
Location loc = unpackOp->getLoc();
1623+
1624+
auto padValue = rewriter.create<arith::ConstantOp>(
1625+
loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1626+
1627+
// Read result, mask if necessary. If transferReadOp shape is not equal
1628+
// to shape of source, then a mask is necessary.
1629+
Value readResult = createReadOrMaskedRead(
1630+
rewriter, loc, unpackOp.getSource(),
1631+
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
1632+
1633+
PackingMetadata packMetadata;
1634+
SmallVector<int64_t> lastDimToInsertPosPerm =
1635+
tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1636+
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1637+
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1638+
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1639+
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1640+
RankedTensorType stripMineTensorType =
1641+
RankedTensorType::get(stripMineShape, stripMineElemType);
1642+
// Transpose the appropriate rows to match output.
1643+
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1644+
loc, readResult, lastDimToInsertPosPerm);
1645+
1646+
// Collapse the vector to the size required by result.
1647+
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1648+
stripMineTensorType, packMetadata.reassociations);
1649+
mlir::VectorType vecCollapsedType =
1650+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1651+
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1652+
loc, vecCollapsedType, transposeOp->getResult(0));
1653+
1654+
// WriteMaskShape had to match the shapecast shape for dynamic sizes,
1655+
// otherwise the validator complains that the mask size is invalid.
1656+
SmallVector<int64_t> writeMaskShape(
1657+
unpackOp.getDestType().hasStaticShape()
1658+
? inputVectorSizes
1659+
: shapeCastOp.getResultVectorType().getShape());
1660+
Operation *write =
1661+
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
1662+
reifiedRetShapes[0], writeMaskShape);
1663+
newResults.push_back(write->getResult(0));
1664+
return success();
1665+
}
1666+
15621667
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
15631668
/// and (3) all-zero lowPad to
15641669
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1655,6 +1760,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
16551760
return success();
16561761
}
16571762

1763+
/// Need to check if the inner-tiles are static/constant.
1764+
static LogicalResult
1765+
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1766+
ArrayRef<int64_t> inputVectorSizes) {
1767+
1768+
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1769+
return !getConstantIntValue(res).has_value();
1770+
})) {
1771+
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1772+
return failure();
1773+
}
1774+
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1775+
if (!inputVectorSizes.empty() &&
1776+
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
1777+
return failure();
1778+
1779+
return success();
1780+
}
1781+
16581782
static LogicalResult
16591783
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
16601784
ArrayRef<int64_t> inputVectorSizes,
@@ -1703,9 +1827,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
17031827
}
17041828
if (isElementwise(linalgOp))
17051829
return success();
1706-
// TODO: isaConvolutionOpInterface that can also infer from generic features.
1707-
// But we will still need stride/dilation attributes that will be annoying to
1708-
// reverse-engineer...
1830+
1831+
// TODO: isaConvolutionOpInterface that can also infer from generic
1832+
// features. But we will still need stride/dilation attributes that will be
1833+
// annoying to reverse-engineer...
17091834
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
17101835
return success();
17111836
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -1810,6 +1935,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
18101935
.Case<tensor::PackOp>([&](auto packOp) {
18111936
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
18121937
})
1938+
.Case<tensor::UnPackOp>([&](auto unpackOp) {
1939+
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
1940+
})
18131941
.Default([](auto) { return failure(); });
18141942
}
18151943

@@ -1829,11 +1957,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
18291957
}
18301958

18311959
/// Emit a suitable vector form for an operation. If provided,
1832-
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
1833-
/// must match the rank of the iteration space of the operation and the input
1834-
/// vector sizes must be greater than or equal to their counterpart iteration
1835-
/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
1836-
/// operations with dynamic shapes.
1960+
/// `inputVectorSizes` are used to vectorize this operation.
1961+
/// `inputVectorSizes` must match the rank of the iteration space of the
1962+
/// operation and the input vector sizes must be greater than or equal to
1963+
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
1964+
/// also allows the vectorization of operations with dynamic shapes.
18371965
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
18381966
ArrayRef<int64_t> inputVectorSizes,
18391967
ArrayRef<bool> inputScalableVecDims,
@@ -1867,8 +1995,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
18671995
auto vectorizeResult =
18681996
TypeSwitch<Operation *, LogicalResult>(op)
18691997
.Case<linalg::LinalgOp>([&](auto linalgOp) {
1870-
// TODO: isaConvolutionOpInterface that can also infer from generic
1871-
// features. Will require stride/dilation attributes inference.
1998+
// TODO: isaConvolutionOpInterface that can also infer from
1999+
// generic features. Will require stride/dilation attributes
2000+
// inference.
18722001
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
18732002
FailureOr<Operation *> convOr = vectorizeConvolution(
18742003
rewriter, linalgOp, flatten1DDepthwiseConv);
@@ -1902,6 +2031,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
19022031
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
19032032
results);
19042033
})
2034+
.Case<tensor::UnPackOp>([&](auto unpackOp) {
2035+
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2036+
inputVectorSizes, results);
2037+
})
19052038
.Default([](auto) { return failure(); });
19062039

19072040
if (failed(vectorizeResult)) {
@@ -1919,7 +2052,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
19192052

19202053
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
19212054
memref::CopyOp copyOp) {
1922-
19232055
auto srcType = cast<MemRefType>(copyOp.getSource().getType());
19242056
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
19252057
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
@@ -2833,8 +2965,8 @@ struct Conv1DGenerator
28332965
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
28342966
resPadding);
28352967

2836-
// The base vectorization case for channeled convolution is input: {n,w,c},
2837-
// weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2968+
// The base vectorization case for channeled convolution is input:
2969+
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
28382970
// vectorization case, we do pre transpose on input, weight, and output.
28392971
switch (conv1DOpOrder) {
28402972
case Conv1DOpOrder::W:
@@ -2877,9 +3009,9 @@ struct Conv1DGenerator
28773009
return kw * (wSize / wSizeStep) + w;
28783010
};
28793011

2880-
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
2881-
// perform outerproduct for non-channeled convolution or
2882-
// perform simple arith operation for pooling
3012+
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3013+
// or perform outerproduct for non-channeled convolution or perform simple
3014+
// arith operation for pooling
28833015
for (int64_t kw = 0; kw < kwSize; ++kw) {
28843016
for (int64_t w = 0; w < wSize; w += wSizeStep) {
28853017
switch (oper) {
@@ -2908,9 +3040,9 @@ struct Conv1DGenerator
29083040
// End vector-only rewrite part
29093041
//===------------------------------------------------------------------===//
29103042

2911-
// The base vectorization case for channeled convolution is output: {n,w,f}
2912-
// To reuse the result from base pattern vectorization case, we post
2913-
// transpose the base case result.
3043+
// The base vectorization case for channeled convolution is output:
3044+
// {n,w,f} To reuse the result from base pattern vectorization case, we
3045+
// post transpose the base case result.
29143046
switch (conv1DOpOrder) {
29153047
case Conv1DOpOrder::W:
29163048
case Conv1DOpOrder::Nwc:
@@ -3348,9 +3480,9 @@ static FailureOr<Operation *>
33483480
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
33493481
bool flatten1DDepthwiseConv) {
33503482
// The ConvolutionOpInterface gives us guarantees of existence for
3351-
// strides/dilations. However, we do not need to rely on those, we can simply
3352-
// use them if present, otherwise use the default and let the generic conv.
3353-
// matcher in the ConvGenerator succeed or fail.
3483+
// strides/dilations. However, we do not need to rely on those, we can
3484+
// simply use them if present, otherwise use the default and let the generic
3485+
// conv. matcher in the ConvGenerator succeed or fail.
33543486
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
33553487
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
33563488
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;

0 commit comments

Comments
 (0)