Skip to content

Reapply "[mlir][linalg] Relax tensor.extract vectorization" (#102232) #102321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Aug 7, 2024

[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f]

Simplifies the vectorization of tensor.extract so that:

  • all cases that read into a genuinely multi-dim vector (*) are
    considered a gather load,
  • all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<[...]> : tensor<15x1xi32>

  %out = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor<1x1x4xi32>) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
    linalg.yield %extracted : i32
  } -> tensor<1x1x4xi32>

  return %out:tensor<1x1x4xi32>
}

Overview of the delta compared to the original submission (#99299):

  • removed an assert representing a condition that is being relaxed
    here,
  • added a test (reading from a column tensor) based on a repro from
    @hanhanW.

(*) vector<1x4x1xf32> is considered as 1D vector in this context.

)

[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f]

Simplifies the vectorization of tensor.extract so that:
* all cases that read into a genuinely multi-dim vector (*) are
  considered a gather load,
* all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

```mlir
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<[...]> : tensor<15x1xi32>

  %out = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor<1x1x4xi32>) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
    linalg.yield %extracted : i32
  } -> tensor<1x1x4xi32>

  return %out:tensor<1x1x4xi32>
}
```

Overview of the delta when compared to the original submission:
  * removed an assert representing a conditon that is being relaxed
    here,
  * added a test (reading from a column tensor) based on a repro from
    @hanhanW.

(*) `vector<1x4x1xf32>` is considered as 1D vector in this context.
@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f]

Simplifies the vectorization of tensor.extract so that:

  • all cases that read into a genuinely multi-dim vector (*) are
    considered a gather load,
  • all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

func.func @<!-- -->vectorize_scalar_broadcast_column_tensor(%in: tensor&lt;1x1x4xi32&gt;) -&gt; tensor&lt;1x1x4xi32&gt; {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense&lt;[...]&gt; : tensor&lt;15x1xi32&gt;

  %out = linalg.generic {
    indexing_maps = [affine_map&lt;(d0, d1, d2) -&gt; (d0, d1, d2)&gt;],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor&lt;1x1x4xi32&gt;) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor&lt;15x1xi32&gt;
    linalg.yield %extracted : i32
  } -&gt; tensor&lt;1x1x4xi32&gt;

  return %out:tensor&lt;1x1x4xi32&gt;
}

Overview of the delta when compared to the original submission:

  • removed an assert representing a conditon that is being relaxed
    here,
  • added a test (reading from a column tensor) based on a repro from
    @hanhanW.

(*) vector&lt;1x4x1xf32&gt; is considered as 1D vector in this context.


Full diff: https://github.com/llvm/llvm-project/pull/102321.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-25)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+105-12)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3d0d6abf702d7..63dcda78d0f2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -814,11 +814,9 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
 
   auto targetShape = linalgOp.getStaticLoopRanges();
-  assert(((llvm::count_if(targetShape,
-                          [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
+  assert(llvm::count_if(targetShape,
+                        [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
          "n-D vectors are not yet supported");
-  assert(targetShape.back() != 1 &&
-         "1-D vectors with the trailing dim eqaual 1 are not yet supported");
 
   // Blocks outside _this_ linalg.generic are effectively loop invariant.
   // However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -879,8 +877,6 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
   assert(((llvm::count_if(targetShape,
                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
          "n-D vectors are not yet supported");
-  assert(targetShape.back() != 1 &&
-         "1-D vectors with the trailing dim 1 are not yet supported");
 
   // Blocks outside _this_ linalg.generic are effectively loop invariant.
   // However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -946,27 +942,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (linalgOp.hasDynamicShape())
     return VectorMemoryAccessKind::Gather;
 
-  // 1. Assume that it's a gather load when reading _into_:
-  //    * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
-  //    * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
-  // TODO: Relax these conditions.
-  // FIXME: This condition assumes non-dynamic sizes.
-  if ((llvm::count_if(targetShape,
-                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
-      targetShape.back() == 1)
-    return VectorMemoryAccessKind::Gather;
+  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
+  // otherwise.
+  bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
+                             return dimSize > 1;
+                           }) == 1);
 
-  // 2. Assume that it's a gather load when reading _from_ a tensor for which
-  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
-  // TODO: Relax this condition.
-  if (inputShape.getShape().back() == 1)
+  // 1. Assume that it's a gather load when reading non-1D vector.
+  if (!isOutput1DVector)
     return VectorMemoryAccessKind::Gather;
 
   bool leadingIdxsLoopInvariant = true;
 
-  // 3. Analyze the leading indices of `extractOp`.
+  // 2. Analyze the leading indices of `extractOp`.
   // Look at the way each index is calculated and decide whether it is suitable
-  // for a contiguous load, i.e. whether it's loop invariant.
+  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
+  // gather load.
   auto indices = extractOp.getIndices();
   auto leadIndices = indices.drop_back(1);
 
@@ -982,13 +973,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::Gather;
   }
 
-  // 4. Analyze the trailing index for `extractOp`.
+  // 3. Analyze the trailing index for `extractOp`.
   // At this point we know that the leading indices are loop invariant. This
   // means that is potentially a scalar or a contiguous load. We can decide
   // based on the trailing idx.
   auto extractOpTrailingIdx = indices.back();
 
-  // 4a. Scalar broadcast load
+  // 3a. Scalar broadcast load
   // If the trailing index is loop invariant then this is a scalar load.
   if (leadingIdxsLoopInvariant &&
       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -997,7 +988,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::ScalarBroadcast;
   }
 
-  // 4b. Contiguous loads
+  // 3b. Contiguous loads
   // The trailing `extractOp` index should increment with every loop iteration.
   // This effectively means that it must be based on the trailing loop index.
   // This is what the following bool captures.
@@ -1011,7 +1002,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::Contiguous;
   }
 
-  // 5. Fallback case - gather load.
+  // 4. Fallback case - gather load.
   LDBG("Found gather load: " << extractOp);
   return VectorMemoryAccessKind::Gather;
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 85e1c56dd45a0..bdaa20c3bf971 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
 }
 
 // -----
+
 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
   %c0 = arith.constant 1 : index
@@ -74,20 +75,24 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
-  %1 = linalg.generic {
-    indexing_maps = [#map1],
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_transfer_read_basic(
+    %arg0: tensor<3x3x3xf32>,
+    %arg1: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+
+  %res = linalg.generic {
+    indexing_maps = [#map],
     iterator_types = ["parallel", "parallel", "parallel"]
-  } outs(%arg2 : tensor<1x1x3xf32>) {
-  ^bb0(%arg4: f32):
-    %2 = linalg.index 0 : index
-    %3 = linalg.index 1 : index
-    %4 = linalg.index 2 : index
-    %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
-    linalg.yield %5 : f32
+  } outs(%arg1 : tensor<1x1x3xf32>) {
+  ^bb0(%out: f32):
+    %1 = linalg.index 0 : index
+    %2 = linalg.index 1 : index
+    %3 = linalg.index 2 : index
+    %4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x3xf32>
+    linalg.yield %4 : f32
   } -> tensor<1x1x3xf32>
-  return %1 : tensor<1x1x3xf32>
+
+  return %res : tensor<1x1x3xf32>
 }
 
 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
@@ -104,6 +109,38 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
 // CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
+// Same as example above, but reading into a column tensor. Note that after the
+// vectorizatoin, the `TransferOpReduceRank` will replace
+// `vector.transfer_read` with `tensor.extract -> scalar`.
+
+// TODO: Currently this fails to vectorise when the indices are non-constant.
+
+func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
+    %input: tensor<3x3x3xf32>,
+    %output: tensor<3x1x1xf32>) -> tensor<3x1x1xf32> {
+
+  %c0 = arith.constant 0 : index
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%output : tensor<3x1x1xf32>) {
+  ^bb0(%out: f32):
+    %5 = tensor.extract %input[%c0, %c0, %c0] : tensor<3x3x3xf32>
+    linalg.yield %5 : f32
+  } -> tensor<3x1x1xf32>
+
+  return %res : tensor<3x1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<3x3x3xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
+// CHECK:           return %[[RES]] : tensor<3x1x1xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -595,3 +632,59 @@ module attributes {transform.with_named_sequence} {
      transform.yield
    }
 }
+
+
+// -----
+
+func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+
+  %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
+  ^bb0(%out: i32):
+    %8 = linalg.index 0 : index
+    %idx_0 = linalg.index 0 : index
+    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
+    linalg.yield %extracted : i32
+  } -> tensor<1x1x4xi32>
+
+  return %out:tensor<1x1x4xi32>
+}
+
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL:   func.func @vectorize_scalar_broadcast_column_tensor(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_10:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK:           %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_19:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_24:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0  vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
+    transform.yield
+  }
+}

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@banach-space banach-space merged commit 62e5032 into llvm:main Aug 8, 2024
10 checks passed
@banach-space banach-space deleted the andrzej/relax_extract_vectorization_v2 branch August 8, 2024 08:36
nirvedhmeshram added a commit that referenced this pull request Sep 21, 2024
…sor.extract case (#107922)

In #102321 we relaxed the
vectorizer so that when checking for contiguous loads we dont always
have a trailing non unit dim. For example in the test case added we have
`tensor<8x1xf32>` which is now a valid candidate for contiguous load.
However, the logic to check contiguous load assumed that only the
trailing dim will be non unit so this PR just updates that logic to find
the actual non unit dim.
banach-space added a commit that referenced this pull request Sep 24, 2024
This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the
following case is vectorised as `vector.gather` (as opposed to
attempting a contiguous load):
```mlir
  func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```

Specifically, when looking for loop-invariant indices in
`tensor.extract` Ops, any `linalg.index` Op that's used in address
colcluation should only access loop dims that are == 1. In the example
above, the following does not meet that criteria:
```mlir
  %1 = linalg.index 0 : index
```

Note that this PR also effectively addresses the issue fixed in #107922,
i.e. exercised by:
  * `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load`

`getNonUnitLoopDim` introduced in #107922 is still valid though. In
fact, it is required to identify that the following case is a contiguous
load:
```mlir
  func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```
Some logic is still missing to lower the above to
`vector.transfer_read`, so it is conservatively lowered to
`vector.gather` instead (see TODO in
`getTensorExtractMemoryAccessPattern`).

There's a few additional changes:
  * `getNonUnitLoopDim` is simplified and renamed as
    `getTrailingNonUnitLoopDimIdx`, additional comments are added (note
    that the functionality didn't change);
  * extra comments in a few places, variable names in comments update to
    use Markdown (which is the preferred approach in MLIR).

This is a follow-on for:
  * #107922
  * #102321
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants