Skip to content

[mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose #92562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2024

Conversation

c-rhodes
Copy link
Collaborator

vector.transpose ops whose inputs come from vector.transfer_read can be eliminated by folding the transpose into the xfer op to enable in-flight transposition when converting xfer read to arm_sme.tile_load.

…spose

vector.transpose ops whose inputs come from vector.transfer_read can be
eliminated by folding the transpose into the xfer op to enable in-flight
transposition when converting xfer read to arm_sme.tile_load.
@llvmbot
Copy link
Member

llvmbot commented May 17, 2024

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

vector.transpose ops whose inputs come from vector.transfer_read can be eliminated by folding the transpose into the xfer op to enable in-flight transposition when converting xfer read to arm_sme.tile_load.


Full diff: https://github.com/llvm/llvm-project/pull/92562.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+14-2)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+12)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d8e473a562e53..b1b84705da7d3 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
       return failure();
 
     auto loc = transposeOp.getLoc();
+    Value input = transposeOp.getVector();
+
+    if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>()) {
+      // Fold transpose into transfer_read to enable in-flight transpose when
+      // converting to arm_sme.tile_load.
+      rewriter.modifyOpInPlace(xferOp, [&]() {
+        SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
+        xferOp->setAttr(xferOp.getPermutationMapAttrName(),
+                        AffineMapAttr::get(AffineMap::getPermutationMap(
+                            permutation, transposeOp.getContext())));
+      });
+      rewriter.replaceOp(transposeOp, xferOp);
+      return success();
+    }
 
     // Allocate buffer to store input tile to.
     Value vscale =
@@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
     auto buffer = rewriter.create<memref::AllocaOp>(
         loc, bufferType, ValueRange{numTileSlices, numTileSlices});
 
-    Value input = transposeOp.getVector();
-
     // Store input tile.
     auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
         loc, input, buffer, ValueRange{c0, c0});
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index ce0b46e0f061a..48e92ce88ed16 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -150,6 +150,18 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
 
 // -----
 
+// CHECK-LABEL: @fold_transpose_into_load
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
 //===----------------------------------------------------------------------===//

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments:

Comment on lines +362 to +363
// Fold transpose into transfer_read to enable in-flight transpose when
// converting to arm_sme.tile_load.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not blocking: this feels like something that should be part of the transfer_write lowering, rather than the transpose lowering. I think we could just removing this allocating vector.transpose lowering (as we've never actually needed it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we might if we enable SME for transposes in IREE and they don't match this pattern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memref is a problem for IREE though, and if we hit somewhere where we needed to lower a vector.transpose alone it'd probably be better just to use an extra tile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, we should update this lowering to use an extra tile then rather than remove.

@c-rhodes c-rhodes merged commit bfb5fe2 into llvm:main May 21, 2024
4 checks passed
@c-rhodes c-rhodes deleted the mlir-arm-sme-fold-transpose branch May 21, 2024 07:08
c-rhodes added a commit to c-rhodes/iree that referenced this pull request May 23, 2024
Since [1] transposes may get eliminated by folding into defining
vector.transfer_read.

[1] llvm/llvm-project#92562
c-rhodes added a commit to c-rhodes/iree that referenced this pull request May 23, 2024
Since [1] transposes may get eliminated by folding into defining
vector.transfer_read.

[1] llvm/llvm-project#92562

Signed-off-by: Cullen Rhodes <[email protected]>
c-rhodes added a commit to c-rhodes/iree that referenced this pull request May 30, 2024
Since [1] transposes may get eliminated by folding into defining
vector.transfer_read.

[1] llvm/llvm-project#92562

Signed-off-by: Cullen Rhodes <[email protected]>
c-rhodes added a commit to iree-org/iree that referenced this pull request Jun 4, 2024
This patch enables transposes for ArmSME for f32 and f64 types now that they
may get eliminated by folding into defining `vector.transfer_read` since [1].
This is done by extending `setTransposeLikeOpRootConfig` which currently only
supports x86 (no changes there). The transpose is represented by a
`linalg.generic`, since `linalg.transpose` gets converted by
`GeneralizeLinalgNamedOps`.

[1] llvm/llvm-project#92562

ci-extra: build_test_all_arm64

---------

Signed-off-by: Cullen Rhodes <[email protected]>
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Jun 5, 2024
This patch enables transposes for ArmSME for f32 and f64 types now that they
may get eliminated by folding into defining `vector.transfer_read` since [1].
This is done by extending `setTransposeLikeOpRootConfig` which currently only
supports x86 (no changes there). The transpose is represented by a
`linalg.generic`, since `linalg.transpose` gets converted by
`GeneralizeLinalgNamedOps`.

[1] llvm/llvm-project#92562

ci-extra: build_test_all_arm64

---------

Signed-off-by: Cullen Rhodes <[email protected]>
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
This patch enables transposes for ArmSME for f32 and f64 types now that they
may get eliminated by folding into defining `vector.transfer_read` since [1].
This is done by extending `setTransposeLikeOpRootConfig` which currently only
supports x86 (no changes there). The transpose is represented by a
`linalg.generic`, since `linalg.transpose` gets converted by
`GeneralizeLinalgNamedOps`.

[1] llvm/llvm-project#92562

ci-extra: build_test_all_arm64

---------

Signed-off-by: Cullen Rhodes <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants