Skip to content

[Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case #107922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

nirvedhmeshram
Copy link
Contributor

@nirvedhmeshram nirvedhmeshram commented Sep 9, 2024

In #102321 we relaxed the vectorizer so that when checking for contiguous loads we dont always have a trailing non unit dim. For example in the test case added we have tensor<8x1xf32> which is now a valid candidate for contiguous load. However, the logic to check contiguous load assumed that only the trailing dim will be non unit so this PR just updates that logic to find the actual non unit dim.

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

There is a case shown in #107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose.
Fixes: #107476


Full diff: https://github.com/llvm/llvm-project/pull/107922.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+25)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       continue;
     }
 
+    auto idxType = dyn_cast<VectorType>(idx.getType());
+
+    if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+      auto maxElement = std::max_element(resultType.getShape().begin(),
+                                         resultType.getShape().end());
+      auto maxElementDim =
+          std::distance(resultType.getShape().begin(), maxElement);
+      // This means that the result type of the index is non trailing and we
+      // insert transpose op in this case to match it to the extract type.
+      if (maxElementDim != resultType.getShape().size() - 1) {
+        SmallVector<int64_t> transposition = llvm::to_vector<16>(
+            llvm::seq<int64_t>(0, resultType.getShape().size()));
+        std::swap(transposition.back(), transposition[maxElementDim]);
+        auto transposeOp =
+            rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+        auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+            loc,
+            VectorType::get(*maxElement, rewriter.getIndexType(),
+                            resultType.getScalableDims().back()),
+            transposeOp);
+        transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+            loc, indexAs1dVector, zero));
+        continue;
+      }
+    }
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
         loc,
         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:     %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:     %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK:    %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:    %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK:    %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK:    %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK:    %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK:    %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK:    vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2024

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

There is a case shown in #107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose.
Fixes: #107476


Full diff: https://github.com/llvm/llvm-project/pull/107922.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+25)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       continue;
     }
 
+    auto idxType = dyn_cast<VectorType>(idx.getType());
+
+    if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+      auto maxElement = std::max_element(resultType.getShape().begin(),
+                                         resultType.getShape().end());
+      auto maxElementDim =
+          std::distance(resultType.getShape().begin(), maxElement);
+      // This means that the result type of the index is non trailing and we
+      // insert transpose op in this case to match it to the extract type.
+      if (maxElementDim != resultType.getShape().size() - 1) {
+        SmallVector<int64_t> transposition = llvm::to_vector<16>(
+            llvm::seq<int64_t>(0, resultType.getShape().size()));
+        std::swap(transposition.back(), transposition[maxElementDim]);
+        auto transposeOp =
+            rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+        auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+            loc,
+            VectorType::get(*maxElement, rewriter.getIndexType(),
+                            resultType.getScalableDims().back()),
+            transposeOp);
+        transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+            loc, indexAs1dVector, zero));
+        continue;
+      }
+    }
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
         loc,
         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:     %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:     %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK:    %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:    %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK:    %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK:    %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK:    %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK:    %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK:    vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>

@dcaballe
Copy link
Contributor

Hi, thanks for the fix! Could you please elaborate a bit more on "a way of handling it by adding an extra tranpose op". I'm not sure I can infer what that means. An IR example in the description would help.

@nirvedhmeshram
Copy link
Contributor Author

Hi, thanks for the fix! Could you please elaborate a bit more on "a way of handling it by adding an extra tranpose op". I'm not sure I can infer what that means. An IR example in the description would help.

Sounds good, seems like this is creating wrong IR as is, but once we have a resolution with help from @banach-space I can post an IR in the PR description of what we are generating for this case.

@nirvedhmeshram nirvedhmeshram force-pushed the nm_vectorize_nd_tensor_extract branch from a1d86c8 to 6c6ed79 Compare September 18, 2024 16:01
@nirvedhmeshram
Copy link
Contributor Author

@banach-space could you please take another look at this when you have a chance?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, thank you and sorry for the delay. A few more minor comments/suggestions.

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Sep 19, 2024

@banach-space #100582 breaks the logic here as one of the test cases is

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<?x?xf32>) {
^bb0(%out: f32):
  %6 = linalg.index 1 : index
  %7 = arith.addi %6, %arg2 : index
  %extracted = tensor.extract %arg0[%c79, %7] : tensor<?x?xf32>
  linalg.yield %extracted : f32
} -> tensor<?x?xf32>

where it will find two non unit dims and hit the assert

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Sep 19, 2024

@banach-space #100582 breaks the logic here as one of the test cases is

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<?x?xf32>) {
^bb0(%out: f32):
  %6 = linalg.index 1 : index
  %7 = arith.addi %6, %arg2 : index
  %extracted = tensor.extract %arg0[%c79, %7] : tensor<?x?xf32>
  linalg.yield %extracted : f32
} -> tensor<?x?xf32>

where it will find two non unit dims and hit the assert

I adapted the logic to check for dynamic dims and return the trailing dim if no non dynamic / non unit dim is found. That seems to work although seems ad-hoc.

@banach-space
Copy link
Contributor

@banach-space #100582 breaks the logic here as one of the test cases is

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<?x?xf32>) {
^bb0(%out: f32):
  %6 = linalg.index 1 : index
  %7 = arith.addi %6, %arg2 : index
  %extracted = tensor.extract %arg0[%c79, %7] : tensor<?x?xf32>
  linalg.yield %extracted : f32
} -> tensor<?x?xf32>

where it will find two non unit dims and hit the assert

Apologies, I didn't expect that to be such a curve ball. But I am really glad that we are hitting these cases, it really helps to stress-test this logic. Hopefully my explanation makes sense. I've also deleted a couple of my comments - #100582 invalidated them 😅

@nirvedhmeshram nirvedhmeshram force-pushed the nm_vectorize_nd_tensor_extract branch from a567a54 to 69ece45 Compare September 20, 2024 15:58
@nirvedhmeshram nirvedhmeshram force-pushed the nm_vectorize_nd_tensor_extract branch from 69ece45 to 95189eb Compare September 20, 2024 16:03
@llvm llvm deleted a comment from github-actions bot Sep 20, 2024
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this and for bearing with me 🙏🏻

I will try to prepare a follow-on to clarify the differences between loading from statically and dynamically shaped tensors. That's two cases. There's also masked vectorisation for statically shaped tensors, so that's 3 cases in total 😅

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Sep 21, 2024

LGTM, thanks for fixing this and for bearing with me 🙏🏻

I will try to prepare a follow-on to clarify the differences between loading from statically and dynamically shaped tensors. That's two cases. There's also masked vectorisation for statically shaped tensors, so that's 3 cases in total 😅

Thank you so much for help on this and taking up supporting those other cases. This back and forth gave me a lot of new insight into vectorization so no complaints here :)

@nirvedhmeshram nirvedhmeshram merged commit e45fc51 into llvm:main Sep 21, 2024
8 checks passed
banach-space added a commit that referenced this pull request Sep 24, 2024
This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the
following case is vectorised as `vector.gather` (as opposed to
attempting a contiguous load):
```mlir
  func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```

Specifically, when looking for loop-invariant indices in
`tensor.extract` Ops, any `linalg.index` Op that's used in address
colcluation should only access loop dims that are == 1. In the example
above, the following does not meet that criteria:
```mlir
  %1 = linalg.index 0 : index
```

Note that this PR also effectively addresses the issue fixed in #107922,
i.e. exercised by:
  * `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load`

`getNonUnitLoopDim` introduced in #107922 is still valid though. In
fact, it is required to identify that the following case is a contiguous
load:
```mlir
  func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```
Some logic is still missing to lower the above to
`vector.transfer_read`, so it is conservatively lowered to
`vector.gather` instead (see TODO in
`getTensorExtractMemoryAccessPattern`).

There's a few additional changes:
  * `getNonUnitLoopDim` is simplified and renamed as
    `getTrailingNonUnitLoopDimIdx`, additional comments are added (note
    that the functionality didn't change);
  * extra comments in a few places, variable names in comments update to
    use Markdown (which is the preferred approach in MLIR).

This is a follow-on for:
  * #107922
  * #102321
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Oct 31, 2024
…71becfbe2

Local branch amd-gfx ba271be Merged main:b44da2446b17aaa847bf76f81a01870917f8736b into amd-gfx:b625170ac5f7
Remote branch main e45fc51 [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (llvm#107922)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants