@@ -1506,84 +1506,86 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1506
1506
return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
1507
1507
}
1508
1508
1509
- // / Creates a TransferWriteOp to write `input` into a newly initialized
1510
- // / output tensor.
1509
+ // / Creates an optionally masked TransferWriteOp
1511
1510
// /
1512
- // / Given:
1513
- // / - an input vector to write,
1514
- // / - the mixed destination sizes for the output tensor,
1515
- // / - and the vector sizes used for vectorization (i.e., the leading N dims,
1516
- // / for some value of N),
1517
- // /
1518
- // / this function generates the following sequence of ops:
1519
- // /
1520
- // / %dest = tensor.empty(%destSizes)
1521
- // / %res = vector.transfer_write %input into %dest
1511
+ // / Generates the following operation:
1512
+ // / %res = vector.transfer_write %vectorToStore into %dest
1522
1513
// /
1523
1514
// / If the leading N dimensions of the destination tensor do not match
1524
- // / `inputVecSizesForLeadingDims` (where N =
1525
- // / rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
1526
- // / correctness:
1515
+ // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1516
+ // / masking is applied to ensure correctness:
1527
1517
// /
1528
- // / %dest = tensor.empty(%destSizes )
1529
- // / %write = vector.transfer_write %input into %dest
1530
- // / %mask = vector.create_mask(%destSizes)
1531
- // / %res = vector.mask %mask { %write }
1518
+ // / %mask = vector.create_mask(%destShape )
1519
+ // / %res = vector.mask %mask {
1520
+ // / vector.transfer_write %vectorToStore into %dest
1521
+ // / }
1532
1522
// /
1533
1523
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1534
1524
// / is used instead of masking:
1535
1525
// /
1536
- // / %dest = tensor.empty(%destSizes)
1526
+ // / %write = vector.transfer_write %vectorToStore into %dest
1537
1527
// / in_bounds_flags = (...)
1538
1528
// / %res = vector.transfer_write %input into %dest
1539
1529
// / {in_bounds = in_bounds_flags}
1540
1530
// /
1541
- // / NOTE: all write offsets are set to 0.
1531
+ // / NOTE: All write offsets are set to 0.
1532
+ // / TODO: Allow specyfying write offsets.
1542
1533
// / NOTE: When N < rank(input), the missing vector sizes are effectively
1543
1534
// / extracted from the trailing sizes of `destSizes`. This means those sizes
1544
- // / must be static. Supporting dynamic sizes will require the user to specify
1545
- // / the remaining vector sizes. This is left as a TODO.
1535
+ // / must be static.
1536
+ // / TODO: Support cases where an arbitrary dim is dynamic - this will require
1537
+ // / specifying all the vector sizes.
1546
1538
static Operation *
1547
- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value input ,
1548
- SmallVector<OpFoldResult> destSizes ,
1539
+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore ,
1540
+ Value dest ,
1549
1541
ArrayRef<int64_t > inputVecSizesForLeadingDims,
1550
1542
bool useInBoundsInsteadOfMasking = false ) {
1551
1543
1552
- auto inputType = cast<VectorType>(input.getType ());
1553
- assert (inputType.getRank () == static_cast <int64_t >(destSizes.size ()) &&
1544
+ ShapedType destType = cast<ShapedType>(dest.getType ());
1545
+ assert (cast<VectorType>(vectorToStore.getType ()).getRank () ==
1546
+ static_cast <int64_t >(destType.getRank ()) &&
1554
1547
" Rank mismatch!" );
1555
1548
1556
- Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1557
- inputType.getElementType ());
1558
1549
int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1559
- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1560
1550
auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1551
+
1552
+ // Compute the in_bounds attribute
1561
1553
SmallVector<bool > inBoundsVal (rank, true );
1562
1554
if (useInBoundsInsteadOfMasking) {
1563
1555
// In this case, assume that all the required vector sizes have been
1564
1556
// provided.
1565
- assert (inputVecSizesForLeadingDims.size () == destSizes.size () &&
1557
+ assert (inputVecSizesForLeadingDims.size () ==
1558
+ static_cast <size_t >(destType.getRank ()) &&
1566
1559
" Insufficient number of input vector sizes!" );
1567
1560
// Update the inBounds attribute.
1568
1561
for (unsigned i = 0 ; i < rank; i++)
1569
1562
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1570
1563
!ShapedType::isDynamic (destShape[i]);
1571
1564
}
1565
+
1566
+ // Generate the xfer_write Op
1567
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1572
1568
Operation *write = builder.create <vector::TransferWriteOp>(
1573
1569
loc,
1574
- /* vector=*/ input ,
1570
+ /* vector=*/ vectorToStore ,
1575
1571
/* source=*/ dest,
1576
1572
/* indices=*/ SmallVector<Value>(rank, zero),
1577
1573
/* inBounds=*/ inBoundsVal);
1578
1574
assert (llvm::none_of (
1579
1575
destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1580
1576
[](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1581
1577
" Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1578
+
1579
+ // If masking is disabled, exit.
1582
1580
if (useInBoundsInsteadOfMasking)
1583
1581
return write ;
1582
+
1583
+ // Check if masking is needed.
1584
1584
bool needMaskForWrite =
1585
1585
!llvm::equal (inputVecSizesForLeadingDims,
1586
1586
destShape.take_front (inputVecSizesForLeadingDims.size ()));
1587
+
1588
+ // If masking is needed, generate the mask and mask the operation.
1587
1589
if (needMaskForWrite) {
1588
1590
SmallVector<int64_t > writeMaskShape;
1589
1591
writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
@@ -1592,10 +1594,11 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
1592
1594
inputVecSizesForLeadingDims.size (),
1593
1595
destShape.end ());
1594
1596
auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1595
- Value maskForWrite =
1596
- builder. create <vector::CreateMaskOp>(loc, writeMaskType, destSizes );
1597
+ Value maskForWrite = builder. create <vector::CreateMaskOp>(
1598
+ loc, writeMaskType, tensor::getMixedSizes (builder, loc, dest) );
1597
1599
write = mlir::vector::maskOperation (builder, write , maskForWrite);
1598
1600
}
1601
+
1599
1602
return write ;
1600
1603
}
1601
1604
@@ -1693,9 +1696,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1693
1696
loc, shapeCastOp.getResult (), destPermutation);
1694
1697
1695
1698
// Create TransferWriteOp.
1699
+ Value dest = rewriter.create <tensor::EmptyOp>(
1700
+ loc, reifiedReturnShapes[0 ],
1701
+ transposeOp.getResult ().getType ().getElementType ());
1696
1702
Operation *write =
1697
- createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1698
- /* destSizes=*/ reifiedReturnShapes[0 ],
1703
+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (), dest,
1699
1704
/* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1700
1705
/* useInBoundsInsteadOfMasking=*/ false );
1701
1706
newResults.push_back (write ->getResult (0 ));
@@ -1830,10 +1835,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1830
1835
unpackOp.getDestType ().hasStaticShape ()
1831
1836
? vectorSizes
1832
1837
: shapeCastOp.getResultVectorType ().getShape ());
1833
- Operation *write = createWriteOrMaskedWrite (
1834
- rewriter, loc, shapeCastOp.getResult (), /* destSizes=*/ reifiedRetShapes[0 ],
1835
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1836
- useInBoundsInsteadOfMasking);
1838
+ Value dest = rewriter.create <tensor::EmptyOp>(
1839
+ loc, reifiedRetShapes[0 ],
1840
+ shapeCastOp.getResult ().getType ().getElementType ());
1841
+ Operation *write =
1842
+ createWriteOrMaskedWrite (rewriter, loc, shapeCastOp.getResult (), dest,
1843
+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1844
+ useInBoundsInsteadOfMasking);
1837
1845
newResults.push_back (write ->getResult (0 ));
1838
1846
return success ();
1839
1847
}
@@ -1861,10 +1869,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1861
1869
auto maskedRead = vector::createReadOrMaskedRead (
1862
1870
rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1863
1871
/* useInBoundsInsteadOfMasking=*/ false );
1864
- Operation *write = createWriteOrMaskedWrite (
1865
- rewriter, loc, maskedRead, reifiedReturnShapes[0 ],
1866
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1867
- /* useInBoundsInsteadOfMasking=*/ false );
1872
+
1873
+ // Create Xfer write Op
1874
+ Value dest = rewriter.create <tensor::EmptyOp>(
1875
+ loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1876
+ Operation *write =
1877
+ createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest,
1878
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1879
+ /* useInBoundsInsteadOfMasking=*/ false );
1868
1880
newResults.push_back (write ->getResult (0 ));
1869
1881
return success ();
1870
1882
}
0 commit comments