-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern #92426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CC @MacDue @banach-space . |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesImplements Split of #90835 Full diff: https://github.com/llvm/llvm-project/pull/92426.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f..63d3ec91e512f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -90,14 +90,19 @@ namespace {
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_read inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
- transposePerm);
- return success();
+ return rewriter
+ .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+ .getResult();
}
};
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_write inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -207,11 +217,14 @@ struct TransferWritePermutationLowering
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- op.getMask(), newInBoundsAttr);
-
- return success();
+ auto newWrite = rewriter.create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
+ if (newWrite.hasPureTensorSemantics())
+ return newWrite.getResult();
+ // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+ rewriter.eraseOp(op);
+ return failure();
}
};
@@ -231,14 +244,19 @@ struct TransferWritePermutationLowering
/// vector<1x8x16xf32>
/// ```
struct TransferWriteNonPermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_write inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -285,10 +303,14 @@ struct TransferWriteNonPermutationLowering
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- newMask, newInBoundsAttr);
- return success();
+ auto newWrite = rewriter.create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+ if (newWrite.hasPureTensorSemantics())
+ return newWrite.getResult();
+ // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+ rewriter.eraseOp(op);
+ return failure();
}
};
@@ -300,14 +322,19 @@ struct TransferWriteNonPermutationLowering
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
+struct TransferOpReduceRank
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
@@ -347,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
- newRead);
- return success();
+ return rewriter
+ .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+ .getVector();
}
SmallVector<int64_t> newShape(
@@ -371,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
- newRead);
- return success();
+ return rewriter
+ .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+ .getVector();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index e48af3cd7aace..349dc1ab31d4c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -46,6 +46,51 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
return
}
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
+// CHECK-NOT: vector.transpose
+// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]]{{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
+ %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+ return %r : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
+// CHECK-SAME: -> tensor<?x?x?x?xf32> {
+// CHECK-NOT: vector.transpose
+// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+
+ return %r : tensor<?x?x?x?xf32>
+}
+
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+// CHECK-NOT: vector.broadcast
+// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+func.func @masked_non_permutation_xfer_write_fixed_width(
+ %arg0 : tensor<?x?x?x?xf32>,
+ %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
+ %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
+
///----------------------------------------------------------------------------------------
/// vector.transfer_read
///----------------------------------------------------------------------------------------
@@ -101,6 +146,37 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
return %1 : vector<8x[4]x2xf32>
}
+// transfer_read in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
+// CHECK-NOT: vector.transpose
+// CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+ call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
+ return
+}
+func.func private @test.some_use(vector<1x4x4xf32>)
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+// CHECK-NOT: vector.transpose
+// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+
+ %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
+ {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
+ : tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+ return %1 : vector<8x[4]x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 2f2bdcaab5b3e..7c50bfa155472 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -219,6 +219,21 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
return %res : vector<4x4xf32>
}
+// CHECK-LABEL: func @masked_transfer_read_reduce_rank_with_broadcast(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8x8x8xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x4xi1>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
+// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4x4x4x4xf32>
+#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
+func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
+ {in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
+ : memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
+ return %res : vector<4x4x4x4xf32>
+}
+
// More complex broadcasting case (here a `vector.load` is generated).
// CHECK-LABEL: func @transfer_broadcasting_complex(
// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
dca2419
to
d1acbb1
Compare
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32> | ||
// CHECK-NEXT: return %[[RES]] : vector<4x4x4x4xf32> | ||
#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)> | ||
func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a “masked” version of an existing test? If yes, which one? (sorry if this is obvious and I missed it). If “not”, why not? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the masked version of the example in the Doxygen of TransferOpReduceRank
.
d1acbb1
to
1c3f8c8
Compare
Ping 🏓 , thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks okay to me. Is it NFC? Or it enables the support for scalable vectors? Can you add such information to PR description, thanks!
} : !transform.any_op | ||
transform.yield | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a new line at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
1c3f8c8
to
9a547ad
Compare
Hi, It its not really an NFC as the added testcase would segfault before this MR. It is more a bugfix. Is there a flag to mention Bugfixes ? |
// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32> | ||
// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32> | ||
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32> | ||
func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this test different to @permutation_with_mask_xfer_read_scalable
? Both seem to check for vector.transfer_read
from memref
into a scalable vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a transposition in the permutation map of permutation_with_mask_xfer_read_scalable
. It would hence also trigger TransferReadPermutationLowering
which IIRC would fail without #91987 if I masked it with vector.mask
. Changes were not completely independent, this one is a bit more unit test-like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a transposition in the permutation map of permutation_with_mask_xfer_read_scalable.
Cool, right now that's very unclear, unless you happen to know what to look for. Such distinction should be captured either:
- through comments, and/or
- test function name.
(preferably both). That's currently not the case, so lets fix that. I'm suggesting the following plan-of-action:
- add a big bold comment to document what's being tested here, something similar to what's on L94-L101
- add
no_transpose
to the test functions names that you added.
I can take it from there.
I have two more high-level points. A bit tangential to this PR - more like TODOs for myself 😅
What test should be included in vector-transfer-permutation-lowering.mlir?
Btw, based on the file name and the patterns being tested here:
- vector-transfer-permutation-lowering.mlir
transform.apply_patterns.vector.transfer_permutation_patterns
,
there should be no non-permutation examples in this file, right? To make things even more confusing, these patterns are also tested in:
- transform-vector.mlir
- vector-transfer-to-vector-load-store.mlir
😱 :) I'm on a mission to tidy this up a bit (while adding tests for scalable vectors).
Note on populateVectorTransferPermutationMapLoweringPatterns
Separately, it's bit confusing that TransferOpReduceRank
is added to populateVectorTransferPermutationMapLoweringPatterns
:
llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Lines 410 to 416 in 77db8b0
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( | |
RewritePatternSet &patterns, PatternBenefit benefit) { | |
patterns | |
.add<TransferReadPermutationLowering, TransferWritePermutationLowering, | |
TransferOpReduceRank, TransferWriteNonPermutationLowering>( | |
patterns.getContext(), benefit); | |
} |
I think that it would be easier to "categorise" tests if the granularity was lower. Another PR though. I can handle this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me, it deserves refactoring and standardization. I could have done better on commenting and naming.
add a big bold comment to document what's being tested here, something similar to what's on L94-L101
Alternatively, separating tests with a big bold comment by pattern they are actually testing from the set could make sense.
add no_transpose to the test functions names that you added.
I think this quite goes against the idea of having a permutation specific file. as you mention there :
there should be no non-permutation examples in this file, right? To make things even more confusing, these patterns are also tested in:
But I think we should rather emphasize on the fact it is not a permutation lowering but a lowering based on the permutation map . Hence the no_transpose
answer works.
these patterns are also tested in:
I would not mind about transform-vector.mlir
as it is a full pipeline test. The goal is to show and test a complete lowering which makes sense. A bit like the e-2-e examples you added for SME.
For vector-transfer-to-vector-load-store.mlir , I think this file deserves improvements too. (such as the repetition of same transform sequences. ) The concept is to combine it with lower_transfer to test the complete lowering of xfer ops. it is not very unit test like but it shows how to properly lower xfer ops which is a benefit. Mixed feelings. I think there should not be transfer_permutation_patterns
specific tests here, but I find sensible to test this combination of patterns. We should maybe ask @ftynse .
it's bit confusing that TransferOpReduceRank is added to populateVectorTransferPermutationMapLoweringPatterns:
Indeed, but after all it is a lowering based on the permutation map of a TransferRead. Hence, I think than rather removing TransferOpReduceRank
, the ideal solution is to rename TransferOpReduceRank
to make clear it matches patterns in the permutation map or renaming populateVectorTransferPermutationMapLoweringPatterns
to reflect it is not mandatory a permutation. This set of patterns is useful as a whole for the community as a lowering pass.
I would also add something. As @hanhanW mentioned in #93664 (review) we could get rid of /// ``` and standardize it.
in Doxygen descriptions for improved readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in Doxygen descriptions for improved readability.
It is also helpful when I parse the comment in the code directly. I'd appreciate if someone can help make it consistent. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The syntax is mostly for markdown to me, so I don't know why people started adding such style to comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think we should rather emphasize on the fact it is not a permutation lowering but a lowering based on the permutation map . Hence the no_transpose answer works.
That's spot-on and something that I missed. In fact, this is documented here:
llvm-project/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Lines 175 to 228 in 77db8b0
/// Collect a set of transfer read/write lowering patterns that simplify the | |
/// permutation map (e.g., converting it to a minor identity map) by inserting | |
/// broadcasts and transposes. More specifically: | |
/// | |
/// [TransferReadPermutationLowering] | |
/// Lower transfer_read op with permutation into a transfer_read with a | |
/// permutation map composed of leading zeros followed by a minor identity + | |
/// vector.transpose op. | |
/// Ex: | |
/// vector.transfer_read ... | |
/// permutation_map: (d0, d1, d2) -> (0, d1) | |
/// into: | |
/// %v = vector.transfer_read ... | |
/// permutation_map: (d0, d1, d2) -> (d1, 0) | |
/// vector.transpose %v, [1, 0] | |
/// | |
/// vector.transfer_read ... | |
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) | |
/// into: | |
/// %v = vector.transfer_read ... | |
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) | |
/// vector.transpose %v, [0, 1, 3, 2, 4] | |
/// Note that an alternative is to transform it to linalg.transpose + | |
/// vector.transfer_read to do the transpose in memory instead. | |
/// | |
/// [TransferWritePermutationLowering] | |
/// Lower transfer_write op with permutation into a transfer_write with a | |
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) | |
/// Ex: | |
/// vector.transfer_write %v ... | |
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) | |
/// into: | |
/// %tmp = vector.transpose %v, [2, 0, 1] | |
/// vector.transfer_write %tmp ... | |
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) | |
/// | |
/// vector.transfer_write %v ... | |
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) | |
/// into: | |
/// %tmp = vector.transpose %v, [1, 0] | |
/// %v = vector.transfer_write %tmp ... | |
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) | |
/// | |
/// [TransferOpReduceRank] | |
/// Lower transfer_read op with broadcast in the leading dimensions into | |
/// transfer_read of lower rank + vector.broadcast. | |
/// Ex: vector.transfer_read ... | |
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) | |
/// into: | |
/// %v = vector.transfer_read ... | |
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) | |
/// vector.broadcast %v | |
void populateVectorTransferPermutationMapLoweringPatterns( | |
RewritePatternSet &patterns, PatternBenefit benefit = 1); |
With this in mind, please ignore my comments under this heading:
Note on populateVectorTransferPermutationMapLoweringPatterns
In fact, sounds like vector-transfer-permutation-lowering.mlir is the right file for all patterns under populateVectorTransferPermutationMapLoweringPatterns
.
Going back to the problem at hand ... This
Makes sense to me, it deserves refactoring and standardization.
and this:
Alternatively, separating tests with a big bold comment by pattern they are actually testing from the set could make sense.
First step in this direction: #95529. Let me know if that makes sense and I will prepare more tests.
I could have done better on commenting and naming.
Not your fault, it's hard to navigate this ATM (this is also why it takes me ages to review things). Clearly our test coverage is not great either. In general, Vector needs some TLC 😅 I really appreciate you helping us here 🙏🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implements
TransferOpReduceRank
as aMaskableOpRewritePattern
. Allowing to exit gracefully when run on avector::transfer_read
located inside avector::MaskOp
instead of breaking because the pattern generated multiple ops in the MaskOp witherror: 'vector.mask' op expects only one operation to mask
.Split of #90835