Skip to content

[MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern #92426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2024

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 16, 2024

Implements TransferOpReduceRank as a MaskableOpRewritePattern. Allowing to exit gracefully when run on a vector::transfer_read located inside a vector::MaskOp instead of breaking because the pattern generated multiple ops in the MaskOp with error: 'vector.mask' op expects only one operation to mask.

Split of #90835

@nujaa
Copy link
Contributor Author

nujaa commented May 16, 2024

CC @MacDue @banach-space .
On top of #91987

@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

Implements TransferOpReduceRank as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a vector::transfer_read is inside a vector::MaskOp

Split of #90835


Full diff: https://github.com/llvm/llvm-project/pull/92426.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+62-35)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+76)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+15)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f..63d3ec91e512f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -90,14 +90,19 @@ namespace {
 /// Note that an alternative is to transform it to linalg.transpose +
 /// vector.transfer_read to do the transpose in memory instead.
 struct TransferReadPermutationLowering
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_read inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
 
     // Transpose result of transfer_read.
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
-                                                     transposePerm);
-    return success();
+    return rewriter
+        .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+        .getResult();
   }
 };
 
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
 ///     %v = vector.transfer_write %tmp ...
 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
 struct TransferWritePermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_write inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -207,11 +217,14 @@ struct TransferWritePermutationLowering
         op.getLoc(), op.getVector(), indices);
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        op.getMask(), newInBoundsAttr);
-
-    return success();
+    auto newWrite = rewriter.create<vector::TransferWriteOp>(
+        op.getLoc(), newVec, op.getSource(), op.getIndices(),
+        AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
+    if (newWrite.hasPureTensorSemantics())
+      return newWrite.getResult();
+    // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+    rewriter.eraseOp(op);
+    return failure();
   }
 };
 
@@ -231,14 +244,19 @@ struct TransferWritePermutationLowering
 ///     vector<1x8x16xf32>
 /// ```
 struct TransferWriteNonPermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_write inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -285,10 +303,14 @@ struct TransferWriteNonPermutationLowering
       newInBoundsValues.push_back(op.isDimInBounds(i));
     }
     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        newMask, newInBoundsAttr);
-    return success();
+    auto newWrite = rewriter.create<vector::TransferWriteOp>(
+        op.getLoc(), newVec, op.getSource(), op.getIndices(),
+        AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+    if (newWrite.hasPureTensorSemantics())
+      return newWrite.getResult();
+    // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+    rewriter.eraseOp(op);
+    return failure();
   }
 };
 
@@ -300,14 +322,19 @@ struct TransferWriteNonPermutationLowering
 ///     %v = vector.transfer_read ...
 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
 ///     vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
+struct TransferOpReduceRank
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     AffineMap map = op.getPermutationMap();
     unsigned numLeadingBroadcast = 0;
@@ -347,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
             op.getLoc(), originalVecType.getElementType(), op.getSource(),
             op.getIndices());
       }
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                       newRead);
-      return success();
+      return rewriter
+          .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+          .getVector();
     }
 
     SmallVector<int64_t> newShape(
@@ -371,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
         newInBoundsAttr);
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                     newRead);
-    return success();
+    return rewriter
+        .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+        .getVector();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index e48af3cd7aace..349dc1ab31d4c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -46,6 +46,51 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
     return
 }
 
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x?xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<16xf32>,
+//  CHECK-SAME:        %[[IDX:.*]]: index,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<16xi1>
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]]{{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
+  %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+  return %r : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
+//  CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>)
+//  CHECK-SAME:        -> tensor<?x?x?x?xf32> {
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+     %c0 = arith.constant 0 : index
+     %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+
+    return %r : tensor<?x?x?x?xf32>
+}
+
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
+//  CHECK-SAME:      %[[ARG0:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:      %[[ARG1:.*]]: vector<14x8x16xf32>
+//  CHECK-SAME:      %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+func.func @masked_non_permutation_xfer_write_fixed_width(
+    %arg0 : tensor<?x?x?x?xf32>,
+    %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
+  %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_read
 ///----------------------------------------------------------------------------------------
@@ -101,6 +146,37 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
   return %1 : vector<8x[4]x2xf32>
 }
 
+// transfer_read in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x1xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<4x1xi1>
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+  call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
+  return
+}
+func.func private @test.some_use(vector<1x4x4xf32>)
+
+// CHECK-LABEL:  func.func @masked_permutation_xfer_read_scalable(
+//  CHECK-SAME:      %[[ARG_0:.*]]: tensor<?x?xf32>,
+//  CHECK-SAME:      %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+//   CHECK-NOT:    vector.transpose
+//       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+
+  %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
+    {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
+    : tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+  return %1 : vector<8x[4]x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 2f2bdcaab5b3e..7c50bfa155472 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -219,6 +219,21 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
   return %res : vector<4x4xf32>
 }
 
+// CHECK-LABEL:   func @masked_transfer_read_reduce_rank_with_broadcast(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8x8x8xf32>,
+// CHECK-SAME:                                %[[MASK:.*]]: vector<4x4xi1>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
+//      CHECK:      %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4x4x4x4xf32>
+#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
+func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
+    {in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
+      : memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
+  return %res : vector<4x4x4x4xf32>
+}
+
 // More complex broadcasting case (here a `vector.load` is generated).
 // CHECK-LABEL:   func @transfer_broadcasting_complex(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<10x20x30x8x8xf32>,

@MacDue MacDue requested review from banach-space and MacDue May 16, 2024 19:05
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@nujaa nujaa force-pushed the hugo.MaskedTfReduceRank branch from dca2419 to d1acbb1 Compare May 22, 2024 08:53
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
// CHECK-NEXT: return %[[RES]] : vector<4x4x4x4xf32>
#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a “masked” version of an existing test? If yes, which one? (sorry if this is obvious and I missed it). If “not”, why not? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the masked version of the example in the Doxygen of TransferOpReduceRank.

@nujaa nujaa force-pushed the hugo.MaskedTfReduceRank branch from d1acbb1 to 1c3f8c8 Compare May 24, 2024 14:55
@nujaa
Copy link
Contributor Author

nujaa commented May 30, 2024

Ping 🏓 , thanks.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks okay to me. Is it NFC? Or it enables the support for scalable vectors? Can you add such information to PR description, thanks!

} : !transform.any_op
transform.yield
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a new line at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@nujaa nujaa force-pushed the hugo.MaskedTfReduceRank branch from 1c3f8c8 to 9a547ad Compare May 31, 2024 09:01
@nujaa
Copy link
Contributor Author

nujaa commented May 31, 2024

Looks okay to me. Is it NFC? Or it enables the support for scalable vectors? Can you add such information to PR description, thanks!

Hi, It its not really an NFC as the added testcase would segfault before this MR. It is more a bugfix. Is there a flag to mention Bugfixes ?

@nujaa nujaa merged commit 0170498 into llvm:main Jun 12, 2024
7 checks passed
// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this test different to @permutation_with_mask_xfer_read_scalable? Both seem to check for vector.transfer_read from memref into a scalable vector?

Copy link
Contributor Author

@nujaa nujaa Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a transposition in the permutation map of permutation_with_mask_xfer_read_scalable. It would hence also trigger TransferReadPermutationLowering which IIRC would fail without #91987 if I masked it with vector.mask. Changes were not completely independent, this one is a bit more unit test-like

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a transposition in the permutation map of permutation_with_mask_xfer_read_scalable.

Cool, right now that's very unclear, unless you happen to know what to look for. Such distinction should be captured either:

  • through comments, and/or
  • test function name.

(preferably both). That's currently not the case, so lets fix that. I'm suggesting the following plan-of-action:

  • add a big bold comment to document what's being tested here, something similar to what's on L94-L101
  • add no_transpose to the test functions names that you added.

I can take it from there.

I have two more high-level points. A bit tangential to this PR - more like TODOs for myself 😅

What test should be included in vector-transfer-permutation-lowering.mlir?

Btw, based on the file name and the patterns being tested here:

there should be no non-permutation examples in this file, right? To make things even more confusing, these patterns are also tested in:

  • transform-vector.mlir
  • vector-transfer-to-vector-load-store.mlir

😱 :) I'm on a mission to tidy this up a bit (while adding tests for scalable vectors).

Note on populateVectorTransferPermutationMapLoweringPatterns

Separately, it's bit confusing that TransferOpReduceRank is added to populateVectorTransferPermutationMapLoweringPatterns:

void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadPermutationLowering, TransferWritePermutationLowering,
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
patterns.getContext(), benefit);
}

I think that it would be easier to "categorise" tests if the granularity was lower. Another PR though. I can handle this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, it deserves refactoring and standardization. I could have done better on commenting and naming.

add a big bold comment to document what's being tested here, something similar to what's on L94-L101

Alternatively, separating tests with a big bold comment by pattern they are actually testing from the set could make sense.

add no_transpose to the test functions names that you added.

I think this quite goes against the idea of having a permutation specific file. as you mention there :

there should be no non-permutation examples in this file, right? To make things even more confusing, these patterns are also tested in:

But I think we should rather emphasize on the fact it is not a permutation lowering but a lowering based on the permutation map . Hence the no_transpose answer works.

these patterns are also tested in:

I would not mind about transform-vector.mlir as it is a full pipeline test. The goal is to show and test a complete lowering which makes sense. A bit like the e-2-e examples you added for SME.
For vector-transfer-to-vector-load-store.mlir , I think this file deserves improvements too. (such as the repetition of same transform sequences. ) The concept is to combine it with lower_transfer to test the complete lowering of xfer ops. it is not very unit test like but it shows how to properly lower xfer ops which is a benefit. Mixed feelings. I think there should not be transfer_permutation_patterns specific tests here, but I find sensible to test this combination of patterns. We should maybe ask @ftynse .

it's bit confusing that TransferOpReduceRank is added to populateVectorTransferPermutationMapLoweringPatterns:

Indeed, but after all it is a lowering based on the permutation map of a TransferRead. Hence, I think than rather removing TransferOpReduceRank, the ideal solution is to rename TransferOpReduceRank to make clear it matches patterns in the permutation map or renaming populateVectorTransferPermutationMapLoweringPatterns to reflect it is not mandatory a permutation. This set of patterns is useful as a whole for the community as a lowering pass.

I would also add something. As @hanhanW mentioned in #93664 (review) we could get rid of /// ``` and standardize it. in Doxygen descriptions for improved readability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in Doxygen descriptions for improved readability.

It is also helpful when I parse the comment in the code directly. I'd appreciate if someone can help make it consistent. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The syntax is mostly for markdown to me, so I don't know why people started adding such style to comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think we should rather emphasize on the fact it is not a permutation lowering but a lowering based on the permutation map . Hence the no_transpose answer works.

That's spot-on and something that I missed. In fact, this is documented here:

/// Collect a set of transfer read/write lowering patterns that simplify the
/// permutation map (e.g., converting it to a minor identity map) by inserting
/// broadcasts and transposes. More specifically:
///
/// [TransferReadPermutationLowering]
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identity +
/// vector.transpose op.
/// Ex:
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (0, d1)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (d1, 0)
/// vector.transpose %v, [1, 0]
///
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
/// vector.transpose %v, [0, 1, 3, 2, 4]
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
///
/// [TransferWritePermutationLowering]
/// Lower transfer_write op with permutation into a transfer_write with a
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
/// Ex:
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
/// into:
/// %tmp = vector.transpose %v, [2, 0, 1]
/// vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
///
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
/// into:
/// %tmp = vector.transpose %v, [1, 0]
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
///
/// [TransferOpReduceRank]
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
void populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

With this in mind, please ignore my comments under this heading:

Note on populateVectorTransferPermutationMapLoweringPatterns

In fact, sounds like vector-transfer-permutation-lowering.mlir is the right file for all patterns under populateVectorTransferPermutationMapLoweringPatterns.

Going back to the problem at hand ... This

Makes sense to me, it deserves refactoring and standardization.

and this:

Alternatively, separating tests with a big bold comment by pattern they are actually testing from the set could make sense.

First step in this direction: #95529. Let me know if that makes sense and I will prepare more tests.

I could have done better on commenting and naming.

Not your fault, it's hard to navigate this ATM (this is also why it takes me ages to review things). Clearly our test coverage is not great either. In general, Vector needs some TLC 😅 I really appreciate you helping us here 🙏🏻

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First step in this direction: #95529. Let me know if that makes sense and I will prepare more tests.

Next one: #96033

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants