Skip to content

Commit bee71e7

Browse files
committed
[mlir][ArmSME] Lower transfer_write + transpose to vertical store
This patch extends the lowering of vector.transfer_write in VectorToArmSME to support in-flight transpose via SME vertical store.
1 parent ed350bb commit bee71e7

File tree

3 files changed

+124
-3
lines changed

3 files changed

+124
-3
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
136136

137137
/// Conversion pattern for vector.transfer_write.
138138
///
139-
/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
140-
/// memref<?x?xi8>
139+
/// ---
140+
///
141+
/// Example 1: op with identity permutation map to horizontal
142+
/// arm_sme.tile_store:
143+
///
144+
/// vector.transfer_write %vector, %source[%c0, %c0]
145+
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
141146
///
142147
/// is converted to:
143148
///
144149
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
145150
/// vector<[16]x[16]xi8>
151+
/// ---
152+
///
153+
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
154+
/// (in-flight transpose):
155+
///
156+
/// vector.transfer_write %vector, %source[%c0, %c0]
157+
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
158+
/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
159+
///
160+
/// is converted to:
161+
///
162+
/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
163+
/// : memref<?x?xi8>, vector<[16]x[16]xi8>
146164
struct TransferWriteToArmSMELowering
147165
: public OpRewritePattern<vector::TransferWriteOp> {
148166
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
153171
if (!arm_sme::isValidSMETileVectorType(vType))
154172
return failure();
155173

174+
assert(writeOp.getTransferRank() == 2 &&
175+
"expected a permutation_map with result dims of the same rank as "
176+
"the vector type");
177+
156178
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
157179
return failure();
158180

181+
// Out-of-bounds dims are not supported.
182+
if (writeOp.hasOutOfBoundsDim())
183+
return rewriter.notifyMatchFailure(writeOp,
184+
"not inbounds transfer write");
185+
186+
arm_sme::TileSliceLayout layout;
187+
188+
AffineExpr d0, d1;
189+
bindDims(writeOp.getContext(), d0, d1);
190+
AffineMap map = writeOp.getPermutationMap();
191+
if (map.isIdentity())
192+
layout = arm_sme::TileSliceLayout::Horizontal;
193+
else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
194+
writeOp.getContext()))
195+
layout = arm_sme::TileSliceLayout::Vertical;
196+
else
197+
return rewriter.notifyMatchFailure(writeOp,
198+
"unsupported permutation map");
199+
159200
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
160201
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
161-
writeOp.getMask());
202+
writeOp.getMask(), layout);
162203
return success();
163204
}
164205
};

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
337337

338338
// -----
339339

340+
/// in-flight transpose via vertical store.
341+
342+
// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
343+
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
344+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
345+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
346+
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
347+
func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
348+
%c0 = arith.constant 0 : index
349+
vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
350+
return
351+
}
352+
353+
// -----
354+
355+
/// in-flight transpose via vertical store with mask.
356+
357+
// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
358+
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
359+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>,
360+
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) {
361+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
362+
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
363+
func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
364+
%c0 = arith.constant 0 : index
365+
vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
366+
return
367+
}
368+
369+
// -----
370+
340371
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
341372
// lowering only occurs for vector types of correct rank, shape, element size
342373
// and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
398429
return
399430
}
400431

432+
// -----
433+
434+
// CHECK-LABEL: @transfer_write_2d__out_of_bounds
435+
// CHECK: vector.transfer_write
436+
// CHECK-NOT: arm_sme.tile_store
437+
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
438+
%c0 = arith.constant 0 : index
439+
vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
440+
return
441+
}
442+
401443
//===----------------------------------------------------------------------===//
402444
// vector.broadcast
403445
//===----------------------------------------------------------------------===//

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
3232
return
3333
}
3434

35+
// Vector store + transpose.
36+
func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
37+
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
38+
vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
39+
vector<[4]x[4]xf32>, memref<?x?xf32>
40+
return
41+
}
42+
43+
// Masked vector store + transpose.
44+
func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
45+
%c2 = arith.constant 2 : index
46+
%c4 = arith.constant 4 : index
47+
%mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
48+
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
49+
vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
50+
vector<[4]x[4]xf32>, memref<?x?xf32>
51+
return
52+
}
53+
3554
// Vector load + print.
3655
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
3756
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -116,6 +135,25 @@ func.func @entry() {
116135
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
117136
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
118137

138+
// 4. Reload 3. + store + transpose.
139+
// CHECK-LABEL: TILE BEGIN:
140+
// CHECK-NEXT: ( 0, 0, 20, 30
141+
// CHECK-NEXT: ( 0, 0, 21, 31
142+
// CHECK-NEXT: ( 0, 0, 0, 0
143+
// CHECK-NEXT: ( 3, 13, 0, 0
144+
call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
145+
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
146+
147+
// 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
148+
// The mask applies after permutation
149+
// CHECK-LABEL: TILE BEGIN:
150+
// CHECK-NEXT: ( 0, 0, 20, 30
151+
// CHECK-NEXT: ( 0, 0, 21, 31
152+
// CHECK-NEXT: ( 20, 21, 0, 0
153+
// CHECK-NEXT: ( 30, 31, 0, 0
154+
call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
155+
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
156+
119157
memref.dealloc %A : memref<?x?xf32>
120158

121159
return

0 commit comments

Comments
 (0)