@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1506
1506
return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
1507
1507
}
1508
1508
1509
+ // / Determines whether the mask for a corresponding `vector.transfer_write` op
1510
+ // / is trivially foldable (i.e., guaranteed to be all true).
1511
+ // /
1512
+ // / Requirements:
1513
+ // / * All involved shapes (destination, mask) are static.
1514
+ // / * All write indices are constant.
1515
+ // / * All mask sizes are constant.
1516
+ // /
1517
+ // / Once verified, the method checks for each destination dimension `d`:
1518
+ // / (1) destDimSize[rankDiff + d] <= maskShape[d]
1519
+ // / (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1520
+ // /
1521
+ // / rankDiff = rank(dest) - rank(mask).
1522
+ // /
1523
+ // / This method takes a conservative view: it may return false even if the mask
1524
+ // / is technically foldable.
1525
+ // /
1526
+ // / EXAMPLE 1 (trivially foldable):
1527
+ // / %c0 = arith.constant 0 : index
1528
+ // / vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1529
+ // / {in_bounds = [true, true]}
1530
+ // / : vector<5x1xi32>, tensor<5x1xi32>
1531
+ // /
1532
+ // / EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
1533
+ // / %c0 = arith.constant 0 : index
1534
+ // / vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1535
+ // / {in_bounds = [true, true]}
1536
+ // / : vector<8x1xi32>, tensor<5x1xi32>
1537
+ // /
1538
+ // / TODO: Re-use in createReadOrMaskedRead
1539
+ static bool isMaskTriviallyFoldable (SmallVector<OpFoldResult> &maskSizes,
1540
+ SmallVector<Value> &writeIdxs,
1541
+ ArrayRef<int64_t > destShape,
1542
+ ArrayRef<int64_t > maskShape) {
1543
+ // Masking is unavoidable in the case of dynamic tensors.
1544
+ if (ShapedType::isDynamicShape (destShape))
1545
+ return false ;
1546
+
1547
+ // Collect all constant mask sizes.
1548
+ SmallVector<int64_t , 4 > cstMaskSizes;
1549
+ for (auto [i, dimSize] : llvm::enumerate (maskSizes)) {
1550
+ if (auto intSize = getConstantIntValue (dimSize)) {
1551
+ cstMaskSizes.push_back (*intSize);
1552
+ }
1553
+ }
1554
+
1555
+ // If any of the mask sizes is non-constant, bail out.
1556
+ if (cstMaskSizes.size () != maskShape.size ())
1557
+ return false ;
1558
+
1559
+ // Collect all constant write indices.
1560
+ SmallVector<int64_t , 4 > cstWriteIdxs;
1561
+ for (auto [i, idx] : llvm::enumerate (writeIdxs)) {
1562
+ APSInt intVal;
1563
+ if (matchPattern (idx, m_ConstantInt (&intVal))) {
1564
+ cstWriteIdxs.push_back (intVal.getSExtValue ());
1565
+ }
1566
+ }
1567
+
1568
+ // If any of the write indices is non-constant, bail out.
1569
+ if (cstWriteIdxs.size () != destShape.size ())
1570
+ return false ;
1571
+
1572
+ // Go over all destination dims and check (1) and (2). Take into account that:
1573
+ // * The number of mask sizes will match the rank of the vector to store.
1574
+ // This could be lower than the rank of the destination tensor.
1575
+ // * Mask sizes could be larger than the corresponding mask shape (hence
1576
+ // `clamp`).
1577
+ // TODO: The 2nd item should be rejected by the verifier.
1578
+ int64_t rankDiff = destShape.size () - cstMaskSizes.size ();
1579
+ for (auto [i, idx] : llvm::enumerate (cstMaskSizes)) {
1580
+ if (/* (1)*/ maskShape[i] > destShape[rankDiff + i] ||
1581
+ /* (2)*/ destShape[rankDiff + i] <
1582
+ (std::clamp (cstMaskSizes[i], int64_t (0 ), maskShape[i]) +
1583
+ cstWriteIdxs[i]))
1584
+ return false ;
1585
+ }
1586
+
1587
+ return true ;
1588
+ }
1589
+
1509
1590
// / Creates an optionally masked TransferWriteOp
1510
1591
// /
1511
1592
// / Generates the following operation:
1512
1593
// / %res = vector.transfer_write %vectorToStore into %dest
1513
1594
// /
1514
- // / If the leading N dimensions of the destination tensor do not match
1595
+ // / If the leading N dimensions of the vector to store do not match
1515
1596
// / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1516
1597
// / masking is applied to ensure correctness:
1517
1598
// /
1518
- // / %mask = vector.create_mask(%destShape)
1599
+ // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1519
1600
// / %res = vector.mask %mask {
1520
1601
// / vector.transfer_write %vectorToStore into %dest
1521
1602
// / }
1522
1603
// /
1604
+ // / The mask shape is identical to `vectorToStore` (with the element type ==
1605
+ // / i1), and the mask values are based on the shape of the `dest` tensor.
1606
+ // /
1523
1607
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1524
1608
// / is used instead of masking:
1525
1609
// /
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1528
1612
// / %res = vector.transfer_write %input into %dest
1529
1613
// / {in_bounds = in_bounds_flags}
1530
1614
// /
1531
- // / NOTE: All write offsets are set to 0.
1532
- // / TODO: Allow specyfying write offsets .
1533
- // / NOTE: When N < rank(input), the missing vector sizes are effectively
1534
- // / extracted from the trailing sizes of `destSizes`. This means those sizes
1535
- // / must be static .
1536
- // / TODO: Support cases where an arbitrary dim is dynamic - this will require
1537
- // / specifying all the vector sizes .
1615
+ // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1616
+ // / to 0 .
1617
+ // /
1618
+ // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1619
+ // / `valueToStore` .
1620
+ // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1621
+ // / already provided in `vectorToStore` .
1538
1622
static Operation *
1539
1623
createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1540
1624
Value dest,
1541
1625
ArrayRef<int64_t > inputVecSizesForLeadingDims,
1626
+ SmallVector<Value> writeIndices = {},
1542
1627
bool useInBoundsInsteadOfMasking = false ) {
1543
1628
1544
1629
ShapedType destType = cast<ShapedType>(dest.getType ());
1545
- assert (cast<VectorType>(vectorToStore.getType ()).getRank () ==
1546
- static_cast <int64_t >(destType.getRank ()) &&
1547
- " Rank mismatch!" );
1548
- (void )destType;
1630
+ int64_t destRank = destType.getRank ();
1631
+ auto destShape = destType.getShape ();
1549
1632
1550
- int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1551
- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1633
+ VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType ());
1634
+ int64_t vecToStoreRank = vecToStoreType.getRank ();
1635
+ auto vecToStoreShape = vecToStoreType.getShape ();
1552
1636
1553
1637
// Compute the in_bounds attribute
1554
- SmallVector<bool > inBoundsVal (rank , true );
1638
+ SmallVector<bool > inBoundsVal (vecToStoreRank , true );
1555
1639
if (useInBoundsInsteadOfMasking) {
1556
1640
// In this case, assume that all the required vector sizes have been
1557
1641
// provided.
1558
1642
assert (inputVecSizesForLeadingDims.size () ==
1559
- static_cast <size_t >(destType .getRank ()) &&
1643
+ static_cast <size_t >(vecToStoreType .getRank ()) &&
1560
1644
" Insufficient number of input vector sizes!" );
1561
1645
// Update the inBounds attribute.
1562
- for (unsigned i = 0 ; i < rank ; i++)
1646
+ for (unsigned i = 0 ; i < destRank ; i++)
1563
1647
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1564
1648
!ShapedType::isDynamic (destShape[i]);
1565
1649
}
1566
1650
1651
+ // If missing, initialize the write indices to 0.
1652
+ assert (writeIndices.empty () ||
1653
+ writeIndices.size () == static_cast <size_t >(destRank) &&
1654
+ " Invalid number of write indices!" );
1655
+ if (writeIndices.empty ()) {
1656
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1657
+ writeIndices = SmallVector<Value>(destRank, zero);
1658
+ }
1659
+
1567
1660
// Generate the xfer_write Op
1568
- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1569
- Operation *write = builder.create <vector::TransferWriteOp>(
1570
- loc,
1571
- /* vector=*/ vectorToStore,
1572
- /* source=*/ dest,
1573
- /* indices=*/ SmallVector<Value>(rank, zero),
1574
- /* inBounds=*/ inBoundsVal);
1575
- assert (llvm::none_of (
1576
- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1577
- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1578
- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1661
+ Operation *write =
1662
+ builder.create <vector::TransferWriteOp>(loc,
1663
+ /* vector=*/ vectorToStore,
1664
+ /* source=*/ dest,
1665
+ /* indices=*/ writeIndices,
1666
+ /* inBounds=*/ inBoundsVal);
1579
1667
1580
1668
// If masking is disabled, exit.
1581
1669
if (useInBoundsInsteadOfMasking)
1582
1670
return write ;
1583
1671
1672
+ assert (llvm::none_of (
1673
+ destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1674
+ [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1675
+ " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1676
+
1584
1677
// Check if masking is needed.
1585
1678
bool needMaskForWrite =
1586
1679
!llvm::equal (inputVecSizesForLeadingDims,
1587
- destShape.take_front (inputVecSizesForLeadingDims.size ()));
1680
+ destShape.take_front (destRank - vecToStoreRank +
1681
+ inputVecSizesForLeadingDims.size ()));
1588
1682
1589
1683
// If masking is needed, generate the mask and mask the operation.
1590
1684
if (needMaskForWrite) {
1685
+ // Get the mask shape + type. Missing mask dimensions are taken from
1686
+ // `vectorToStore`.
1591
1687
SmallVector<int64_t > writeMaskShape;
1592
1688
writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1593
1689
inputVecSizesForLeadingDims.end ());
1594
- writeMaskShape.append (destShape.begin () +
1595
- inputVecSizesForLeadingDims.size (),
1596
- destShape.end ());
1690
+ if (vecToStoreRank >
1691
+ static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1692
+ writeMaskShape.append (vecToStoreShape.begin () +
1693
+ inputVecSizesForLeadingDims.size (),
1694
+ vecToStoreShape.end ());
1597
1695
auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1598
- Value maskForWrite = builder.create <vector::CreateMaskOp>(
1599
- loc, writeMaskType, tensor::getMixedSizes (builder, loc, dest));
1696
+
1697
+ SmallVector<OpFoldResult> destSizes =
1698
+ tensor::getMixedSizes (builder, loc, dest);
1699
+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1700
+ destSizes.end ());
1701
+
1702
+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1703
+ writeMaskShape))
1704
+ return write ;
1705
+
1706
+ Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1707
+ loc, writeMaskType, maskSizes);
1600
1708
write = mlir::vector::maskOperation (builder, write , maskForWrite);
1601
1709
}
1602
1710
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1700
1808
Value dest = rewriter.create <tensor::EmptyOp>(
1701
1809
loc, reifiedReturnShapes[0 ],
1702
1810
transposeOp.getResult ().getType ().getElementType ());
1703
- Operation *write =
1704
- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1705
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1706
- /* useInBoundsInsteadOfMasking=*/ false );
1811
+ Operation *write = createWriteOrMaskedWrite (
1812
+ rewriter, loc, transposeOp.getResult (), dest,
1813
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, /* writeIndices= */ {} ,
1814
+ /* useInBoundsInsteadOfMasking=*/ false );
1707
1815
newResults.push_back (write ->getResult (0 ));
1708
1816
return success ();
1709
1817
}
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1839
1947
Value dest = rewriter.create <tensor::EmptyOp>(
1840
1948
loc, reifiedRetShapes[0 ],
1841
1949
shapeCastOp.getResult ().getType ().getElementType ());
1842
- Operation *write =
1843
- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (), dest,
1844
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1845
- useInBoundsInsteadOfMasking);
1950
+ Operation *write = createWriteOrMaskedWrite (
1951
+ rewriter, loc, shapeCastOp.getResult (), dest,
1952
+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1953
+ /* writeIndices= */ {}, useInBoundsInsteadOfMasking);
1846
1954
newResults.push_back (write ->getResult (0 ));
1847
1955
return success ();
1848
1956
}
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1874
1982
// Create Xfer write Op
1875
1983
Value dest = rewriter.create <tensor::EmptyOp>(
1876
1984
loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1877
- Operation *write =
1878
- createWriteOrMaskedWrite ( rewriter, loc, maskedRead, dest,
1879
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1880
- /* useInBoundsInsteadOfMasking=*/ false );
1985
+ Operation *write = createWriteOrMaskedWrite (
1986
+ rewriter, loc, maskedRead, dest,
1987
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {} ,
1988
+ /* useInBoundsInsteadOfMasking=*/ false );
1881
1989
newResults.push_back (write ->getResult (0 ));
1882
1990
return success ();
1883
1991
}
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2922
3030
auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2923
3031
2924
3032
// 3. Generate TransferReadOp + TransferWriteOp
2925
- ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926
- Value maskOp;
2927
-
2928
- // If vector sizes are user provided, make sure to mask. First, generate the
2929
- // mask.
2930
- if (!inputVectorSizes.empty ()) {
2931
- auto *srcDefOp = source.getDefiningOp ();
2932
- if (!srcDefOp) {
2933
- LDBG (" Unable to get the defining Op of " << sliceOp);
2934
- return failure ();
2935
- }
2936
-
2937
- LogicalResult status =
2938
- cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2939
- rewriter, reifiedSrcSizes);
2940
- if (status.failed ()) {
2941
- LDBG (" Unable to reify result shapes of " << srcDefOp);
2942
- return failure ();
2943
- }
2944
-
2945
- // Create the mask
2946
- auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2947
- maskOp = rewriter.create <vector::CreateMaskOp>(
2948
- sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
2949
- }
3033
+ auto loc = sliceOp.getLoc ();
2950
3034
3035
+ // Create read
2951
3036
SmallVector<Value> readIndices (
2952
- vecType.getRank (),
2953
- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2954
- Operation *read = rewriter.create <vector::TransferReadOp>(
2955
- sliceOp.getLoc (), vecType, source, readIndices, padValue,
2956
- ArrayRef<bool >{readInBounds});
2957
-
2958
- if (maskOp) {
2959
- read = mlir::vector::maskOperation (rewriter, read , maskOp);
2960
- }
2961
-
2962
- auto writeIndices = getValueOrCreateConstantIndexOp (
2963
- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2964
-
2965
- Operation *write = rewriter.create <vector::TransferWriteOp>(
2966
- sliceOp.getLoc (), read ->getResult (0 ), sliceOp.getDest (), writeIndices,
2967
- ArrayRef<bool >{writeInBounds});
2968
-
2969
- if (maskOp) {
2970
- write = mlir::vector::maskOperation (rewriter, write , maskOp);
2971
- }
3037
+ vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
3038
+ Value read = mlir::vector::createReadOrMaskedRead (
3039
+ rewriter, loc, source, vecType.getShape (), padValue);
3040
+
3041
+ // Create write
3042
+ auto writeIndices =
3043
+ getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3044
+ Operation *write = createWriteOrMaskedWrite (
3045
+ rewriter, loc, read , sliceOp.getDest (), vecType.getShape (), writeIndices);
2972
3046
2973
3047
// 4. Finalize
2974
3048
newResults.push_back (write ->getResult (0 ));
0 commit comments