Skip to content

[mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose #92562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
return failure();

auto loc = transposeOp.getLoc();
Value input = transposeOp.getVector();

if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
xferOp && xferOp->hasOneUse()) {
// Fold transpose into transfer_read to enable in-flight transpose when
// converting to arm_sme.tile_load.
Comment on lines +363 to +364
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not blocking: this feels like something that should be part of the transfer_write lowering, rather than the transpose lowering. I think we could just removing this allocating vector.transpose lowering (as we've never actually needed it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we might if we enable SME for transposes in IREE and they don't match this pattern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memref is a problem for IREE though, and if we hit somewhere where we needed to lower a vector.transpose alone it'd probably be better just to use an extra tile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, we should update this lowering to use an extra tile then rather than remove.

rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getPermutationMapAttrName(),
AffineMapAttr::get(AffineMap::getPermutationMap(
permutation, transposeOp.getContext())));
});
rewriter.replaceOp(transposeOp, xferOp);
return success();
}

// Allocate buffer to store input tile to.
Value vscale =
Expand All @@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
auto buffer = rewriter.create<memref::AllocaOp>(
loc, bufferType, ValueRange{numTileSlices, numTileSlices});

Value input = transposeOp.getVector();

// Store input tile.
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
loc, input, buffer, ValueRange{c0, c0});
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,39 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas

// -----

// CHECK-LABEL: @fold_transpose_into_load
// CHECK-NOT: arm_sme.tile_store
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK-NOT: arm_sme.tile_store
func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
"prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
}

// -----

/// Transposes with more than a single use cannot be folded into load and will
/// instead be transposed via memory.

// CHECK-LABEL: @fold_transpose_into_load_multi_use
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: %[[TILE_TRANSPOSED_VIA_MEM:.*]] = arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: "prevent.dce"(%[[TILE_TRANSPOSED_VIA_MEM]]) : (vector<[4]x[4]xf32>) -> ()
func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
%1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
"prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
}

// -----

//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//
Expand Down
Loading