Skip to content

[MLIR][Vector] Implement transferXXPermutationLowering as MaskableOpRewritePattern #91987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 20, 2024

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 13, 2024

  • Implements TransferWritePermutationLowering, TransferReadPermutationLowering and TransferWriteNonPermutationLowering as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a xferOp is inside a vector::MaskOp
  • Updates MaskableOpRewritePattern to handle MemRefs and buffer semantics providing empty Value() as a return value for matchAndRewriteMaskableOp now represents successful rewriting without value to replace the original op.

Split of #90835

@nujaa
Copy link
Contributor Author

nujaa commented May 13, 2024

@banach-space

@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Hugo Trachino (nujaa)

Changes

Implements TransferWritePermutationLowering, TransferReadPermutationLowering and TransferWriteNonPermutationLowering as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a xferOp is inside a vector::MaskOp

Split of #90835


Full diff: https://github.com/llvm/llvm-project/pull/91987.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+40-24)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+77)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f..7f5703b635068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -90,14 +90,19 @@ namespace {
 /// Note that an alternative is to transform it to linalg.transpose +
 /// vector.transfer_read to do the transpose in memory instead.
 struct TransferReadPermutationLowering
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_read inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
 
     // Transpose result of transfer_read.
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
-                                                     transposePerm);
-    return success();
+    return rewriter
+        .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+        .getResult();
   }
 };
 
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
 ///     %v = vector.transfer_write %tmp ...
 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
 struct TransferWritePermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_write inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -207,11 +217,11 @@ struct TransferWritePermutationLowering
         op.getLoc(), op.getVector(), indices);
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        op.getMask(), newInBoundsAttr);
-
-    return success();
+    return rewriter
+        .create<vector::TransferWriteOp>(
+            op.getLoc(), newVec, op.getSource(), op.getIndices(),
+            AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
+        .getResult();
   }
 };
 
@@ -231,14 +241,19 @@ struct TransferWritePermutationLowering
 ///     vector<1x8x16xf32>
 /// ```
 struct TransferWriteNonPermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_write inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -285,10 +300,11 @@ struct TransferWriteNonPermutationLowering
       newInBoundsValues.push_back(op.isDimInBounds(i));
     }
     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        newMask, newInBoundsAttr);
-    return success();
+    return rewriter
+        .create<vector::TransferWriteOp>(
+            op.getLoc(), newVec, op.getSource(), op.getIndices(),
+            AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
+        .getResult();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index e48af3cd7aace..a53e2a9e50ba2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -46,6 +46,52 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
     return
 }
 
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x?xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<16xf32>,
+//  CHECK-SAME:        %[[IDX:.*]]: index,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<16xi1>
+//       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+//       CHECK:   return %[[RES]]
+func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
+  %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+  return %r : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:           func.func @masked_permutation_xfer_write_scalable(
+//  CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>)
+//  CHECK-SAME:        -> tensor<?x?x?x?xf32> {
+//       CHECK:             %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:             %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+//       CHECK:             return %[[R]] : tensor<?x?x?x?xf32>
+func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+     %c0 = arith.constant 0 : index
+     %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+
+    return %r : tensor<?x?x?x?xf32>
+}
+
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
+//  CHECK-SAME:      %[[ARG0:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:      %[[ARG1:.*]]: vector<14x8x16xf32>
+//  CHECK-SAME:      %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+func.func @masked_non_permutation_xfer_write_fixed_width(
+    %arg0 : tensor<?x?x?x?xf32>,
+    %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
+  %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+  // CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_read
 ///----------------------------------------------------------------------------------------
@@ -101,6 +147,37 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
   return %1 : vector<8x[4]x2xf32>
 }
 
+// transfer_read in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x1xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<4x1xi1>
+//       CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+  call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
+  return
+}
+func.func private @test.some_use(vector<1x4x4xf32>)
+
+// CHECK-LABEL:   func.func @masked_permutation_xfer_read_scalable(
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+// CHECK:           return %[[T_READ]] : vector<8x[4]x2xf32>
+func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+
+  %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
+    {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
+    : tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+  return %1 : vector<8x[4]x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@banach-space banach-space self-requested a review May 13, 2024 18:47
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I really like how the code is gradually becoming self-documenting :)

LGTM!

@nujaa nujaa force-pushed the hugo.maskTfPermutationLowering branch from 74894d7 to c11da46 Compare May 17, 2024 12:27
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Hugo, sorry for the delay with this.

Having read this again, I am realising that I forgot about MemRef semantics when implementing MaskableOpRewritePattern - thanks for fixing that! I think that it would be good to capture that with some additional comments - see my suggestions inline. It would also be good to updated the summary accordingly (something along the lines:

Updates MaskableOpRewritePattern so that it works correctly with MemRefs.

Feel free to re-use and/or re-write.

AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
// In memref case, MaskableOpRewritePattern cannot replaceOp with result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// In memref case, MaskableOpRewritePattern cannot replaceOp with result.
// In the memref case there's no return value. Use empty value to signal success.

Comment on lines 160 to 163
if (rootOp->getNumResults() == 0 || *newOp == Value())
rewriter.eraseOp(rootOp);
else
rewriter.replaceOp(rootOp, *newOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the only case that we are testing now is when "there's no return value" and "newOp is Value()". Hence I'm suggesting to replace || with &&.

Suggested change
if (rootOp->getNumResults() == 0 || *newOp == Value())
rewriter.eraseOp(rootOp);
else
rewriter.replaceOp(rootOp, *newOp);
// In the memref case there won't be a return value to replace. Instead, use an empty value to signal success.
if (rootOp->getNumResults() == 0 && *newOp == Value())
rewriter.eraseOp(rootOp);
else
rewriter.replaceOp(rootOp, *newOp);

Copy link
Contributor Author

@nujaa nujaa May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late answer, I have been thinking about it while implementing and did not come up with a solution I liked. With the weekend fresh mind, Here is my point returning Value() means it did NOT fail. aka code updates happened but no value to give e.g. memref case.
if we split cases :

if *newOp == Value() 
|  if NumResult == 0 // simple case
|  |  rewriter.eraseOp(rootOp); 
|  else 
|  |  // We have to replace something with a value with Value() so there might be uses of rootOp in the rest 
|  |  // of the program if we try to erase it. So I suggest to raise an error.
|  |  raise Error(); 
if pattern returns a value:
|  if NumResult == 1 // simple case
|  | rewriter.replaceOp(rootOp, *newOp);
|  else // We created ops with a value which should replace something without a value. We can't use it in the program. It will most likely be DCE-ed.
|  |  rewriter.eraseOp(rootOp); 

Which can then be reduced to

if (failed(newOp))
  return failure();
if NumResult == 0
  rewriter.eraseOp(rootOp); 
else 
  assert(*newOp != Value() && "Can't replace an op use with Value()");
  rewriter.replaceOp(rootOp, *newOp);
return success()

As an additionnal point, technically, matchAndRewriteMaskableOp could return a ValueRange as replaceOp takes a ValueRange as input. replaceOp will assert rootOp->getNumResults() != newOp.size(). And will allow to handle cases where ops have multiple results. But I suggest as part of a separate patch.

@nujaa
Copy link
Contributor Author

nujaa commented May 20, 2024

Thanks for your comments. I interpreted your comments slightly differently. Feel free to debate or merge if you are satisfied.

@banach-space
Copy link
Contributor

We are on the same page here, thanks for seeing this through!

One thing that this discussion makes me question - should vector.mask allow memref semantics? Is that needed at all? I’m leaning towards “no”, but that’s a discussion for a different PR.

@banach-space banach-space merged commit fdd245a into llvm:main May 20, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants