Skip to content

[mlir][Vector] Fix "scalability" in CastAwayExtractStridedSliceLeadingOneDim #81187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

Makes sure that "scalability" flags in the
CastAwayExtractStridedSliceLeadingOneDim pattern are correctly
updated.

…gOneDim

Makes sure that "scalability" flags in the
`CastAwayExtractStridedSliceLeadingOneDim` pattern are correctly
updated.
@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Makes sure that "scalability" flags in the
CastAwayExtractStridedSliceLeadingOneDim pattern are correctly
updated.


Full diff: https://github.com/llvm/llvm-project/pull/81187.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+2-1)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+21)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index e1ed5d81625d8e..74382b027c2f48 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -73,7 +73,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     VectorType oldDstType = extractOp.getType();
     VectorType newDstType =
         VectorType::get(oldDstType.getShape().drop_front(dropCount),
-                        oldDstType.getElementType());
+                        oldDstType.getElementType(),
+                        oldDstType.getScalableDims().drop_front(dropCount));
 
     Location loc = extractOp.getLoc();
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index f601be04168144..bb2d30f2092435 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -206,6 +206,16 @@ func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8x
   return %0: vector<1x1x8xf16>
 }
 
+// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
+func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
+  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
+  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x1x[8]xf16>
+}
+
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
 func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
   // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
@@ -217,6 +227,17 @@ func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16
   return %0: vector<1x8x8xf16>
 }
 
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
+func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
+  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
+  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
+  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
+  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x8x[8]xf16>
+}
+
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
 func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {

@banach-space banach-space merged commit 0d72f0b into llvm:main Feb 9, 2024
@banach-space banach-space deleted the andrzej/fix_drop_leading_unit_dim branch February 9, 2024 17:23
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 9, 2024
This is a follow-up for llvm#81187, it simply adds missing tests for
scalable vectors.
banach-space added a commit that referenced this pull request Feb 12, 2024
This is a follow-up for #81187, it simply adds missing tests for
scalable vectors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants