Skip to content

[mlir][linalg] Relax tensor.extract vectorization #99299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jul 17, 2024

Simplifies the vectorization of tensor.extract so that:

  • all cases that read into a genuinely multi-dim vector (*) are
    considered a gather load,
  • all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<[...]> : tensor<15x1xi32>

  %out = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor<1x1x4xi32>) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
    linalg.yield %extracted : i32
  } -> tensor<1x1x4xi32>

  return %out:tensor<1x1x4xi32>
}

(*) vector<1x4x1xf32> is considered as 1D vector in this context.

Simplifies the vectorization of tensor.extract so that:
* all cases that read into a genuinely multi-dim vector (*) are
  considered a gather load,
* all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

```mlir
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<[...]> : tensor<15x1xi32>

  %out = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor<1x1x4xi32>) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
    linalg.yield %extracted : i32
  } -> tensor<1x1x4xi32>

  return %out:tensor<1x1x4xi32>
}
```

(*) `vector<1x4x1xf32>` is considered as 1D vector in this context.
@banach-space banach-space force-pushed the andrzej/relax_extract_vectorization branch from 6966f4e to 20b886a Compare July 30, 2024 18:50
@banach-space banach-space changed the title [mlir][linalg] Refine tensor.extract vectorization [mlir][linalg] Relax tensor.extract vectorization Jul 30, 2024
@banach-space banach-space marked this pull request as ready for review July 30, 2024 18:51
@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

Simplifies the vectorization of tensor.extract so that:

  • all cases that read into a genuinely multi-dim vector (*) are
    considered a gather load,
  • all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

func.func @<!-- -->vectorize_scalar_broadcast_column_tensor(%in: tensor&lt;1x1x4xi32&gt;) -&gt; tensor&lt;1x1x4xi32&gt; {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense&lt;[...]&gt; : tensor&lt;15x1xi32&gt;

  %out = linalg.generic {
    indexing_maps = [affine_map&lt;(d0, d1, d2) -&gt; (d0, d1, d2)&gt;],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor&lt;1x1x4xi32&gt;) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor&lt;15x1xi32&gt;
    linalg.yield %extracted : i32
  } -&gt; tensor&lt;1x1x4xi32&gt;

  return %out:tensor&lt;1x1x4xi32&gt;
}

(*) vector&lt;1x4x1xf32&gt; is considered as 1D vector in this context.


Full diff: https://github.com/llvm/llvm-project/pull/99299.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+15-20)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+56)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9185663799e52..34d9c4247b916 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -944,27 +944,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (linalgOp.hasDynamicShape())
     return VectorMemoryAccessKind::Gather;
 
-  // 1. Assume that it's a gather load when reading _into_:
-  //    * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
-  //    * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
-  // TODO: Relax these conditions.
-  // FIXME: This condition assumes non-dynamic sizes.
-  if ((llvm::count_if(targetShape,
-                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
-      targetShape.back() == 1)
-    return VectorMemoryAccessKind::Gather;
-
-  // 2. Assume that it's a gather load when reading _from_ a tensor for which
-  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
-  // TODO: Relax this condition.
-  if (inputShape.getShape().back() == 1)
+  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
+  // otherwise.
+  bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
+                             return dimSize > 1;
+                           }) == 1);
+
+  // 1. Assume that it's a gather load when reading non-1D vector.
+  if (!isOutput1DVector)
     return VectorMemoryAccessKind::Gather;
 
   bool leadingIdxsLoopInvariant = true;
 
-  // 3. Analyze the leading indices of `extractOp`.
+  // 2. Analyze the leading indices of `extractOp`.
   // Look at the way each index is calculated and decide whether it is suitable
-  // for a contiguous load, i.e. whether it's loop invariant.
+  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
+  // gather load.
   auto indices = extractOp.getIndices();
   auto leadIndices = indices.drop_back(1);
 
@@ -980,13 +975,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::Gather;
   }
 
-  // 4. Analyze the trailing index for `extractOp`.
+  // 3. Analyze the trailing index for `extractOp`.
   // At this point we know that the leading indices are loop invariant. This
   // means that is potentially a scalar or a contiguous load. We can decide
   // based on the trailing idx.
   auto extractOpTrailingIdx = indices.back();
 
-  // 4a. Scalar broadcast load
+  // 3a. Scalar broadcast load
   // If the trailing index is loop invariant then this is a scalar load.
   if (leadingIdxsLoopInvariant &&
       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -995,7 +990,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::ScalarBroadcast;
   }
 
-  // 4b. Contiguous loads
+  // 3b. Contiguous loads
   // The trailing `extractOp` index should increment with every loop iteration.
   // This effectively means that it must be based on the trailing loop index.
   // This is what the following bool captures.
@@ -1009,7 +1004,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::Contiguous;
   }
 
-  // 5. Fallback case - gather load.
+  // 4. Fallback case - gather load.
   LDBG("Found gather load: " << extractOp);
   return VectorMemoryAccessKind::Gather;
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 85e1c56dd45a0..ac75a19cbeb28 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -595,3 +595,59 @@ module attributes {transform.with_named_sequence} {
      transform.yield
    }
 }
+
+
+// -----
+
+func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+
+  %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
+  ^bb0(%out: i32):
+    %8 = linalg.index 0 : index
+    %idx_0 = linalg.index 0 : index
+    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
+    linalg.yield %extracted : i32
+  } -> tensor<1x1x4xi32>
+
+  return %out:tensor<1x1x4xi32>
+}
+
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL:   func.func @vectorize_scalar_broadcast_column_tensor(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_10:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK:           %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_19:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_24:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0  vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
+    transform.yield
+  }
+}

@dcaballe
Copy link
Contributor

Hey! I was looking at the changes and thinking that we may end up introducing too much complexity if we add support for all these unit dimension special cases. I think we discussed this in the past and decided to leverage the drop unit dimension utilities for that. For this particular case, could we generate a gather and then have a canonicalization pattern that turns it into a load when the unit dimension is removed?

@banach-space
Copy link
Contributor Author

Hey! I was looking at the changes and thinking that we may end up introducing too much complexity if we add support for all these unit dimension special cases.

Thanks for taking a look! Note that this is actual simplifying the current logic 😅 (and reducing the number of special cases).

I think we discussed this in the past and decided to leverage the drop unit dimension utilities for that. For this particular case, could we generate a gather and then have a canonicalization pattern that turns it into a load when the unit dimension is removed?

This is still not obvious to me, tbh. My concern is that generating gather loads introduces a chain of not-so-trivial address calculation Ops. But this is definitely worth exploring - I agree that we should reduce the complexity in the vectoriser.

And yes, we have discussed this in the past - I finally have some cycles to revisit :)

@dcaballe
Copy link
Contributor

dcaballe commented Aug 1, 2024

Hey! I was looking at the changes and thinking that we may end up introducing too much complexity if we add support for all these unit dimension special cases.

Thanks for taking a look! Note that this is actual simplifying the current logic 😅 (and reducing the number of special cases).

The way I see it is that it's replacing an early exit with a special case.

Would you have an example of how the gather would look like after vectorization and after removing the unit dim? That would be helpful to make a call.

@banach-space
Copy link
Contributor Author

banach-space commented Aug 1, 2024

Hey! I was looking at the changes and thinking that we may end up introducing too much complexity if we add support for all these unit dimension special cases.

Thanks for taking a look! Note that this is actual simplifying the current logic 😅 (and reducing the number of special cases).

The way I see it is that it's replacing an early exit with a special case.

Well, I am replacing:

if (some_complex_condition)
  return VectorMemoryAccessKind::Gather;
  
if (some_other_complex_condition)
  return VectorMemoryAccessKind::Gather;

with:

  // One less complex condition
  if (!isOutput1DVector)
    return VectorMemoryAccessKind::Gather;

😅

Would you have an example of how the gather would look like after vectorization and after removing the unit dim? That would be helpful to make a call.

Sure!

BEFORE THIS CHANGE

  func.func @vectorize_scalar_broadcast_column_tensor(%arg0: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
    %c1 = arith.constant 1 : index
    %c1_0 = arith.constant 1 : index
    %c4_1 = arith.constant 4 : index
    %c0_2 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %c0_i32 : tensor<1x1x4xi32>, vector<1x1x4xi32>
    %1 = vector.step : vector<1xindex>
    %2 = vector.broadcast %1 : vector<1xindex> to vector<4x1x1xindex>
    %3 = vector.transpose %2, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %4 = vector.step : vector<1xindex>
    %5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
    %6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %cst_3 = arith.constant dense<true> : vector<1x1x4xi1>
    %cst_4 = arith.constant dense<0> : vector<1x1x4xi32>
    %c0_5 = arith.constant 0 : index
    %c1_6 = arith.constant 1 : index
    %dim = tensor.dim %cst, %c1_6 : tensor<30x1xi32>
    %7 = vector.broadcast %dim : index to vector<1x1x4xindex>
    %8 = arith.muli %6, %7 : vector<1x1x4xindex>
    %cst_7 = arith.constant dense<0> : vector<1x1x4xindex>
    %9 = arith.addi %cst_7, %8 : vector<1x1x4xindex>
    %10 = vector.gather %cst[%c0_5, %c0_5] [%9], %cst_3, %cst_4 : tensor<30x1xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
    %c0_8 = arith.constant 0 : index
    %11 = vector.transfer_write %10, %arg0[%c0_8, %c0_8, %c0_8] : vector<1x1x4xi32>, tensor<1x1x4xi32>
    return %11 : tensor<1x1x4xi32>
  }

AFTER THIS CHANGE

  func.func @vectorize_scalar_broadcast_column_tensor(%arg0: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
    %c1 = arith.constant 1 : index
    %c1_0 = arith.constant 1 : index
    %c4_1 = arith.constant 4 : index
    %c0_2 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %c0_i32 : tensor<1x1x4xi32>, vector<1x1x4xi32>
    %1 = vector.step : vector<1xindex>
    %2 = vector.broadcast %1 : vector<1xindex> to vector<4x1x1xindex>
    %3 = vector.transpose %2, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %4 = vector.step : vector<1xindex>
    %5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
    %6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %cst_3 = arith.constant dense<true> : vector<1x1x4xi1>
    %cst_4 = arith.constant dense<0> : vector<1x1x4xi32>
    %c0_5 = arith.constant 0 : index
    %c0_i32_6 = arith.constant 0 : i32
    %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
    %8 = vector.extractelement %7[%c0_i32_6 : i32] : vector<4xindex>
    %c0_i32_7 = arith.constant 0 : i32
    %9 = vector.transfer_read %cst[%8, %c0], %c0_i32_7 {in_bounds = [true, true, true], permutation_map = #map} : tensor<30x1xi32>, vector<1x1x4xi32>
    %c0_8 = arith.constant 0 : index
    %10 = vector.transfer_write %9, %arg0[%c0_8, %c0_8, %c0_8] : vector<1x1x4xi32>, tensor<1x1x4xi32>
    return %10 : tensor<1x1x4xi32>
  }

So, we'd need to match:

    %4 = vector.step : vector<1xindex>
    %5 = vector.broadcast %4 : vector<1xindex> to vector<4x1x1xindex>
    %6 = vector.transpose %5, [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
    %dim = tensor.dim %cst, %c1_6 : tensor<30x1xi32>
    %7 = vector.broadcast %dim : index to vector<1x1x4xindex>
    %8 = arith.muli %6, %7 : vector<1x1x4xindex>
    %cst_7 = arith.constant dense<0> : vector<1x1x4xindex>
    %9 = arith.addi %cst_7, %8 : vector<1x1x4xindex>
    %10 = vector.gather %cst[%c0_5, %c0_5] [%9], %cst_3, %cst_4 : tensor<30x1xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>

Not that bad, but that's "roughly" what the vectoriser matches today to decide that the underlying tensor.extract is a broadcast of a scalar. IIUC, you are suggesting removing that logic from the vectorizer and creating a Vector dialect pattern instead? Would you go as far as simplifying the vectorizer to always generate vector.gather and the let Vector patterns "lower" that to something more efficient?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is beyond gather: I'm worried that we start adding logic to vectorize code as if the unit dimension wouldn't be there in general. This case looks complex enough to keep the complexity where it is, at least for now. I guess we can move that outside the vectorizer if that's ever needed.

@banach-space banach-space merged commit 8868c02 into llvm:main Aug 6, 2024
11 checks passed
@hanhanW
Copy link
Contributor

hanhanW commented Aug 6, 2024

Hi, this breaks downstream IREE project, and I'm going to revert it. Below is the upstream repro (and I will attach the case to the revert commit).

Run: mlir-opt -transform-interpreter ~/repro.mlir

#map = affine_map<(d0, d1) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
module {
  func.func @foo(%arg0: index, %arg1: tensor<2xf32>, %arg2: tensor<4xf32>, %arg3: tensor<1xf32>) -> tensor<4x1xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 1.000000e+00 : f32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<4x1xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg2, %arg3 : tensor<4xf32>, tensor<1xf32>) outs(%0 : tensor<4x1xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %2 = linalg.index 0 : index
      %3 = linalg.index 1 : index
      %4 = affine.apply #map3(%3, %arg0)
      %extracted = tensor.extract %arg1[%c0] : tensor<2xf32>
      %5 = arith.cmpi eq, %2, %c0 : index
      %6 = arith.cmpi ult, %2, %c0 : index
      %7 = arith.select %5, %cst, %in : f32
      %8 = arith.select %6, %cst_0, %7 : f32
      %9 = arith.cmpi eq, %4, %c0 : index
      %10 = arith.cmpi ult, %4, %c0 : index
      %11 = arith.select %9, %cst, %in_1 : f32
      %12 = arith.select %10, %cst_0, %11 : f32
      %13 = arith.mulf %8, %12 : f32
      %14 = arith.mulf %13, %extracted : f32
      %15 = arith.cmpi eq, %2, %4 : index
      %16 = arith.select %15, %cst, %cst_0 : f32
      %17 = arith.subf %16, %14 : f32
      linalg.yield %17 : f32
    } -> tensor<4x1xf32>
    return %1 : tensor<4x1xf32>
  }
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 : !transform.any_op
    transform.yield
  }
}

Error:

mlir-opt: llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp:820: bool isLoopInvariantIdx(mlir::linalg::LinalgOp&, mlir::Value&): Assertion `targetShape.back() != 1 && "1-D vectors with the trailing dim eqaual 1 are not yet supp
orted"' failed.                                                                                                                                                                                                                                              PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
...

hanhanW added a commit that referenced this pull request Aug 6, 2024
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Aug 6, 2024
hanhanW added a commit that referenced this pull request Aug 6, 2024
Reverts #99299 because it breaks the lowering. To
repro: `mlir-opt -transform-interpreter ~/repro.mlir`

```mlir
#map = affine_map<(d0, d1) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
module {
  func.func @foo(%arg0: index, %arg1: tensor<2xf32>, %arg2: tensor<4xf32>, %arg3: tensor<1xf32>) -> tensor<4x1xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 1.000000e+00 : f32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<4x1xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg2, %arg3 : tensor<4xf32>, tensor<1xf32>) outs(%0 : tensor<4x1xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %2 = linalg.index 0 : index
      %3 = linalg.index 1 : index
      %4 = affine.apply #map3(%3, %arg0)
      %extracted = tensor.extract %arg1[%c0] : tensor<2xf32>
      %5 = arith.cmpi eq, %2, %c0 : index
      %6 = arith.cmpi ult, %2, %c0 : index
      %7 = arith.select %5, %cst, %in : f32
      %8 = arith.select %6, %cst_0, %7 : f32
      %9 = arith.cmpi eq, %4, %c0 : index
      %10 = arith.cmpi ult, %4, %c0 : index
      %11 = arith.select %9, %cst, %in_1 : f32
      %12 = arith.select %10, %cst_0, %11 : f32
      %13 = arith.mulf %8, %12 : f32
      %14 = arith.mulf %13, %extracted : f32
      %15 = arith.cmpi eq, %2, %4 : index
      %16 = arith.select %15, %cst, %cst_0 : f32
      %17 = arith.subf %16, %14 : f32
      linalg.yield %17 : f32
    } -> tensor<4x1xf32>
    return %1 : tensor<4x1xf32>
  }
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 : !transform.any_op
    transform.yield
  }
}
```
@banach-space
Copy link
Contributor Author

Sorry about that :( Thanks for the quick revert and for the reproducer 🙏🏻!

@banach-space
Copy link
Contributor Author

@hanhanW Please take a look when you get a chance: #102321. AFAIK, Diego is OOO this and the following week.

banach-space added a commit that referenced this pull request Aug 8, 2024
[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f]

Simplifies the vectorization of tensor.extract so that:
* all cases that read into a genuinely multi-dim vector (*) are
  considered a gather load,
* all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

```mlir
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<[...]> : tensor<15x1xi32>

  %out = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor<1x1x4xi32>) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
    linalg.yield %extracted : i32
  } -> tensor<1x1x4xi32>

  return %out:tensor<1x1x4xi32>
}
```

Overview of the delta compared to the original submission (#99299):
  * removed an assert representing a condition that is being relaxed
    here,
  * added a test (reading from a column tensor) based on a repro from
    @hanhanW.

(*) `vector<1x4x1xf32>` is considered as 1D vector in this context.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants