Skip to content

Commit 34de7fd

Browse files
authored
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n) (#95743)
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: * split tests that covered `xfer_read` + `xfer_write` into separate tests (majority of the existing tests check _one_ xfer Op at a time), * organise tests for `xfer_read` and `xfer_write` into separate groups (separate with a big bold comment). Note, all tests (i.e. test cases) are preserved and some new tests are added. Deletions that you will see in `git diff` correspond to `xfer_write` and `xfer_read` Ops being extracted to separate functions (so that there's one xfer Op per function). In particular, the number of test functions has grown from 26 to 30. In addition, this PR unifies the tests so that: * input variable names are consistent (e.g. make sure that the input memref is always `arg`) * CHECK lines use similar indentations * 2 x tabs are always used for function arguments, 1 x tab for function body Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to unify comments and logic between * `FlattenContiguousRowMajorTransferWritePattern` and * `FlattenContiguousRowMajorTransferReadPattern`.
1 parent d4d95ee commit 34de7fd

File tree

2 files changed

+264
-113
lines changed

2 files changed

+264
-113
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,7 @@ namespace {
568568
/// memref.collapse_shape on the source so that the resulting
569569
/// vector.transfer_read has a 1D source. Requires the source shape to be
570570
/// already reduced i.e. without unit dims.
571+
///
571572
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
572573
/// the trailing dimension of the vector read is smaller than the provided
573574
/// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
617618
Value collapsedSource =
618619
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
619620
MemRefType collapsedSourceType =
620-
dyn_cast<MemRefType>(collapsedSource.getType());
621+
cast<MemRefType>(collapsedSource.getType());
621622
int64_t collapsedRank = collapsedSourceType.getRank();
622623
assert(collapsedRank == firstDimToCollapse + 1);
623624

@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
658659
/// memref.collapse_shape on the source so that the resulting
659660
/// vector.transfer_write has a 1D source. Requires the source shape to be
660661
/// already reduced i.e. without unit dims.
662+
///
663+
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
664+
/// the trailing dimension of the vector read is smaller than the provided
665+
/// bitwidth.
661666
class FlattenContiguousRowMajorTransferWritePattern
662667
: public OpRewritePattern<vector::TransferWriteOp> {
663668
public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
674679
VectorType vectorType = cast<VectorType>(vector.getType());
675680
Value source = transferWriteOp.getSource();
676681
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
682+
683+
// 0. Check pre-conditions
677684
// Contiguity check is valid on tensors only.
678685
if (!sourceType)
679686
return failure();
687+
// If this is already 0D/1D, there's nothing to do.
680688
if (vectorType.getRank() <= 1)
681689
// Already 0D/1D, nothing to do.
682690
return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
688696
return failure();
689697
if (!vector::isContiguousSlice(sourceType, vectorType))
690698
return failure();
691-
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
692699
// TODO: generalize this pattern, relax the requirements here.
693700
if (transferWriteOp.hasOutOfBoundsDim())
694701
return failure();
@@ -697,22 +704,30 @@ class FlattenContiguousRowMajorTransferWritePattern
697704
if (transferWriteOp.getMask())
698705
return failure();
699706

700-
SmallVector<Value> collapsedIndices =
701-
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
702-
transferWriteOp.getIndices(), firstDimToCollapse);
707+
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
703708

709+
// 1. Collapse the source memref
704710
Value collapsedSource =
705711
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
706712
MemRefType collapsedSourceType =
707713
cast<MemRefType>(collapsedSource.getType());
708714
int64_t collapsedRank = collapsedSourceType.getRank();
709715
assert(collapsedRank == firstDimToCollapse + 1);
710716

717+
// 2. Generate input args for a new vector.transfer_read that will read
718+
// from the collapsed memref.
719+
// 2.1. New dim exprs + affine map
711720
SmallVector<AffineExpr, 1> dimExprs{
712721
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
713722
auto collapsedMap =
714723
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
715724

725+
// 2.2 New indices
726+
SmallVector<Value> collapsedIndices =
727+
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
728+
transferWriteOp.getIndices(), firstDimToCollapse);
729+
730+
// 3. Create new vector.transfer_write that writes to the collapsed memref
716731
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
717732
vectorType.getElementType());
718733
Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
721736
rewriter.create<vector::TransferWriteOp>(
722737
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
723738
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
739+
740+
// 4. Replace the old transfer_write with the new one writing the
741+
// collapsed shape
724742
rewriter.eraseOp(transferWriteOp);
725743
return success();
726744
}

0 commit comments

Comments
 (0)