@@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1405
1405
// / permutations.
1406
1406
static SmallVector<int64_t > getTiledPackShape (tensor::PackOp packOp,
1407
1407
ArrayRef<int64_t > destShape) {
1408
- return applyPermutation (destShape,
1409
- tensor::getPackInverseDestPermutation (packOp));
1408
+ return applyPermutation (destShape, tensor::getPackInverseDestPerm (packOp));
1410
1409
}
1411
1410
1412
1411
// / Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1547
1546
1548
1547
// Create TransposeOp.
1549
1548
auto destPermutation =
1550
- invertPermutationVector (tensor::getPackInverseDestPermutation (packOp));
1549
+ invertPermutationVector (tensor::getPackInverseDestPerm (packOp));
1551
1550
auto transposeOp = rewriter.create <vector::TransposeOp>(
1552
1551
loc, shapeCastOp.getResult (), destPermutation);
1553
1552
@@ -1559,6 +1558,112 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1559
1558
return success ();
1560
1559
}
1561
1560
1561
+ // / Vectorize a `tensor::UnPackOp` to these 4 Ops:
1562
+ // / Vector::TransferReadOp - Reads a vector from the source tensor
1563
+ // / vector::TransposeOp - Transpose the Source tensor
1564
+ // / ShapeCastOp - Reshape the data based on the target.
1565
+ // / vector::TransferWriteOp. - Write the result vector back to the destination
1566
+ // / tensor
1567
+ static LogicalResult
1568
+ vectorizeAsTensorUnpackOp (RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1569
+ ArrayRef<int64_t > inputVectorSizes,
1570
+ SmallVectorImpl<Value> &newResults) {
1571
+
1572
+ OpBuilder::InsertionGuard g (rewriter);
1573
+ rewriter.setInsertionPoint (unpackOp);
1574
+
1575
+ RankedTensorType unpackTensorType = unpackOp.getSourceType ();
1576
+
1577
+ ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1578
+ ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1579
+
1580
+ SmallVector<int64_t > readMaskShape (inputVectorSizes.begin (),
1581
+ inputVectorSizes.end ());
1582
+ ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1583
+ ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1584
+
1585
+ // ReadMask is the size of tensor used to read and apply mask. It is
1586
+ // set like this: Let's say the vectorSize (VS) array is size 'N' and
1587
+ // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1588
+ // size M-N
1589
+ // Thus:
1590
+ // - initially: ReadMaskShape = vectorInputSizes
1591
+ // - Divide all the readMaskShape locations pointed by innerDimPos
1592
+ // by the innerTileSize attribute value.
1593
+ // - if outer_dims_perms is present: do that permutation on readMaskShape.
1594
+ // - Append the remaining shape from SS
1595
+ // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1596
+ // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1597
+ // 128] and outer_dims_perm is [1, 0] then read shape is:
1598
+ // ReadMaskShape(initial): [512, 128]
1599
+ // Final Value(after innerDim Adjustment): [512/32, 128/16]
1600
+ // = [16, 8]
1601
+ // After applying outer_dims_perm: [8, 16]
1602
+ // After appending the rest of the sourceShape: [8, 16, 32, 16]
1603
+
1604
+ for (auto [index , size] : enumerate(innerTiles)) {
1605
+ readMaskShape[innerDimPos[index ]] =
1606
+ llvm::divideCeil (readMaskShape[innerDimPos[index ]], size);
1607
+ }
1608
+ if (!outerDimsPerm.empty ()) {
1609
+ applyPermutationToVector (readMaskShape, outerDimsPerm);
1610
+ }
1611
+ readMaskShape.append (sourceShape.begin () + inputVectorSizes.size (),
1612
+ sourceShape.end ());
1613
+
1614
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
1615
+ LogicalResult status =
1616
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation ())
1617
+ .reifyResultShapes (rewriter, reifiedRetShapes);
1618
+ if (status.failed ()) {
1619
+ LDBG (" Unable to reify result shapes of " << unpackOp);
1620
+ return failure ();
1621
+ }
1622
+ Location loc = unpackOp->getLoc ();
1623
+
1624
+ auto padValue = rewriter.create <arith::ConstantOp>(
1625
+ loc, rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
1626
+
1627
+ // Read result, mask if necessary. If transferReadOp shape is not equal
1628
+ // to shape of source, then a mask is necessary.
1629
+ Value readResult = createReadOrMaskedRead (
1630
+ rewriter, loc, unpackOp.getSource (),
1631
+ ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue);
1632
+
1633
+ PackingMetadata packMetadata;
1634
+ SmallVector<int64_t > lastDimToInsertPosPerm =
1635
+ tensor::getUnPackInverseSrcPerm (unpackOp, packMetadata);
1636
+ ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType ());
1637
+ SmallVector<int64_t > stripMineShape (maskedOpShapedType.getShape ());
1638
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType ();
1639
+ applyPermutationToVector (stripMineShape, lastDimToInsertPosPerm);
1640
+ RankedTensorType stripMineTensorType =
1641
+ RankedTensorType::get (stripMineShape, stripMineElemType);
1642
+ // Transpose the appropriate rows to match output.
1643
+ vector::TransposeOp transposeOp = rewriter.create <vector::TransposeOp>(
1644
+ loc, readResult, lastDimToInsertPosPerm);
1645
+
1646
+ // Collapse the vector to the size required by result.
1647
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
1648
+ stripMineTensorType, packMetadata.reassociations );
1649
+ mlir::VectorType vecCollapsedType =
1650
+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1651
+ vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
1652
+ loc, vecCollapsedType, transposeOp->getResult (0 ));
1653
+
1654
+ // WriteMaskShape had to match the shapecast shape for dynamic sizes,
1655
+ // otherwise the validator complains that the mask size is invalid.
1656
+ SmallVector<int64_t > writeMaskShape (
1657
+ unpackOp.getDestType ().hasStaticShape ()
1658
+ ? inputVectorSizes
1659
+ : shapeCastOp.getResultVectorType ().getShape ());
1660
+ Operation *write =
1661
+ createWriteOrMaskedWrite (rewriter, loc, shapeCastOp.getResult (),
1662
+ reifiedRetShapes[0 ], writeMaskShape);
1663
+ newResults.push_back (write ->getResult (0 ));
1664
+ return success ();
1665
+ }
1666
+
1562
1667
// / Vectorize a `padOp` with (1) static result type, (2) constant padding value
1563
1668
// / and (3) all-zero lowPad to
1564
1669
// / `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1655,6 +1760,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
1655
1760
return success ();
1656
1761
}
1657
1762
1763
+ // / Need to check if the inner-tiles are static/constant.
1764
+ static LogicalResult
1765
+ vectorizeUnPackOpPrecondition (tensor::UnPackOp unpackOp,
1766
+ ArrayRef<int64_t > inputVectorSizes) {
1767
+
1768
+ if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
1769
+ return !getConstantIntValue (res).has_value ();
1770
+ })) {
1771
+ LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
1772
+ return failure ();
1773
+ }
1774
+ llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1775
+ if (!inputVectorSizes.empty () &&
1776
+ failed (isValidMaskedInputVector (resultShape, inputVectorSizes)))
1777
+ return failure ();
1778
+
1779
+ return success ();
1780
+ }
1781
+
1658
1782
static LogicalResult
1659
1783
vectorizeLinalgOpPrecondition (LinalgOp linalgOp,
1660
1784
ArrayRef<int64_t > inputVectorSizes,
@@ -1703,9 +1827,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1703
1827
}
1704
1828
if (isElementwise (linalgOp))
1705
1829
return success ();
1706
- // TODO: isaConvolutionOpInterface that can also infer from generic features.
1707
- // But we will still need stride/dilation attributes that will be annoying to
1708
- // reverse-engineer...
1830
+
1831
+ // TODO: isaConvolutionOpInterface that can also infer from generic
1832
+ // features. But we will still need stride/dilation attributes that will be
1833
+ // annoying to reverse-engineer...
1709
1834
if (isa<ConvolutionOpInterface>(linalgOp.getOperation ()))
1710
1835
return success ();
1711
1836
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -1810,6 +1935,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
1810
1935
.Case <tensor::PackOp>([&](auto packOp) {
1811
1936
return vectorizePackOpPrecondition (packOp, inputVectorSizes);
1812
1937
})
1938
+ .Case <tensor::UnPackOp>([&](auto unpackOp) {
1939
+ return vectorizeUnPackOpPrecondition (unpackOp, inputVectorSizes);
1940
+ })
1813
1941
.Default ([](auto ) { return failure (); });
1814
1942
}
1815
1943
@@ -1829,11 +1957,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1829
1957
}
1830
1958
1831
1959
// / Emit a suitable vector form for an operation. If provided,
1832
- // / `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
1833
- // / must match the rank of the iteration space of the operation and the input
1834
- // / vector sizes must be greater than or equal to their counterpart iteration
1835
- // / space sizes, if static. `inputVectorShapes` also allows the vectorization of
1836
- // / operations with dynamic shapes.
1960
+ // / `inputVectorSizes` are used to vectorize this operation.
1961
+ // / `inputVectorSizes` must match the rank of the iteration space of the
1962
+ // / operation and the input vector sizes must be greater than or equal to
1963
+ // / their counterpart iteration space sizes, if static. `inputVectorShapes`
1964
+ // / also allows the vectorization of operations with dynamic shapes.
1837
1965
LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
1838
1966
ArrayRef<int64_t > inputVectorSizes,
1839
1967
ArrayRef<bool > inputScalableVecDims,
@@ -1867,8 +1995,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1867
1995
auto vectorizeResult =
1868
1996
TypeSwitch<Operation *, LogicalResult>(op)
1869
1997
.Case <linalg::LinalgOp>([&](auto linalgOp) {
1870
- // TODO: isaConvolutionOpInterface that can also infer from generic
1871
- // features. Will require stride/dilation attributes inference.
1998
+ // TODO: isaConvolutionOpInterface that can also infer from
1999
+ // generic features. Will require stride/dilation attributes
2000
+ // inference.
1872
2001
if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
1873
2002
FailureOr<Operation *> convOr = vectorizeConvolution (
1874
2003
rewriter, linalgOp, flatten1DDepthwiseConv);
@@ -1902,6 +2031,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1902
2031
return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
1903
2032
results);
1904
2033
})
2034
+ .Case <tensor::UnPackOp>([&](auto unpackOp) {
2035
+ return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2036
+ inputVectorSizes, results);
2037
+ })
1905
2038
.Default ([](auto ) { return failure (); });
1906
2039
1907
2040
if (failed (vectorizeResult)) {
@@ -1919,7 +2052,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1919
2052
1920
2053
LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
1921
2054
memref::CopyOp copyOp) {
1922
-
1923
2055
auto srcType = cast<MemRefType>(copyOp.getSource ().getType ());
1924
2056
auto dstType = cast<MemRefType>(copyOp.getTarget ().getType ());
1925
2057
if (!srcType.hasStaticShape () || !dstType.hasStaticShape ())
@@ -2833,8 +2965,8 @@ struct Conv1DGenerator
2833
2965
Value res = rewriter.create <vector::TransferReadOp>(loc, resType, resShaped,
2834
2966
resPadding);
2835
2967
2836
- // The base vectorization case for channeled convolution is input: {n,w,c},
2837
- // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2968
+ // The base vectorization case for channeled convolution is input:
2969
+ // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2838
2970
// vectorization case, we do pre transpose on input, weight, and output.
2839
2971
switch (conv1DOpOrder) {
2840
2972
case Conv1DOpOrder::W:
@@ -2877,9 +3009,9 @@ struct Conv1DGenerator
2877
3009
return kw * (wSize / wSizeStep) + w;
2878
3010
};
2879
3011
2880
- // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
2881
- // perform outerproduct for non-channeled convolution or
2882
- // perform simple arith operation for pooling
3012
+ // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3013
+ // or perform outerproduct for non-channeled convolution or perform simple
3014
+ // arith operation for pooling
2883
3015
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
2884
3016
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2885
3017
switch (oper) {
@@ -2908,9 +3040,9 @@ struct Conv1DGenerator
2908
3040
// End vector-only rewrite part
2909
3041
// ===------------------------------------------------------------------===//
2910
3042
2911
- // The base vectorization case for channeled convolution is output: {n,w,f}
2912
- // To reuse the result from base pattern vectorization case, we post
2913
- // transpose the base case result.
3043
+ // The base vectorization case for channeled convolution is output:
3044
+ // {n,w,f} To reuse the result from base pattern vectorization case, we
3045
+ // post transpose the base case result.
2914
3046
switch (conv1DOpOrder) {
2915
3047
case Conv1DOpOrder::W:
2916
3048
case Conv1DOpOrder::Nwc:
@@ -3348,9 +3480,9 @@ static FailureOr<Operation *>
3348
3480
vectorizeConvolution (RewriterBase &rewriter, LinalgOp op,
3349
3481
bool flatten1DDepthwiseConv) {
3350
3482
// The ConvolutionOpInterface gives us guarantees of existence for
3351
- // strides/dilations. However, we do not need to rely on those, we can simply
3352
- // use them if present, otherwise use the default and let the generic conv.
3353
- // matcher in the ConvGenerator succeed or fail.
3483
+ // strides/dilations. However, we do not need to rely on those, we can
3484
+ // simply use them if present, otherwise use the default and let the generic
3485
+ // conv. matcher in the ConvGenerator succeed or fail.
3354
3486
auto strides = op->getAttrOfType <DenseIntElementsAttr>(" strides" );
3355
3487
auto dilations = op->getAttrOfType <DenseIntElementsAttr>(" dilations" );
3356
3488
auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
0 commit comments