-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Vector] Fix transferOps optimization inside maskOp #90835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesSome optimizations on This commit fixes two patterns Origin discussed here. Full diff: https://github.com/llvm/llvm-project/pull/90835.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb7314..ac63f93c1d756b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -100,11 +100,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
SmallVector<Value> indices(readOp.getIndices().begin(),
readOp.getIndices().end());
SmallVector<Value> sourceIndices;
+ // In case transfer_read is located inside a MaskOp we want to avoid creating
+ // more ops inside it.
+ if (isa<vector::MaskOp>(readOp->getParentOp()))
+ rewriter.setInsertionPoint(readOp->getParentOp());
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
indices, sourceIndices);
+ // Reset the insertion point.
+ rewriter.setInsertionPoint(readOp);
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
AffineMapAttr::get(expandDimsToRank(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f4..51a9d52cbe3880 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -98,6 +98,9 @@ struct TransferReadPermutationLowering
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ if (isa<vector::MaskOp>(op->getParentOp()))
+ return rewriter.notifyMatchFailure(
+ op, "Cannot expand transfer read inside a Mask Op");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -173,6 +176,9 @@ struct TransferWritePermutationLowering
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ if (isa<vector::MaskOp>(op->getParentOp()))
+ return rewriter.notifyMatchFailure(
+ op, "Cannot expand transfer write inside a Mask Op");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -239,6 +245,9 @@ struct TransferWriteNonPermutationLowering
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ if (isa<vector::MaskOp>(op->getParentOp()))
+ return rewriter.notifyMatchFailure(
+ op, "Cannot expand transfer write inside a Mask Op");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
index 6213db3956f9a1..214b41461b98f6 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
@@ -111,3 +111,18 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
return %1 : tensor<?x?x12xf32>
}
+
+// CHECK-LABEL: func @masked_transfer_read_of_extract_slice
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[m:.*]] = vector.create_mask{{.*}} : vector<5x6xi1>
+// CHECK-DAG: %[[a:.*]] = affine.apply {{.*}}[[s1]]
+// CHECK: vector.mask %[[m]] { vector.transfer_read %[[t]]{{.*}}: tensor<?x?xf32>, vector<5x6xf32> } : vector<5x6xi1> -> vector<5x6xf32>
+func.func @masked_transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %mask = vector.create_mask %c3, %c4 : vector<5x6xi1>
+ %1 = vector.mask %mask {vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>} : vector<5x6xi1> -> vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 31bd19c0be8e83..ec2cd478923cca 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -59,6 +59,37 @@ func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16
return
}
+
+
+#map = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+#map1 = affine_map<(d0, d1) -> (d0, 0, d1)>
+// CHECK-LABEL: func @masked_permutation_transfer_read
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
+// CHECK: vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<4x4x1xf32> } : vector<4x1xi1> -> vector<4x4x1xf32>
+func.func @masked_permutation_transfer_read(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map1} : tensor<?x1xf32>, vector<4x4x1xf32> } : vector<4x1xi1> -> vector<4x4x1xf32>
+ call @dostuff(%3) : (vector<4x4x1xf32>) -> ()
+ return
+}
+func.func private @dostuff(vector<4x4x1xf32>)
+
+
+// CHECK-LABEL: func @masked_permutation_transfer_write
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
+// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+// CHECK: return %[[RES]]
+func.func @masked_permutation_transfer_write(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?x?xf32> {
+ %r = vector.mask %m0 { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+ return %r : tensor<?x?xf32>
+}
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
LGTM, but I would rather defer to somebody more versed into masking approve this. |
Thanks for working on this and sorry for the delay! Any chance you could use: ? Feels like that's exactly what we need here. And you might be able to use it for |
ba819cb
to
d2bd488
Compare
d2bd488
to
fa53c9c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, looks good!
I've noticed that the following patterns have been updated:
TransferReadOfExtractSliceOpFolder
TransferReadPermutationLowering
TransferWritePermutationLowering
TransferWriteNonPermutationLowering
TransferOpReduceRank
TransferReadToVectorLoadLowering
TransferWriteToVectorStoreLowering
However, there's only 3 new tests :) And it's hard to tell which tests corresponds to which pattern. You could address that by creating multiple PRs and group similar patterns together (e.g. TransferReadToVectorLoadLowering
and TransferWriteToVectorStoreLowering
). Just a loose suggestion, not a requirement.
In any case, it would be good to see new tests for every of these patterns. In particular, if a pattern doesn't support masking, we should still be able to test that.
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be supported?
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be supported?
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be suppported?
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be supported?
if (maskOp) | ||
return rewriter.notifyMatchFailure(read, "Masked case not supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be supported?
if (maskOp) | ||
return rewriter.notifyMatchFailure(write, "Masked case not supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be supported?
mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which patterns do these two tests check? And is there a "non-masked" version of the tests that you added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't find a unit-test calling transfer_permutation_patterns
only with a non masked version. Might be worth implementing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is : func.func @transfer_write_broadcast_unit_dim(
in mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
for the pattern TransferWriteNonPermutationLowering
(not well placed imo). Do you think I should add the masked case with it or move them both to vector-transfer-permutation-lowering.mlir
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking!
From what I can tell, the tests in "mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir" are "meta" tests" that:
- focus on checking
transform.apply_patterns.vector.lower_transfer max_transfer_rank
, - in some cases have
transfer_permutation_patterns
added on top.
That particular test (transfer_write_broadcast_unit_dim
) doesn't check for vector.load
/vector.store
, so doesn't seem to belong in "vector-transfer-to-vector-load-store.mlir". IMHO, that file should be audited first 😅 In the meantime, let's focus on "vector-transfer-permutation-lowering.mlir".
Now, I think that this file could also benefit from some additional comments and small re-org. This way it will be easier to see what cases are being tested. ATM that's not really clear and I'm to blame 😅 Trying to fix here:
Could you take a look?
Do you think I should add the masked case with it or move them both to vector-transfer-permutation-lowering.mlir ?
I think that what you have here is sufficient. There are 3 possibilities:
- non-masked,
vector.xfer_read
with mask,- masked
vector.xfer_read|write
(i.e. withvector.mask
)
Option 2 is already tested and that effectively covers 1. as well. So we are only missing 3., right? And that's what you are testing.
Updates tests "vector-transfer-permutation-lowering.mlir" to make a clearer split into tests for : * xfer_read vs xfer_write * fixed-width vs scalable tests This is in preparation for llvm#90835 and also for adding more tests for scalable vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. I think that it would be good to document "vector-transfer-permutation-lowering.mlir" a tiny bit first:
Is that OK with you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking!
From what I can tell, the tests in "mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir" are "meta" tests" that:
- focus on checking
transform.apply_patterns.vector.lower_transfer max_transfer_rank
, - in some cases have
transfer_permutation_patterns
added on top.
That particular test (transfer_write_broadcast_unit_dim
) doesn't check for vector.load
/vector.store
, so doesn't seem to belong in "vector-transfer-to-vector-load-store.mlir". IMHO, that file should be audited first 😅 In the meantime, let's focus on "vector-transfer-permutation-lowering.mlir".
Now, I think that this file could also benefit from some additional comments and small re-org. This way it will be easier to see what cases are being tested. ATM that's not really clear and I'm to blame 😅 Trying to fix here:
Could you take a look?
Do you think I should add the masked case with it or move them both to vector-transfer-permutation-lowering.mlir ?
I think that what you have here is sufficient. There are 3 possibilities:
- non-masked,
vector.xfer_read
with mask,- masked
vector.xfer_read|write
(i.e. withvector.mask
)
Option 2 is already tested and that effectively covers 1. as well. So we are only missing 3., right? And that's what you are testing.
Updates tests in "vector-transfer-permutation-lowering.mlir" to make a clearer split into cases for : * xfer_read vs xfer_write * fixed-width vs scalable tests A new test case is added for fixed-width vectors for vector.transfer_read. This is to complement an existing test for scalable vectors. This is in preparation for #90835 and also for adding more tests for scalable vectors.
Is the plan to rebase this on top of #91987? |
Yes. As mentionned in your previous comment, there are lots of different patterns updated.
I have submitted :
|
Sorry, I missed your reply!
I think that you can land these? Do you have commit access?
Nice!
I would create a new PR. Otherwise this discussion will look disconnected from the actual PR. But there are no hard rules, I'm still trying to figure out the "canonical GitHub way" 😅 |
Nope 🤗
OK. Will do in due times 😃 . |
…ewritePattern (#91987) * Implements `TransferWritePermutationLowering`, `TransferReadPermutationLowering` and `TransferWriteNonPermutationLowering` as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a xferOp is inside a `vector::MaskOp` * Updates MaskableOpRewritePattern to handle MemRefs and buffer semantics providing empty `Value()` as a return value for `matchAndRewriteMaskableOp` now represents successful rewriting without value to replace the original op. Split of #90835
Close as split in multiple MRs. |
…tern (#92426) Implements `TransferOpReduceRank` as a `MaskableOpRewritePattern`. Allowing to exit gracefully when run on a `vector::transfer_read` located inside a `vector::MaskOp` instead of generating `error: 'vector.mask' op expects only one operation to mask` because the pattern generated multiple ops inside the MaskOp. Split of #90835
…RewritePattern (#92892) Implements `TransferReadToVectorLoadLowering` and `TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`. Allowing to exit gracefully when run on an xferOp located inside a `vector::MaskOp` instead of breaking because the pattern generated multiple ops in the MaskOp with `error: 'vector.mask' op expects only one operation to mask`. Split of #90835
…RewritePattern (llvm#92892) Implements `TransferReadToVectorLoadLowering` and `TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`. Allowing to exit gracefully when run on an xferOp located inside a `vector::MaskOp` instead of breaking because the pattern generated multiple ops in the MaskOp with `error: 'vector.mask' op expects only one operation to mask`. Split of llvm#90835
Some optimizations on
vector.transfer_[read|write]
generate other ops. When the transfer op is located inside a MaskOp, compilation will break witherror: 'vector.mask' op expects only one operation to mask
.This commit fixes two patterns
tensor.fold_tensor_subset_ops_into_vector_transfers
andTransferPermutationMapLoweringPatterns
. In theTransferPermutationMapLoweringPatterns
, I preferred returning a failure as hoisting a transpose out of theMaskOp
would require to update the maskOp which is cumbersome and I think it is already taken care of invector.lower_masked_transfers
.Origin discussed here.