-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add tests xfer-permute-lowering (nfc)(2/n) #96033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add tests xfer-permute-lowering (nfc)(2/n) #96033
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/96033.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2bf4f16f96e6a..d6855d59204cd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3859,9 +3859,6 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
if (op.getPermutationMap().isMinorIdentity())
elidedAttrs.push_back(op.getPermutationMapAttrName());
- // Elide in_bounds attribute if all dims are out-of-bounds.
- if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
- elidedAttrs.push_back(op.getInBoundsAttrName());
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index e1babdd2f1f63..ea57e2afbaa2b 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -174,7 +174,7 @@ func.func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I2]], %[[I6]])
// CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW3]][%[[I4]], %[[I5]], %[[I6]]] : memref<3x4x1xvector<5xf32>>
- // CHECK: vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] : vector<5xf32>, memref<?x?x?x?xf32>
+ // CHECK: vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] {in_bounds = [false]} : vector<5xf32>, memref<?x?x?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: }
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index d7ff1ded9d933..7176dff9bc857 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -950,7 +950,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-NOT: tensor.pad
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5.0
-// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
+// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] {in_bounds = [false, false]} : tensor<5x6xf32>, vector<7x9xf32>
// CHECK: return %[[RESULT]]
func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
%c0 = arith.constant 0 : index
@@ -984,7 +984,7 @@ func.func private @make_vector() -> vector<7x9xf32>
// CHECK-NOT: tensor.pad
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
-// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
+// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] {in_bounds = [false, false]} : vector<7x9xf32>, tensor<5x6xf32>
// CHECK: return %[[RESULT]]
func.func @pad_and_transfer_write_static(
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
@@ -1021,7 +1021,7 @@ func.func private @make_vector() -> vector<7x9xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
-// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
+// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] {in_bounds = [false, false]} : vector<7x9xf32>, tensor<?x6xf32>
// CHECK: return %[[RESULT]]
func.func @pad_and_transfer_write_dynamic_static(
%arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c868c881d079a..afe62a2427fb0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -78,7 +78,7 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
- // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
@@ -135,7 +135,7 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%9 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
%10 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
- // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
%11 = vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
%12 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 35418b38df9b2..f2aec3ada4f01 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// TODO: Replace %arg0 with %vec
+
///----------------------------------------------------------------------------------------
/// vector.transfer_write -> vector.transpose + vector.transfer_write
/// [Pattern: TransferWritePermutationLowering]
@@ -31,6 +33,27 @@ func.func @xfer_write_transposing_permutation_map(
return
}
+// Even with out-of-bounds, it is safe to apply this pattern
+// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
+// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
+// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
+// CHECK: vector.transfer_write
+// CHECK-NOT: permutation_map
+// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [false, false]} : vector<8x4xi16>, memref<2x2x?x?xi16>
+func.func @xfer_write_transposing_permutation_map_out_of_bounds(
+ %arg0: vector<4x8xi16>,
+ %mem: memref<2x2x?x?xi16>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+ in_bounds = [false, false],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ } : vector<4x8xi16>, memref<2x2x?x?xi16>
+
+ return
+}
+
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
@@ -83,19 +106,44 @@ func.func @xfer_write_transposing_permutation_map_masked(
/// * vector.broadcast + vector.transpose + vector.transfer_write with a map
/// which _is_ a permutation of a minor identity
-// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
-// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
-// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
-// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
-// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
-// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
-func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index,
- %base2 : index) {
-
- %fn1 = arith.constant -2.0 : f32
- %vf0 = vector.splat %fn1 : vector<7xf32>
- %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
- vector.transfer_write %vf0, %mem[%base1, %base2], %mask
+// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map(
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
+// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index) {
+// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
+// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
+// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
+func.func @xfer_write_non_transposing_permutation_map(
+ %mem : memref<?x?xf32>,
+ %arg0 : vector<7xf32>,
+ %base1 : index,
+ %base2 : index) {
+
+ vector.transfer_write %arg0, %mem[%base1, %base2]
+ {permutation_map = affine_map<(d0, d1) -> (d0)>}
+ : vector<7xf32>, memref<?x?xf32>
+ return
+}
+
+// The broadcast dimension is in bounds, so the transformation is safe
+// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
+// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<7xi1>) {
+// CHECK: %[[BC_VEC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
+// CHECK: %[[BC_MASK:.*]] = vector.broadcast %[[MASK]] : vector<7xi1> to vector<1x7xi1>
+// CHECK: %[[TR_MASK:.*]] = vector.transpose %[[BC_MASK]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
+// CHECK: %[[TR_VEC:.*]] = vector.transpose %[[BC_VEC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
+// CHECK: vector.transfer_write %[[TR_VEC]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]], %[[TR_MASK]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
+func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
+ %mem : memref<?x?xf32>,
+ %arg0 : vector<7xf32>,
+ %base1 : index,
+ %base2 : index,
+ %mask : vector<7xi1>) {
+
+ vector.transfer_write %arg0, %mem[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
: vector<7xf32>, memref<?x?xf32>
return
|
Adds more tests to: * vector-transfer-permutation-lowering.mlir Specifically, adds tests for: * out-of-bounds access for the `TransferWritePermutationLowering` pattern * in-bounds access for `TransferWriteNonPermutationLowering` + `TransferWritePermutationLowering` Also renames `@permutation_with_mask_xfer_write_fixed_width` as `@xfer_write_non_transposing_permutation_map`. This is a part of a larger effort to make sure that all key cases for patterns under populateVectorTransferPermutationMapLoweringPatterns (*) are tested. I also want to make sure that tests use consistent function and variable names. (*) transform.apply_patterns.vector.transfer_permutation_patterns in TD parlance) Depends on llvm#96031
Rebase and updated based on recent changes
1b16266
to
96f30e9
Compare
@nujaa I've finally landed a few updates for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, a couple of NITs depending on what you want to do with this list of MRs in the future. Thanks for pushing with this. LGTM. Suggested a change to keep track of potential other refactoring to take care in this file.
@@ -1,5 +1,7 @@ | |||
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s | |||
|
|||
// TODO: Replace %arg0 with %vec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, do you think it would make sense to refactor all args the same way you did in vector-transfer-flatten.mlir
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huge +1 - I will update the comment to reflect that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you updated the file AND accepted my suggestion generating repetition 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yikes! Fixed in the latest commit :)
%mem : memref<?x?xf32>, | ||
%arg0 : vector<7xf32>, | ||
%base1 : index, | ||
%base2 : index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: I'd suggest to use %idx1
to match with other tests such as masked_permutation_xfer_write_fixed_width
and your refacto of transfer-flatten.
I am not unhappy if you make it as part of a separate commit to tackle the above TODO. Also getting rid of %dim
args and use %idx
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't mind me expanding this PR, I'll do it here. Ultimately, this needs to happen and it's all about making the review process as smooth/easy as possible. If you are happy then I'm also happy :)
Also getting rid of %dim args and use %idx.
This one I'm a bit unsure of - %dim
is used for mask dimension rather than xfer_{read|write} index 🤔
@@ -31,6 +33,31 @@ func.func @xfer_write_transposing_permutation_map( | |||
return | |||
} | |||
|
|||
// Even with out-of-bounds, it is safe to apply this pattern | |||
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds | |||
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Argument definitions are aligned with function body in the test below. Although xfer_write_transposing_permutation_map
also follows this formatting. Both are valid I guess. 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's prioritise consistency - I will update this before landing. Thanks!
Apply suggestions from Hugo
…nfc)(2/n) Remove repeated comment
@nujaa Here's a follow-up to address some of your suggestions re formatting; More to come 😅 |
Adds more tests to:
Specifically, adds tests for:
TransferWritePermutationLowering
pattern
TransferWriteNonPermutationLowering
+TransferWritePermutationLowering
Also renames
@permutation_with_mask_xfer_write_fixed_width
as@xfer_write_non_transposing_permutation_map
.This is a part of a larger effort to make sure that all key cases for
patterns under populateVectorTransferPermutationMapLoweringPatterns
(*) are tested. I also want to make sure that tests use consistent
function and variable names.
(*) transform.apply_patterns.vector.transfer_permutation_patterns in
TD parlance)