-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Lower transfer_write + transpose to vertical store #71181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Lower transfer_write + transpose to vertical store #71181
Conversation
This patch extends the lowering of vector.transfer_write in VectorToArmSME to support in-flight transpose via SME vertical store.
0c13506
to
bee71e7
Compare
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Cullen Rhodes (c-rhodes) ChangesThis patch extends the lowering of vector.transfer_write in Full diff: https://github.com/llvm/llvm-project/pull/71181.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 5491f7dd30629ad..a8956f0d38fba9d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
/// Conversion pattern for vector.transfer_write.
///
-/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
-/// memref<?x?xi8>
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+/// arm_sme.tile_store:
+///
+/// vector.transfer_write %vector, %source[%c0, %c0]
+/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
/// vector<[16]x[16]xi8>
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
+/// (in-flight transpose):
+///
+/// vector.transfer_write %vector, %source[%c0, %c0]
+/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
+/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
if (!arm_sme::isValidSMETileVectorType(vType))
return failure();
+ assert(writeOp.getTransferRank() == 2 &&
+ "expected a permutation_map with result dims of the same rank as "
+ "the vector type");
+
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();
+ // Out-of-bounds dims are not supported.
+ if (writeOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(writeOp,
+ "not inbounds transfer write");
+
+ arm_sme::TileSliceLayout layout;
+
+ AffineExpr d0, d1;
+ bindDims(writeOp.getContext(), d0, d1);
+ AffineMap map = writeOp.getPermutationMap();
+ if (map.isIdentity())
+ layout = arm_sme::TileSliceLayout::Horizontal;
+ else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+ writeOp.getContext()))
+ layout = arm_sme::TileSliceLayout::Vertical;
+ else
+ return rewriter.notifyMatchFailure(writeOp,
+ "unsupported permutation map");
+
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
- writeOp.getMask());
+ writeOp.getMask(), layout);
return success();
}
};
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ed33f8508dba0bf..a1ad25ed77aa8ef 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
// -----
+/// in-flight transpose via vertical store.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+ return
+}
+
+// -----
+
+/// in-flight transpose via vertical store with mask.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+ return
+}
+
+// -----
+
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
return
}
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// vector.broadcast
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index b599b976c3e1592..174cec857437a7d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
return
}
+// Vector store + transpose.
+func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+ vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
+// Masked vector store + transpose.
+func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
+ %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+ vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
// Vector load + print.
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -116,6 +135,25 @@ func.func @entry() {
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ // 4. Reload 3. + store + transpose.
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 0, 20, 30
+ // CHECK-NEXT: ( 0, 0, 21, 31
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ // CHECK-NEXT: ( 3, 13, 0, 0
+ call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+ // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
+ // The mask applies after permutation
+ // CHECK-LABEL: TILE BEGIN:
+ // CHECK-NEXT: ( 0, 0, 20, 30
+ // CHECK-NEXT: ( 0, 0, 21, 31
+ // CHECK-NEXT: ( 20, 21, 0, 0
+ // CHECK-NEXT: ( 30, 31, 0, 0
+ call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+ call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
memref.dealloc %A : memref<?x?xf32>
return
|
/// arm_sme.tile_store: | ||
/// | ||
/// vector.transfer_write %vector, %source[%c0, %c0] | ||
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> | ||
/// | ||
/// is converted to: | ||
/// | ||
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that this is now always passing layout
to the builder, shouldn't this be printed as:
arm_sme.tile_store %vector, %source[%c0, %c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
horizontal layout is the default and isn't printed, but I could add it to be explicit if you think that's clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an invalid assumption (I was expecting it to be printed since it was being passed to the builder). But, from: https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/#optional-and-default-valued-parameters
An optional parameter is omitted when it is equal to its default value.
So this is correct.
I could add it to be explicit if you think that's clearer
That would be my personal preference, but then what's the point of "default" values? No strong opinion, to me it's a bit confusing, but we can revisit some other time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference is the omit expected default values (and I think horizontal would be expected) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, just some usual nitpicking from me :)
Thanks!
/// arm_sme.tile_store: | ||
/// | ||
/// vector.transfer_write %vector, %source[%c0, %c0] | ||
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> | ||
/// | ||
/// is converted to: | ||
/// | ||
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an invalid assumption (I was expecting it to be printed since it was being passed to the builder). But, from: https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/#optional-and-default-valued-parameters
An optional parameter is omitted when it is equal to its default value.
So this is correct.
I could add it to be explicit if you think that's clearer
That would be my personal preference, but then what's the point of "default" values? No strong opinion, to me it's a bit confusing, but we can revisit some other time.
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing my comments!
…vm#71181) This patch extends the lowering of vector.transfer_write in VectorToArmSME to support in-flight transpose via SME vertical store.
This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.