Skip to content

Commit bfb5fe2

Browse files
authored
[mlir][ArmSME] Fold transpose into xfer read to enable in-flight transpose (#92562)
vector.transpose ops whose inputs come from vector.transfer_read can be eliminated by folding the transpose into the xfer op to enable in-flight transposition when converting xfer read to arm_sme.tile_load.
1 parent 63d8131 commit bfb5fe2

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
356356
return failure();
357357

358358
auto loc = transposeOp.getLoc();
359+
Value input = transposeOp.getVector();
360+
361+
if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
362+
xferOp && xferOp->hasOneUse()) {
363+
// Fold transpose into transfer_read to enable in-flight transpose when
364+
// converting to arm_sme.tile_load.
365+
rewriter.modifyOpInPlace(xferOp, [&]() {
366+
xferOp->setAttr(xferOp.getPermutationMapAttrName(),
367+
AffineMapAttr::get(AffineMap::getPermutationMap(
368+
permutation, transposeOp.getContext())));
369+
});
370+
rewriter.replaceOp(transposeOp, xferOp);
371+
return success();
372+
}
359373

360374
// Allocate buffer to store input tile to.
361375
Value vscale =
@@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
372386
auto buffer = rewriter.create<memref::AllocaOp>(
373387
loc, bufferType, ValueRange{numTileSlices, numTileSlices});
374388

375-
Value input = transposeOp.getVector();
376-
377389
// Store input tile.
378390
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
379391
loc, input, buffer, ValueRange{c0, c0});

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,39 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
150150

151151
// -----
152152

153+
// CHECK-LABEL: @fold_transpose_into_load
154+
// CHECK-NOT: arm_sme.tile_store
155+
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
156+
// CHECK-NOT: arm_sme.tile_store
157+
func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
158+
%c0 = arith.constant 0 : index
159+
%pad = arith.constant 0.0 : f32
160+
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
161+
%1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
162+
"prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
163+
}
164+
165+
// -----
166+
167+
/// Transposes with more than a single use cannot be folded into load and will
168+
/// instead be transposed via memory.
169+
170+
// CHECK-LABEL: @fold_transpose_into_load_multi_use
171+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
172+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
173+
// CHECK: %[[TILE_TRANSPOSED_VIA_MEM:.*]] = arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
174+
// CHECK: "prevent.dce"(%[[TILE_TRANSPOSED_VIA_MEM]]) : (vector<[4]x[4]xf32>) -> ()
175+
func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
176+
%c0 = arith.constant 0 : index
177+
%pad = arith.constant 0.0 : f32
178+
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
179+
"test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
180+
%1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
181+
"prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
182+
}
183+
184+
// -----
185+
153186
//===----------------------------------------------------------------------===//
154187
// vector.transfer_write
155188
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)