Skip to content

[mlir][ArmSME] Lower transfer_write + transpose to vertical store #71181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Nov 3, 2023

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
@c-rhodes c-rhodes force-pushed the mlir-arm-sme-vertical-tile-store-lowering branch from 0c13506 to bee71e7 Compare November 6, 2023 11:28
@c-rhodes c-rhodes requested review from MacDue, banach-space and dcaballe and removed request for MacDue November 6, 2023 11:30
@c-rhodes c-rhodes marked this pull request as ready for review November 6, 2023 11:30
@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2023

@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.


Full diff: https://github.com/llvm/llvm-project/pull/71181.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+44-3)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+42)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+38)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 5491f7dd30629ad..a8956f0d38fba9d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
 
 /// Conversion pattern for vector.transfer_write.
 ///
-///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
-///                                                      memref<?x?xi8>
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+///            arm_sme.tile_store:
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
 ///
 /// is converted to:
 ///
 ///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
 ///                                                   vector<[16]x[16]xi8>
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
+///            (in-flight transpose):
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
     if (!arm_sme::isValidSMETileVectorType(vType))
       return failure();
 
+    assert(writeOp.getTransferRank() == 2 &&
+           "expected a permutation_map with result dims of the same rank as "
+           "the vector type");
+
     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
       return failure();
 
+    // Out-of-bounds dims are not supported.
+    if (writeOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "not inbounds transfer write");
+
+    arm_sme::TileSliceLayout layout;
+
+    AffineExpr d0, d1;
+    bindDims(writeOp.getContext(), d0, d1);
+    AffineMap map = writeOp.getPermutationMap();
+    if (map.isIdentity())
+      layout = arm_sme::TileSliceLayout::Horizontal;
+    else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                                   writeOp.getContext()))
+      layout = arm_sme::TileSliceLayout::Vertical;
+    else
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "unsupported permutation map");
+
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
         writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
-        writeOp.getMask());
+        writeOp.getMask(), layout);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ed33f8508dba0bf..a1ad25ed77aa8ef 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
 
 // -----
 
+/// in-flight transpose via vertical store.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+  return
+}
+
+// -----
+
+/// in-flight transpose via vertical store with mask.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
+// CHECK-SAME:                                                        %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME:                                                        %[[DEST:.*]]: memref<?x?xbf16>,
+// CHECK-SAME:                                                        %[[MASK:.*]]: vector<[8]x[8]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
   return
 }
 
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.broadcast
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index b599b976c3e1592..174cec857437a7d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
   return
 }
 
+// Vector store + transpose.
+func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Masked vector store + transpose.
+func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 // Vector load + print.
 func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -116,6 +135,25 @@ func.func @entry() {
   call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
   call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
 
+  // 4. Reload 3. + store + transpose.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 3, 13, 0, 0
+  call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
+  // The mask applies after permutation
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
   memref.dealloc %A : memref<?x?xf32>
 
   return

/// arm_sme.tile_store:
///
/// vector.transfer_write %vector, %source[%c0, %c0]
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is now always passing layout to the builder, shouldn't this be printed as:

arm_sme.tile_store %vector, %source[%c0, %c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

horizontal layout is the default and isn't printed, but I could add it to be explicit if you think that's clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an invalid assumption (I was expecting it to be printed since it was being passed to the builder). But, from: https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/#optional-and-default-valued-parameters

An optional parameter is omitted when it is equal to its default value.

So this is correct.

I could add it to be explicit if you think that's clearer

That would be my personal preference, but then what's the point of "default" values? No strong opinion, to me it's a bit confusing, but we can revisit some other time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference is the omit expected default values (and I think horizontal would be expected) :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, just some usual nitpicking from me :)

Thanks!

/// arm_sme.tile_store:
///
/// vector.transfer_write %vector, %source[%c0, %c0]
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an invalid assumption (I was expecting it to be printed since it was being passed to the builder). But, from: https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/#optional-and-default-valued-parameters

An optional parameter is omitted when it is equal to its default value.

So this is correct.

I could add it to be explicit if you think that's clearer

That would be my personal preference, but then what's the point of "default" values? No strong opinion, to me it's a bit confusing, but we can revisit some other time.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments!

@c-rhodes c-rhodes merged commit 4240b17 into llvm:main Nov 10, 2023
@c-rhodes c-rhodes deleted the mlir-arm-sme-vertical-tile-store-lowering branch November 10, 2023 07:51
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…vm#71181)

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants