Skip to content

[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n) #95743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jun 17, 2024

The main goal of this and subsequent PRs is to unify and categorize
tests in:

  • vector-transfer-flatten.mlir

This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

  • split tests that covered xfer_read + xfer_write into separate
    tests (majority of the existing tests check one xfer Op at a time),
  • organise tests for xfer_read and xfer_write into separate
    groups (separate with a big bold comment).

Note, all tests (i.e. test cases) are preserved and some new tests are
added. Deletions that you will see in git diff correspond to
xfer_write and xfer_read Ops being extracted to separate functions (so
that there's one xfer Op per function). In particular, the number of
test functions has grown from 26 to 30.

In addition, this PR unifies the tests so that:

  • input variable names are consistent (e.g. make sure that the input
    memref is always arg)
  • CHECK lines use similar indentations
  • 2 x tabs are always used for function arguments, 1 x tab for
    function body

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between

  • FlattenContiguousRowMajorTransferWritePattern and
  • FlattenContiguousRowMajorTransferReadPattern.

@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

The main goal of this and subsequent PRs is to unify and categorize
tests in:

  • vector-transfer-flatten.mlir
    This should make it easier to identify the edge cases being tested (and
    how they differ), remove duplicates and to add tests for scalable
    vectors.

The main contributions of this PR:

  • split tests that covered xfer_read + xfer_write into separate tests
    (majority of the existing tests check one xfer Op at a time),
  • organise tests for xfer_read and xfer_write into separate groups
    (separate with a big bold comment).

Note, all tests (i.e. test cases) are preserved and some new tests are
added. Deletions that you will see in git diff correspond to
xfer_write and xfer_read Ops being extracted to separate functions (so
that there's one xfer Op per function). In particular, the number of
test functions has grown from 26 to 30.

In addition, this PR unifies the tests so that:

  • input variable names are consistent (e.g. make sure that the input
    memref is always arg)
  • CHECK lines use similar indentations
  • 2 x tabs are always used for function arguments, 1 x tab for
    function body

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between

  • FlattenContiguousRowMajorTransferWritePattern and
  • FlattenContiguousRowMajorTransferReadPattern.

Patch is 28.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95743.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+23-5)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+241-108)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
 /// the trailing dimension of the vector read is smaller than the provided
 /// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
-        dyn_cast<MemRefType>(collapsedSource.getType());
+        cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_write has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
 public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    SmallVector<Value> collapsedIndices =
-        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
-                            transferWriteOp.getIndices(), firstDimToCollapse);
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
+    // 1. Collapse the source memref
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
 
+    // 2.2 New indices
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferWriteOp.getIndices(), firstDimToCollapse);
+
+    // 3. Create new vector.transfer_write that writes to the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_write with the new one writing the
+    // collapsed shape
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..0de5a807affe0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_read_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
     %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 // contiguous subset of the memref, so "flattenable".
 
 func.func @transfer_read_dims_mismatch_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
-    return %v : vector<1x1x2x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+  return %v : vector<1x1x2x2xi8>
 }
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,135 +78,160 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x43x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
 // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
 func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-    %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+    %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+    %idx0 : index,
+    %idx1 : index) -> vector<2x2xf32> {
+
   %c0 = arith.constant 0 : index
   %cst_1 = arith.constant 0.000000e+00 : f32
-  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+    memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
   return %8 : vector<2x2xf32>
 }
 
-//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
 // CHECK-LABEL:  func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
+func.func @transfer_read_dims_mismatch_non_contiguous(
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
 func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x?x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
-    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+    %arg : memref<i8>) -> vector<i8> {
+
+  %cst = arith.constant 0 : i8
+  %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+  return %0 : vector<i8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_0d
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_0d(
 //   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_write_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+  return
 }
 
 // CHECK-LABEL: func @transfer_write_dims_match_contiguous(
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
 // CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
@@ -208,42 +241,101 @@ func.func @transfer_write_dims_match_contiguous(
 
 // -----
 
+func.func @transfer_write_dims_match_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
+// CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+
+// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 func....
[truncated]

banach-space added a commit to banach-space/llvm-project that referenced this pull request Jun 17, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and
   `@transfer_read_flattenable_negative` duplicated
   `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both
   tests are deleted
   (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is
   preserved).

2. `@transfer_read_flattenable_negative2` is replaced with
   `@transfer_read_non_contiguous_src` and
   `@transfer_write_non_contiguous_src` (i.e. a dedicated test for
   xfer_read and xfer_read with more descriptive func names)

Depends on llvm#95743.

**Only review the top commit.**
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jun 17, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on llvm#95743 and llvm#95744

**Only review the top top commit**
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
  * split tests that covered xfer_read + xfer_write into separate tests
    (majority of the existing tests check _one_ xfer Op at a time),
  * organise tests for xfer_read and xfer_write into separate groups
    (separate with a big bold comment).

Note, all tests (i.e. test cases) are preserved and some new tests are
added. Deletions that you will see in `git diff` correspond to
xfer_write and xfer_read Ops being extracted to separate functions (so
that there's one xfer Op per function). In particular, the number of
test functions has grown from 26 to 30.

In addition, this PR unifies the tests so that:
  * input variable names are consistent (e.g. make sure that the input
    memref is always `arg`)
  * CHECK lines use similar indentations
  * 2 x tabs are always used for function arguments, 1 x tab for
    function body

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between
  * `FlattenContiguousRowMajorTransferWritePattern` and
  * `FlattenContiguousRowMajorTransferReadPattern`.
@banach-space banach-space force-pushed the andrzej/refactor_xfer_flatten_1 branch from d26b321 to f8bebff Compare June 17, 2024 07:31
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jun 17, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and
   `@transfer_read_flattenable_negative` duplicated
   `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both
   tests are deleted
   (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is
   preserved).

2. `@transfer_read_flattenable_negative2` is replaced with
   `@transfer_read_non_contiguous_src` and
   `@transfer_write_non_contiguous_src` (i.e. a dedicated test for
   xfer_read and xfer_read with more descriptive func names)

Depends on llvm#95743.

**Only review the top commit.**
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jun 17, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on llvm#95743 and llvm#95744

**Only review the top top commit**
@banach-space banach-space requested a review from nujaa June 17, 2024 07:36
Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like those changes, thanks for maintaining this. I am trying to review your PRs by the end of the week but won’t request any changes as I’ll be taking some time off and won’t be able to approve any of your possible changes.

@@ -568,6 +568,7 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if changes in this file have much to do with a test refactoring commit. I am happy to be proven wrong if this is following usual guidelines. I am still a newbie.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle:

  • one patch == one change,
  • avoid unrelated changes.

In this patch, I am violating these rules. First, I am making 2 changes:

Second, this particular change qualifies as "unrelated". So, if we were to go by the book, I should split it into a separate PR. I am happy to do that, but I am also mindful that I'm generating a lot of PR traffic and want to reduce noise 😅

In situation like this, I try to make the intent clear in the summary:

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between

FlattenContiguousRowMajorTransferWritePattern and
FlattenContiguousRowMajorTransferReadPattern.

... and then follow the reviewers recommendation. If I was reviewing this, I'd say "move to a different patch - I'll happily review that" ;-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's ok with me. Code owners might have another point of view.

@banach-space banach-space requested a review from MacDue June 21, 2024 07:50
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice cleanup

@banach-space banach-space merged commit 34de7fd into llvm:main Jun 21, 2024
7 checks passed
@banach-space banach-space deleted the andrzej/refactor_xfer_flatten_1 branch June 21, 2024 09:56
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jun 21, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and
   `@transfer_read_flattenable_negative` duplicated
   `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both
   tests are deleted
   (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is
   preserved).

2. `@transfer_read_flattenable_negative2` is replaced with
   `@transfer_read_non_contiguous_src` and
   `@transfer_write_non_contiguous_src` (i.e. a dedicated test for
   xfer_read and xfer_read with more descriptive func names)

Depends on llvm#95743.

**Only review the top commit.**
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jun 21, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on llvm#95743 and llvm#95744

**Only review the top top commit**
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…m#95743)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
  
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
  * split tests that covered `xfer_read` + `xfer_write` into separate
tests (majority of the existing tests check _one_ xfer Op at a time),
  * organise tests for `xfer_read` and `xfer_write` into separate
    groups (separate with a big bold comment).

Note, all tests (i.e. test cases) are preserved and some new tests are
added. Deletions that you will see in `git diff` correspond to
`xfer_write` and `xfer_read` Ops being extracted to separate functions
(so that there's one xfer Op per function). In particular, the number of
test functions has grown from 26 to 30.

In addition, this PR unifies the tests so that:
  * input variable names are consistent (e.g. make sure that the input
    memref is always `arg`)
  * CHECK lines use similar indentations
  * 2 x tabs are always used for function arguments, 1 x tab for
    function body

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between
  * `FlattenContiguousRowMajorTransferWritePattern` and
  * `FlattenContiguousRowMajorTransferReadPattern`.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jul 21, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on llvm#95743 and llvm#95744

**Only review the top top commit**
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jul 22, 2024
The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on llvm#95743 and llvm#95744

**Only review the top top commit**
banach-space added a commit that referenced this pull request Jul 22, 2024
)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir

This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. For consistency with other tests,
   `@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed
   as `@transfer_read_leading_dynamic_dims`. It is also moved near other
   tests for `xfer_read`, variable names are updated to match other
   `xfer_read` tests

2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`
   is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to
   better highlight that it's a negative test and to contrast it with
   `@transfer_read_leading_dynamic_dims` (and to emphasise the
   difference between the two).

3. Similar changes for tests for `xfer_write`.

4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`).

Follow-up for #95743 and #95744
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir

This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. For consistency with other tests,
   `@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed
   as `@transfer_read_leading_dynamic_dims`. It is also moved near other
   tests for `xfer_read`, variable names are updated to match other
   `xfer_read` tests

2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`
   is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to
   better highlight that it's a negative test and to contrast it with
   `@transfer_read_leading_dynamic_dims` (and to emphasise the
   difference between the two).

3. Similar changes for tests for `xfer_write`.

4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`).

Follow-up for #95743 and #95744
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants