Skip to content

[MLIR][Vector] Fix transferOps optimization inside maskOp #90835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 2, 2024

Some optimizations on vector.transfer_[read|write] generate other ops. When the transfer op is located inside a MaskOp, compilation will break with error: 'vector.mask' op expects only one operation to mask.

This commit fixes two patterns tensor.fold_tensor_subset_ops_into_vector_transfers and TransferPermutationMapLoweringPatterns. In the TransferPermutationMapLoweringPatterns, I preferred returning a failure as hoisting a transpose out of the MaskOp would require to update the maskOp which is cumbersome and I think it is already taken care of in vector.lower_masked_transfers.

Origin discussed here.

@llvmbot
Copy link
Member

llvmbot commented May 2, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

Some optimizations on vector.transfer_[read|write] generate other ops. When the transfer op is located inside a MaskOp, compilation will break with error: 'vector.mask' op expects only one operation to mask.

This commit fixes two patterns tensor.fold_tensor_subset_ops_into_vector_transfers and TransferPermutationMapLoweringPatterns. In the TransferPermutationMapLoweringPatterns, I preferred returning a failure as hoisting a transpose out of the MaskOp would require to update the maskOp which is cumbersome and I think it is already taken care of in vector.lower_masked_transfers.

Origin discussed here.


Full diff: https://github.com/llvm/llvm-project/pull/90835.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+9)
  • (modified) mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir (+15)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+31)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb7314..ac63f93c1d756b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -100,11 +100,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
   SmallVector<Value> indices(readOp.getIndices().begin(),
                              readOp.getIndices().end());
   SmallVector<Value> sourceIndices;
+  // In case transfer_read is located inside a MaskOp we want to avoid creating
+  // more ops inside it.
+  if (isa<vector::MaskOp>(readOp->getParentOp()))
+    rewriter.setInsertionPoint(readOp->getParentOp());
   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
       extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
       indices, sourceIndices);
 
+  // Reset the insertion point.
+  rewriter.setInsertionPoint(readOp);
   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
       readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
       AffineMapAttr::get(expandDimsToRank(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f4..51a9d52cbe3880 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -98,6 +98,9 @@ struct TransferReadPermutationLowering
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    if (isa<vector::MaskOp>(op->getParentOp()))
+      return rewriter.notifyMatchFailure(
+          op, "Cannot expand transfer read inside a Mask Op");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -173,6 +176,9 @@ struct TransferWritePermutationLowering
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    if (isa<vector::MaskOp>(op->getParentOp()))
+      return rewriter.notifyMatchFailure(
+          op, "Cannot expand transfer write inside a Mask Op");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -239,6 +245,9 @@ struct TransferWriteNonPermutationLowering
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    if (isa<vector::MaskOp>(op->getParentOp()))
+      return rewriter.notifyMatchFailure(
+          op, "Cannot expand transfer write inside a Mask Op");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
index 6213db3956f9a1..214b41461b98f6 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
@@ -111,3 +111,18 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
   %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
   return %1 : tensor<?x?x12xf32>
 }
+
+// CHECK-LABEL: func @masked_transfer_read_of_extract_slice
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//       CHECK-DAG: %[[m:.*]] = vector.create_mask{{.*}} : vector<5x6xi1>
+//       CHECK-DAG: %[[a:.*]] = affine.apply {{.*}}[[s1]]
+//       CHECK: vector.mask %[[m]] { vector.transfer_read %[[t]]{{.*}}: tensor<?x?xf32>, vector<5x6xf32> } : vector<5x6xi1> -> vector<5x6xf32>
+func.func @masked_transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %mask = vector.create_mask %c3, %c4 : vector<5x6xi1>
+  %1 = vector.mask %mask {vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>} : vector<5x6xi1> -> vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 31bd19c0be8e83..ec2cd478923cca 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -59,6 +59,37 @@ func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16
 
     return
 }
+
+
+#map = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+#map1 = affine_map<(d0, d1) -> (d0, 0, d1)>
+// CHECK-LABEL: func @masked_permutation_transfer_read
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x1xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<4x1xi1>
+//       CHECK: vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<4x4x1xf32> } : vector<4x1xi1> -> vector<4x4x1xf32>
+func.func @masked_permutation_transfer_read(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map1} : tensor<?x1xf32>, vector<4x4x1xf32> } : vector<4x1xi1> -> vector<4x4x1xf32>
+  call @dostuff(%3) : (vector<4x4x1xf32>) -> ()
+  return
+}
+func.func private @dostuff(vector<4x4x1xf32>)
+
+
+// CHECK-LABEL: func @masked_permutation_transfer_write
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x?xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<16xf32>,
+//  CHECK-SAME:        %[[IDX:.*]]: index,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<16xi1>
+//       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+//       CHECK:   return %[[RES]]
+func.func @masked_permutation_transfer_write(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?x?xf32> {
+  %r = vector.mask %m0 { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+  return %r : tensor<?x?xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@nujaa
Copy link
Contributor Author

nujaa commented May 2, 2024

Hi @ftynse @banach-space @dcaballe .

@ftynse ftynse requested review from dcaballe, banach-space and ftynse May 2, 2024 14:33
@ftynse
Copy link
Member

ftynse commented May 2, 2024

LGTM, but I would rather defer to somebody more versed into masking approve this.

@banach-space
Copy link
Contributor

Thanks for working on this and sorry for the delay! Any chance you could use:

? Feels like that's exactly what we need here. And you might be able to use it for TransferPermutationMapLoweringPatterns as well.

@nujaa nujaa force-pushed the hugo.fixVectorization branch from ba819cb to d2bd488 Compare May 7, 2024 15:46
@nujaa nujaa force-pushed the hugo.fixVectorization branch from d2bd488 to fa53c9c Compare May 7, 2024 15:58
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, looks good!

I've noticed that the following patterns have been updated:

  • TransferReadOfExtractSliceOpFolder
  • TransferReadPermutationLowering
  • TransferWritePermutationLowering
  • TransferWriteNonPermutationLowering
  • TransferOpReduceRank
  • TransferReadToVectorLoadLowering
  • TransferWriteToVectorStoreLowering

However, there's only 3 new tests :) And it's hard to tell which tests corresponds to which pattern. You could address that by creating multiple PRs and group similar patterns together (e.g. TransferReadToVectorLoadLowering and TransferWriteToVectorStoreLowering). Just a loose suggestion, not a requirement.

In any case, it would be good to see new tests for every of these patterns. In particular, if a pattern doesn't support masking, we should still be able to test that.

Comment on lines +103 to +104
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?

Comment on lines +182 to +183
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?

Comment on lines +252 to +253
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be suppported?

Comment on lines +327 to +328
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?

Comment on lines +438 to +439
if (maskOp)
return rewriter.notifyMatchFailure(read, "Masked case not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?

Comment on lines +589 to +590
if (maskOp)
return rewriter.notifyMatchFailure(write, "Masked case not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which patterns do these two tests check? And is there a "non-masked" version of the tests that you added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a unit-test calling transfer_permutation_patterns only with a non masked version. Might be worth implementing.

Copy link
Contributor Author

@nujaa nujaa May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is : func.func @transfer_write_broadcast_unit_dim( in mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir for the pattern TransferWriteNonPermutationLowering (not well placed imo). Do you think I should add the masked case with it or move them both to vector-transfer-permutation-lowering.mlir ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking!

From what I can tell, the tests in "mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir" are "meta" tests" that:

  • focus on checking transform.apply_patterns.vector.lower_transfer max_transfer_rank,
  • in some cases have transfer_permutation_patterns added on top.

That particular test (transfer_write_broadcast_unit_dim) doesn't check for vector.load/vector.store, so doesn't seem to belong in "vector-transfer-to-vector-load-store.mlir". IMHO, that file should be audited first 😅 In the meantime, let's focus on "vector-transfer-permutation-lowering.mlir".

Now, I think that this file could also benefit from some additional comments and small re-org. This way it will be easier to see what cases are being tested. ATM that's not really clear and I'm to blame 😅 Trying to fix here:

Could you take a look?

Do you think I should add the masked case with it or move them both to vector-transfer-permutation-lowering.mlir ?

I think that what you have here is sufficient. There are 3 possibilities:

  1. non-masked,
  2. vector.xfer_read with mask,
  3. masked vector.xfer_read|write (i.e. with vector.mask)

Option 2 is already tested and that effectively covers 1. as well. So we are only missing 3., right? And that's what you are testing.

banach-space added a commit to banach-space/llvm-project that referenced this pull request May 13, 2024
Updates tests "vector-transfer-permutation-lowering.mlir" to make a
clearer split into tests for :
  * xfer_read vs xfer_write
  * fixed-width vs scalable tests

This is in preparation for llvm#90835 and also for adding more tests for
scalable vectors.
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I think that it would be good to document "vector-transfer-permutation-lowering.mlir" a tiny bit first:

Is that OK with you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking!

From what I can tell, the tests in "mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir" are "meta" tests" that:

  • focus on checking transform.apply_patterns.vector.lower_transfer max_transfer_rank,
  • in some cases have transfer_permutation_patterns added on top.

That particular test (transfer_write_broadcast_unit_dim) doesn't check for vector.load/vector.store, so doesn't seem to belong in "vector-transfer-to-vector-load-store.mlir". IMHO, that file should be audited first 😅 In the meantime, let's focus on "vector-transfer-permutation-lowering.mlir".

Now, I think that this file could also benefit from some additional comments and small re-org. This way it will be easier to see what cases are being tested. ATM that's not really clear and I'm to blame 😅 Trying to fix here:

Could you take a look?

Do you think I should add the masked case with it or move them both to vector-transfer-permutation-lowering.mlir ?

I think that what you have here is sufficient. There are 3 possibilities:

  1. non-masked,
  2. vector.xfer_read with mask,
  3. masked vector.xfer_read|write (i.e. with vector.mask)

Option 2 is already tested and that effectively covers 1. as well. So we are only missing 3., right? And that's what you are testing.

banach-space added a commit that referenced this pull request May 13, 2024
Updates tests in "vector-transfer-permutation-lowering.mlir" to make a
clearer split into cases for :

* xfer_read vs xfer_write
* fixed-width vs scalable tests

A new test case is added for fixed-width vectors for vector.transfer_read.
This is to complement an existing test for scalable vectors.

This is in preparation for #90835 and also for adding more tests for
scalable vectors.
@banach-space
Copy link
Contributor

Is the plan to rebase this on top of #91987?

@nujaa
Copy link
Contributor Author

nujaa commented May 14, 2024

Is the plan to rebase this on top of #91987?

Yes. As mentionned in your previous comment, there are lots of different patterns updated.

Thanks for the updates, looks good!

I've noticed that the following patterns have been updated:

  • TransferReadOfExtractSliceOpFolder
  • TransferReadPermutationLowering
  • TransferWritePermutationLowering
  • TransferWriteNonPermutationLowering
  • TransferOpReduceRank
  • TransferReadToVectorLoadLowering
  • TransferWriteToVectorStoreLowering

However, there's only 3 new tests :) And it's hard to tell which tests corresponds to which pattern. You could address that by creating multiple PRs and group similar patterns together (e.g. TransferReadToVectorLoadLowering and TransferWriteToVectorStoreLowering). Just a loose suggestion, not a requirement.

I have submitted :

@banach-space
Copy link
Contributor

Sorry, I missed your reply!

I have submitted :

I think that you can land these? Do you have commit access?

  • I have a branch ready to be pushed on top of them to address TransferOpReduceRank.

Nice!

  • I can reuse this PR to merge Transfer{Read|write}ToVector{Load|store}Lowering or create another one if you think it is a better practice.

I would create a new PR. Otherwise this discussion will look disconnected from the actual PR. But there are no hard rules, I'm still trying to figure out the "canonical GitHub way" 😅

@nujaa
Copy link
Contributor Author

nujaa commented May 16, 2024

I think that you can land these? Do you have commit access?

Nope 🤗

I would create a new PR. Otherwise this discussion will look disconnected from the actual PR. But there are no hard rules, I'm still trying to figure out the "canonical GitHub way" 😅

OK. Will do in due times 😃 .

banach-space pushed a commit that referenced this pull request May 16, 2024
…writePattern (#91960)

Split of #90835
Adds support for `TransferReadOfExtractSliceOpFolder` when the
`TransferReadOp` is inside a `MaskOp`.
banach-space pushed a commit that referenced this pull request May 20, 2024
…ewritePattern (#91987)

* Implements `TransferWritePermutationLowering`,
`TransferReadPermutationLowering` and
`TransferWriteNonPermutationLowering` as a MaskableOpRewritePattern.
Allowing to exit gracefully when such use of a xferOp is inside a
`vector::MaskOp`
* Updates MaskableOpRewritePattern to handle MemRefs and buffer
semantics providing empty `Value()` as a return value for
`matchAndRewriteMaskableOp` now represents successful rewriting without
value to replace the original op.

Split of #90835
@nujaa
Copy link
Contributor Author

nujaa commented May 21, 2024

Close as split in multiple MRs.

@nujaa nujaa closed this May 21, 2024
nujaa added a commit that referenced this pull request Jun 12, 2024
…tern (#92426)

Implements `TransferOpReduceRank` as a `MaskableOpRewritePattern`.
Allowing to exit gracefully when run on a `vector::transfer_read`
located inside a `vector::MaskOp` instead of generating  `error: 'vector.mask'
op expects only one operation to mask` because the
pattern generated multiple ops inside the MaskOp.

Split of #90835
nujaa added a commit that referenced this pull request Jun 18, 2024
…RewritePattern (#92892)

Implements `TransferReadToVectorLoadLowering` and
`TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`.
Allowing to exit gracefully when run on an xferOp located inside a
`vector::MaskOp` instead of breaking because the pattern generated
multiple ops in the MaskOp with `error: 'vector.mask' op expects only
one operation to mask`.

Split of #90835
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…RewritePattern (llvm#92892)

Implements `TransferReadToVectorLoadLowering` and
`TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`.
Allowing to exit gracefully when run on an xferOp located inside a
`vector::MaskOp` instead of breaking because the pattern generated
multiple ops in the MaskOp with `error: 'vector.mask' op expects only
one operation to mask`.

Split of llvm#90835
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants