Skip to content

Commit eaf1590

Browse files
authored
[mlir][ArmSME] Add support for vector.transpose (#66760)
This patch adds support for lowering vector.transpose to ArmSME. It's implemented by storing the input tile of the tranpose to memory and reloading vertically, building on top of the tile slice layout support. Tranposing via memory is obviously expensive, the current intention is to avoid the transpose if possible, this is therefore intended as a fallback and to provide base support for Vector ops. If it turns out transposes can't be avoided then this should be replaced with a more optimal implementation, perhaps with tile <-> vector (MOVA) ops. Depends on #66758.
1 parent f5cb9cb commit eaf1590

File tree

6 files changed

+316
-4
lines changed

6 files changed

+316
-4
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def ArmSME_Dialect : Dialect {
3636
https://developer.arm.com/documentation/ddi0616
3737
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
3838
}];
39-
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
39+
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
40+
"memref::MemRefDialect"];
4041
let useDefaultAttributePrinterParser = 1;
4142
}
4243

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1212
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1314
#include "mlir/IR/BuiltinTypes.h"
1415
#include "llvm/Support/Casting.h"
1516

@@ -239,11 +240,84 @@ struct BroadcastOpToArmSMELowering
239240
}
240241
};
241242

243+
/// Conversion pattern for vector.transpose.
244+
///
245+
/// Stores the input tile to memory and reloads vertically.
246+
///
247+
/// Example:
248+
///
249+
/// %transposed_src = vector.transpose %src, [1, 0]
250+
/// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
251+
///
252+
/// is converted to:
253+
///
254+
/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
255+
/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
256+
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
257+
/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
258+
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
259+
///
260+
/// NOTE: Tranposing via memory is obviously expensive, the current intention
261+
/// is to avoid the transpose if possible, this is therefore intended as a
262+
/// fallback and to provide base support for Vector ops. If it turns out
263+
/// transposes can't be avoided then this should be replaced with a more optimal
264+
/// implementation, perhaps with tile <-> vector (MOVA) ops.
265+
struct TransposeOpToArmSMELowering
266+
: public OpRewritePattern<vector::TransposeOp> {
267+
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
268+
269+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
270+
PatternRewriter &rewriter) const final {
271+
auto tileType = transposeOp.getResultVectorType();
272+
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
273+
return failure();
274+
275+
SmallVector<int64_t> transp;
276+
for (auto attr : transposeOp.getTransp())
277+
transp.push_back(cast<IntegerAttr>(attr).getInt());
278+
279+
// Bail unless this is a true 2-D matrix transpose.
280+
if (transp[0] != 1 || transp[1] != 0)
281+
return failure();
282+
283+
OpBuilder::InsertionGuard g(rewriter);
284+
auto loc = transposeOp.getLoc();
285+
286+
// Allocate buffer to store input tile to.
287+
Value vscale =
288+
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
289+
Value minTileSlices = rewriter.create<arith::ConstantOp>(
290+
loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
291+
Value c0 =
292+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
293+
Value numTileSlices =
294+
rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
295+
auto bufferType =
296+
MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
297+
tileType.getElementType());
298+
auto buffer = rewriter.create<memref::AllocaOp>(
299+
loc, bufferType, ValueRange{numTileSlices, numTileSlices});
300+
301+
Value input = transposeOp.getVector();
302+
303+
// Store input tile.
304+
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
305+
loc, input, buffer, ValueRange{c0, c0});
306+
307+
// Reload input tile vertically.
308+
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
309+
transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
310+
arm_sme::TileSliceLayout::Vertical);
311+
312+
return success();
313+
}
314+
};
315+
242316
} // namespace
243317

244318
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
245319
MLIRContext &ctx) {
246320
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
247321
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
248-
BroadcastOpToArmSMELowering>(&ctx);
322+
BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
249323
}

mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1516
#include "mlir/IR/DialectImplementation.h"
1617
#include "mlir/IR/TypeUtilities.h"
1718
#include "llvm/ADT/TypeSwitch.h"

mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
1111
LINK_LIBS PUBLIC
1212
MLIRIR
1313
MLIRLLVMDialect
14+
MLIRMemRefDialect
1415
MLIRSCFDialect
1516
MLIRSideEffectInterfaces
1617
MLIRVectorDialect

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
22

3+
//===----------------------------------------------------------------------===//
4+
// vector.transfer_write
5+
//===----------------------------------------------------------------------===//
6+
37
// CHECK-LABEL: func.func @transfer_write_2d_i8(
48
// CHECK-SAME: %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
59
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>) {
@@ -165,9 +169,9 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
165169
return
166170
}
167171

168-
// =============================================================================
172+
//===----------------------------------------------------------------------===//
169173
// vector.broadcast
170-
// =============================================================================
174+
//===----------------------------------------------------------------------===//
171175

172176
// -----
173177

@@ -215,3 +219,121 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
215219
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
216220
return
217221
}
222+
223+
//===----------------------------------------------------------------------===//
224+
// vector.transpose
225+
//===----------------------------------------------------------------------===//
226+
227+
// -----
228+
229+
// CHECK-LABEL: func.func @transpose_i8(
230+
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
231+
// CHECK: %[[C16:.*]] = arith.constant 16 : index
232+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
233+
// CHECK: %[[VSCALE:.*]] = vector.vscale
234+
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
235+
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
236+
// CHECK: arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
237+
// CHECK: arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
238+
func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
239+
%0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
240+
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
241+
return
242+
}
243+
244+
// -----
245+
246+
// CHECK-LABEL: @transpose_i16
247+
// CHECK: arith.constant 8
248+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
249+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
250+
func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
251+
%0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
252+
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
253+
return
254+
}
255+
256+
// -----
257+
258+
// CHECK-LABEL: @transpose_i32
259+
// CHECK: arith.constant 4
260+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
261+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
262+
func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
263+
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
264+
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
265+
return
266+
}
267+
268+
// -----
269+
270+
// CHECK-LABEL: @transpose_i64
271+
// CHECK: arith.constant 2
272+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
273+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
274+
func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
275+
%0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
276+
"prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
277+
return
278+
}
279+
280+
// -----
281+
282+
// CHECK-LABEL: @transpose_i128
283+
// CHECK: %[[VSCALE:.*]] = vector.vscale
284+
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
285+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
286+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
287+
func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
288+
%0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
289+
"prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
290+
return
291+
}
292+
293+
// -----
294+
295+
// CHECK-LABEL: @transpose_f16
296+
// CHECK: arith.constant 8
297+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
298+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
299+
func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
300+
%0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
301+
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
302+
return
303+
}
304+
305+
// -----
306+
307+
// CHECK-LABEL: @transpose_bf16
308+
// CHECK: arith.constant 8
309+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
310+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
311+
func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
312+
%0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
313+
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
314+
return
315+
}
316+
317+
// -----
318+
319+
// CHECK-LABEL: @transpose_f32
320+
// CHECK: arith.constant 4
321+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
322+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
323+
func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
324+
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
325+
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
326+
return
327+
}
328+
329+
// -----
330+
331+
// CHECK-LABEL: @transpose_f64
332+
// CHECK: arith.constant 2
333+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
334+
// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
335+
func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
336+
%0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
337+
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
338+
return
339+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// DEFINE: %{entry_point} = entry
2+
// DEFINE: %{compile} = mlir-opt %s \
3+
// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
4+
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
5+
// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
6+
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
7+
// DEFINE: %{run} = %mcr_aarch64_cmd \
8+
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
9+
// DEFINE: -e %{entry_point} -entry-point-result=void \
10+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
11+
12+
// RUN: %{compile} | %{run} | FileCheck %s
13+
14+
llvm.func @printCString(!llvm.ptr<i8>)
15+
16+
func.func @printTileBegin() {
17+
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
18+
%1 = llvm.mlir.constant(0 : index) : i64
19+
%2 = llvm.getelementptr %0[%1, %1]
20+
: (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
21+
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
22+
return
23+
}
24+
25+
func.func @printTileEnd() {
26+
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
27+
%1 = llvm.mlir.constant(0 : index) : i64
28+
%2 = llvm.getelementptr %0[%1, %1]
29+
: (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
30+
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
31+
return
32+
}
33+
34+
func.func @entry() {
35+
%c0 = arith.constant 0 : index
36+
%c1 = arith.constant 1 : index
37+
%c1_i32 = arith.constant 1 : i32
38+
39+
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
40+
%vscale = vector.vscale
41+
%min_elts_s = arith.constant 4 : index
42+
%svl_s = arith.muli %min_elts_s, %vscale : index
43+
%za_s_size = arith.muli %svl_s, %svl_s : index
44+
45+
// Allocate memory.
46+
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
47+
%mem2 = memref.alloca(%za_s_size) : memref<?xi32>
48+
49+
// Fill each "row" of "mem1" with row number.
50+
//
51+
// For example, assuming an SVL of 128-bits:
52+
//
53+
// 0, 0, 0, 0
54+
// 1, 1, 1, 1
55+
// 2, 2, 2, 2
56+
// 3, 3, 3, 3
57+
//
58+
%init_0 = arith.constant 0 : i32
59+
scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
60+
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
61+
vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
62+
%val_next = arith.addi %val, %c1_i32 : i32
63+
scf.yield %val_next : i32
64+
}
65+
66+
// Load tile from "mem1".
67+
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
68+
69+
// Transpose tile.
70+
%transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
71+
72+
// Store tile back to "mem2" to print.
73+
// TODO: Replace this with vector.print when
74+
// https://github.com/llvm/llvm-project/pull/66691 lands.
75+
vector.store %transposed_tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
76+
77+
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
78+
// 4x4xi32.
79+
//
80+
// CHECK: TILE BEGIN
81+
// CHECK-NEXT: ( 0, 0, 0, 0
82+
// CHECK-NEXT: ( 1, 1, 1, 1
83+
// CHECK-NEXT: ( 2, 2, 2, 2
84+
// CHECK-NEXT: ( 3, 3, 3, 3
85+
// CHECK: TILE END
86+
func.call @printTileBegin() : () -> ()
87+
scf.for %i = %c0 to %za_s_size step %svl_s {
88+
%tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
89+
vector.print %tileslice : vector<[4]xi32>
90+
}
91+
func.call @printTileEnd() : () -> ()
92+
93+
// Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
94+
// 4x4xi32.
95+
//
96+
// CHECK: TILE BEGIN
97+
// CHECK-NEXT: ( 0, 1, 2, 3
98+
// CHECK-NEXT: ( 0, 1, 2, 3
99+
// CHECK-NEXT: ( 0, 1, 2, 3
100+
// CHECK-NEXT: ( 0, 1, 2, 3
101+
// CHECK: TILE END
102+
func.call @printTileBegin() : () -> ()
103+
scf.for %i = %c0 to %za_s_size step %svl_s {
104+
%tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
105+
vector.print %tileslice : vector<[4]xi32>
106+
}
107+
func.call @printTileEnd() : () -> ()
108+
109+
return
110+
}
111+
112+
llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
113+
llvm.mlir.global internal constant @str_tile_end("TILE END\0A")

0 commit comments

Comments
 (0)