-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose #92562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…spose vector.transpose ops whose inputs come from vector.transfer_read can be eliminated by folding the transpose into the xfer op to enable in-flight transposition when converting xfer read to arm_sme.tile_load.
@llvm/pr-subscribers-mlir Author: Cullen Rhodes (c-rhodes) Changesvector.transpose ops whose inputs come from vector.transfer_read can be eliminated by folding the transpose into the xfer op to enable in-flight transposition when converting xfer read to arm_sme.tile_load. Full diff: https://github.com/llvm/llvm-project/pull/92562.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d8e473a562e53..b1b84705da7d3 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
return failure();
auto loc = transposeOp.getLoc();
+ Value input = transposeOp.getVector();
+
+ if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>()) {
+ // Fold transpose into transfer_read to enable in-flight transpose when
+ // converting to arm_sme.tile_load.
+ rewriter.modifyOpInPlace(xferOp, [&]() {
+ SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
+ xferOp->setAttr(xferOp.getPermutationMapAttrName(),
+ AffineMapAttr::get(AffineMap::getPermutationMap(
+ permutation, transposeOp.getContext())));
+ });
+ rewriter.replaceOp(transposeOp, xferOp);
+ return success();
+ }
// Allocate buffer to store input tile to.
Value vscale =
@@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
auto buffer = rewriter.create<memref::AllocaOp>(
loc, bufferType, ValueRange{numTileSlices, numTileSlices});
- Value input = transposeOp.getVector();
-
// Store input tile.
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
loc, input, buffer, ValueRange{c0, c0});
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index ce0b46e0f061a..48e92ce88ed16 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -150,6 +150,18 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
// -----
+// CHECK-LABEL: @fold_transpose_into_load
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+ "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few comments:
// Fold transpose into transfer_read to enable in-flight transpose when | ||
// converting to arm_sme.tile_load. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not blocking: this feels like something that should be part of the transfer_write lowering, rather than the transpose lowering. I think we could just removing this allocating vector.transpose
lowering (as we've never actually needed it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps we might if we enable SME for transposes in IREE and they don't match this pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memref is a problem for IREE though, and if we hit somewhere where we needed to lower a vector.transpose
alone it'd probably be better just to use an extra tile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, we should update this lowering to use an extra tile then rather than remove.
Since [1] transposes may get eliminated by folding into defining vector.transfer_read. [1] llvm/llvm-project#92562
Since [1] transposes may get eliminated by folding into defining vector.transfer_read. [1] llvm/llvm-project#92562 Signed-off-by: Cullen Rhodes <[email protected]>
Since [1] transposes may get eliminated by folding into defining vector.transfer_read. [1] llvm/llvm-project#92562 Signed-off-by: Cullen Rhodes <[email protected]>
This patch enables transposes for ArmSME for f32 and f64 types now that they may get eliminated by folding into defining `vector.transfer_read` since [1]. This is done by extending `setTransposeLikeOpRootConfig` which currently only supports x86 (no changes there). The transpose is represented by a `linalg.generic`, since `linalg.transpose` gets converted by `GeneralizeLinalgNamedOps`. [1] llvm/llvm-project#92562 ci-extra: build_test_all_arm64 --------- Signed-off-by: Cullen Rhodes <[email protected]>
This patch enables transposes for ArmSME for f32 and f64 types now that they may get eliminated by folding into defining `vector.transfer_read` since [1]. This is done by extending `setTransposeLikeOpRootConfig` which currently only supports x86 (no changes there). The transpose is represented by a `linalg.generic`, since `linalg.transpose` gets converted by `GeneralizeLinalgNamedOps`. [1] llvm/llvm-project#92562 ci-extra: build_test_all_arm64 --------- Signed-off-by: Cullen Rhodes <[email protected]>
This patch enables transposes for ArmSME for f32 and f64 types now that they may get eliminated by folding into defining `vector.transfer_read` since [1]. This is done by extending `setTransposeLikeOpRootConfig` which currently only supports x86 (no changes there). The transpose is represented by a `linalg.generic`, since `linalg.transpose` gets converted by `GeneralizeLinalgNamedOps`. [1] llvm/llvm-project#92562 ci-extra: build_test_all_arm64 --------- Signed-off-by: Cullen Rhodes <[email protected]> Signed-off-by: Lubo Litchev <[email protected]>
vector.transpose ops whose inputs come from vector.transfer_read can be eliminated by folding the transpose into the xfer op to enable in-flight transposition when converting xfer read to arm_sme.tile_load.