Skip to content

Commit 87f2964

Browse files
[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM
1 parent 4ab7765 commit 87f2964

File tree

5 files changed

+630
-0
lines changed

5 files changed

+630
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// REQUIRES: arm-emulator
2+
3+
// DEFINE: %{compile} = mlir-opt %s \
4+
// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
5+
// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
6+
// DEFINE: -o %t
7+
8+
// DEFINE: %{entry_point} = main
9+
10+
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
11+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
12+
13+
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
14+
15+
#packed_maps = [
16+
affine_map<(d0, d1, d2) -> (d0, d2)>,
17+
affine_map<(d0, d1, d2) -> (d1, d2)>,
18+
affine_map<(d0, d1, d2) -> (d0, d1)>
19+
]
20+
21+
func.func private @setArmVLBits(%bits : i32)
22+
23+
func.func @main() {
24+
%c128 = arith.constant 128 : i32
25+
func.call @setArmVLBits(%c128) : (i32) -> ()
26+
27+
%c0 = arith.constant 0 : index
28+
%c0_i32 = arith.constant 0 : i32
29+
%c0_i8 = arith.constant 0 : i8
30+
31+
// Accumulator test data
32+
%acc_cst = arith.constant dense<[[-44, 20, 44, -46],
33+
[ -8, 25, -34, 26],
34+
[-20, -36, -3, 39],
35+
[-48, -31, -25, -21]]> : vector<4x4xi32>
36+
%acc_m = memref.alloca() : memref<4x4xi32>
37+
vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
38+
39+
%acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
40+
%acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
41+
%acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
42+
43+
vector.print str "ACC:\n"
44+
%acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
45+
%acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
46+
%acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
47+
%acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
48+
vector.print %acc0 : vector<[4]xi32>
49+
vector.print %acc1 : vector<[4]xi32>
50+
vector.print %acc2 : vector<[4]xi32>
51+
vector.print %acc3 : vector<[4]xi32>
52+
53+
// LHS test data
54+
%lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
55+
[-20, 17, -32, -47, 37, 22, -7, -21],
56+
[ -7, -35, 20, -4, 39, 46, -23, 40],
57+
[ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
58+
59+
%lhs_m = memref.alloca() : memref<4x8xi8>
60+
vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
61+
%lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
62+
63+
vector.print str "LHS:\n"
64+
%lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
65+
%lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
66+
%lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
67+
%lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
68+
vector.print %lhs0 : vector<8xi8>
69+
vector.print %lhs1 : vector<8xi8>
70+
vector.print %lhs2 : vector<8xi8>
71+
vector.print %lhs3 : vector<8xi8>
72+
73+
// RHS test data
74+
%rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
75+
[-35, -24, 37, -32, 33, 30, -11, -17],
76+
[-28, 31, 3, -44, -15, -27, 22, 35],
77+
[-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
78+
79+
%rhs_m = memref.alloca() : memref<4x8xi8>
80+
vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
81+
82+
%rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
83+
%rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
84+
85+
vector.print str "RHS:\n"
86+
%rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
87+
%rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
88+
vector.print %rhs0 : vector<[16]xi8>
89+
vector.print %rhs1 : vector<[16]xi8>
90+
91+
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
92+
93+
// Matrix multiplication
94+
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
95+
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
96+
%2 = vector.contract {indexing_maps = #packed_maps,
97+
iterator_types = ["parallel", "parallel", "reduction"],
98+
kind = #vector.kind<add>} %0, %1, %acc
99+
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
100+
101+
// Display the result of the multiplication
102+
vector.print str "Result:\n"
103+
%u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
104+
%u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
105+
%u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
106+
%u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
107+
vector.print %u0 : vector<[4]xi32>
108+
vector.print %u1 : vector<[4]xi32>
109+
vector.print %u2 : vector<[4]xi32>
110+
vector.print %u3 : vector<[4]xi32>
111+
112+
// CHECK: ( -1999, 1941, 685, -2879 )
113+
// CHECK: ( -3705, 2952, 987, -685 )
114+
// CHECK: ( 2565, 4157, -1589, -357 )
115+
// CHECK: ( 2383, -2252, 32, -1365 )
116+
return
117+
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// REQUIRES: arm-emulator
2+
3+
// DEFINE: %{compile} = mlir-opt %s \
4+
// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
5+
// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
6+
// DEFINE: -o %t
7+
8+
// DEFINE: %{entry_point} = main
9+
10+
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
11+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
12+
13+
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
14+
15+
#packed_maps = [
16+
affine_map<(d0, d1, d2) -> (d0, d2)>,
17+
affine_map<(d0, d1, d2) -> (d1, d2)>,
18+
affine_map<(d0, d1, d2) -> (d0, d1)>
19+
]
20+
21+
func.func private @setArmVLBits(%bits : i32)
22+
23+
func.func @main() {
24+
%c256 = arith.constant 256 : i32
25+
func.call @setArmVLBits(%c256) : (i32) -> ()
26+
27+
%c0 = arith.constant 0 : index
28+
%c0_i32 = arith.constant 0 : i32
29+
%c0_i8 = arith.constant 0 : i8
30+
31+
32+
// Accumulator test data
33+
%acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
34+
[-20, -36, -3, 39, -48, -31, -25, -21],
35+
[-35, -27, -36, -31, 23, -34, -8, -33],
36+
[-20, 17, -32, -47, 37, 22, -7, -21],
37+
[ -7, -35, 20, -4, 39, 46, -23, 40],
38+
[ 40, 27, 37, 43, 38, -6, 37, 49],
39+
[-17, -50, -1, 48, -13, 22, 39, 33],
40+
[-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
41+
%acc_m = memref.alloca() : memref<8x8xi32>
42+
vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
43+
44+
%acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
45+
%acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
46+
%acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
47+
48+
vector.print str "ACC:\n"
49+
%acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
50+
%acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
51+
%acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
52+
%acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
53+
%acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
54+
%acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
55+
%acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
56+
%acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
57+
vector.print %acc0 : vector<[4]xi32>
58+
vector.print %acc1 : vector<[4]xi32>
59+
vector.print %acc2 : vector<[4]xi32>
60+
vector.print %acc3 : vector<[4]xi32>
61+
vector.print %acc4 : vector<[4]xi32>
62+
vector.print %acc5 : vector<[4]xi32>
63+
vector.print %acc6 : vector<[4]xi32>
64+
vector.print %acc7 : vector<[4]xi32>
65+
66+
// LHS test data
67+
%lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
68+
[-23, 39, 48, 26, -23, 32, -39, -38],
69+
[ -3, 9, 43, -30, -32, 39, 41, -39],
70+
[-13, -21, -25, 27, 47, -36, -11, -11],
71+
[ -4, -20, 36, 11, 13, -23, 24, -13],
72+
[-20, 30, -5, 1, 42, -37, -22, 35],
73+
[-22, 38, -4, 44, 25, -31, 23, -39],
74+
[-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
75+
76+
%lhs_m = memref.alloca() : memref<8x8xi8>
77+
vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
78+
%lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
79+
80+
vector.print str "LHS:\n"
81+
%lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
82+
%lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
83+
%lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
84+
%lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
85+
%lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
86+
%lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
87+
%lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
88+
%lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
89+
vector.print %lhs0 : vector<8xi8>
90+
vector.print %lhs1 : vector<8xi8>
91+
vector.print %lhs2 : vector<8xi8>
92+
vector.print %lhs3 : vector<8xi8>
93+
vector.print %lhs4 : vector<8xi8>
94+
vector.print %lhs5 : vector<8xi8>
95+
vector.print %lhs6 : vector<8xi8>
96+
vector.print %lhs7 : vector<8xi8>
97+
98+
// RHS test data
99+
%rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
100+
[ 46, -45, -48, -46, -24, 31, -36, 22],
101+
[ 2, 36, 45, -29, -37, -49, -20, -35],
102+
[ -6, 23, 23, 15, 20, 4, -8, -2],
103+
[-35, -6, 16, 49, -50, 9, -44, 13],
104+
[ 24, 1, -4, -44, 41, 15, -43, 44],
105+
[ 44, 0, -10, 41, 22, 44, -40, 0],
106+
[-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
107+
108+
%rhs_m = memref.alloca() : memref<8x8xi8>
109+
vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
110+
111+
%rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
112+
%rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
113+
114+
vector.print str "RHS:\n"
115+
%rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
116+
%rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
117+
vector.print %rhs0 : vector<[16]xi8>
118+
vector.print %rhs1 : vector<[16]xi8>
119+
120+
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
121+
122+
// Matrix multiplication
123+
%0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
124+
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
125+
%2 = vector.contract {indexing_maps = #packed_maps,
126+
iterator_types = ["parallel", "parallel", "reduction"],
127+
kind = #vector.kind<add>} %0, %1, %acc
128+
: vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
129+
130+
// Display the result of the multilication
131+
vector.print str "Result:\n"
132+
%u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
133+
%u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
134+
%u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
135+
%u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
136+
%u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
137+
%u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
138+
%u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
139+
%u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
140+
vector.print %u0 : vector<[4]xi32>
141+
vector.print %u1 : vector<[4]xi32>
142+
vector.print %u2 : vector<[4]xi32>
143+
vector.print %u3 : vector<[4]xi32>
144+
vector.print %u4 : vector<[4]xi32>
145+
vector.print %u5 : vector<[4]xi32>
146+
vector.print %u6 : vector<[4]xi32>
147+
vector.print %u7 : vector<[4]xi32>
148+
149+
150+
// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
151+
// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
152+
// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
153+
// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
154+
// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
155+
// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
156+
// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
157+
// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
158+
return
159+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// REQUIRES: arm-emulator
2+
3+
// DEFINE: %{compile} = mlir-opt %s \
4+
// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
5+
// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
6+
// DEFINE: -o %t
7+
8+
// DEFINE: %{entry_point} = main
9+
10+
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
11+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
12+
13+
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
14+
15+
#packed_maps = [
16+
affine_map<(d0, d1, d2) -> (d0, d2)>,
17+
affine_map<(d0, d1, d2) -> (d1, d2)>,
18+
affine_map<(d0, d1, d2) -> (d0, d1)>
19+
]
20+
21+
func.func private @setArmVLBits(%bits : i32)
22+
23+
func.func @main() {
24+
%c128 = arith.constant 128 : i32
25+
func.call @setArmVLBits(%c128) : (i32) -> ()
26+
27+
%c0 = arith.constant 0 : index
28+
%c0_i32 = arith.constant 0 : i32
29+
%c0_i8 = arith.constant 0 : i8
30+
31+
// Accumulator test data
32+
%acc_cst = arith.constant dense<[[-44, 20, 44, -46],
33+
[ -8, 25, -34, 26],
34+
[-20, -36, -3, 39],
35+
[-48, -31, -25, -21]]> : vector<4x4xi32>
36+
%acc_m = memref.alloca() : memref<4x4xi32>
37+
vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
38+
39+
%acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
40+
%acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
41+
%acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
42+
43+
vector.print str "ACC:\n"
44+
%acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
45+
%acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
46+
%acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
47+
%acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
48+
vector.print %acc0 : vector<[4]xi32>
49+
vector.print %acc1 : vector<[4]xi32>
50+
vector.print %acc2 : vector<[4]xi32>
51+
vector.print %acc3 : vector<[4]xi32>
52+
53+
// LHS test data
54+
%lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
55+
[-20, 17, -32, -47, 37, 22, -7, -21],
56+
[ -7, -35, 20, -4, 39, 46, -23, 40],
57+
[ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
58+
59+
%lhs_m = memref.alloca() : memref<4x8xi8>
60+
vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
61+
%lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
62+
63+
vector.print str "LHS:\n"
64+
%lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
65+
%lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
66+
%lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
67+
%lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
68+
vector.print %lhs0 : vector<8xi8>
69+
vector.print %lhs1 : vector<8xi8>
70+
vector.print %lhs2 : vector<8xi8>
71+
vector.print %lhs3 : vector<8xi8>
72+
73+
// RHS test data
74+
%rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
75+
[221, 25, 164, 97, 156, 221, 218, 177],
76+
[171, 160, 219, 191, 144, 45, 161, 210],
77+
[223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
78+
79+
%rhs_m = memref.alloca() : memref<4x8xi8>
80+
vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
81+
82+
%rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
83+
%rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
84+
85+
vector.print str "RHS:\n"
86+
%rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
87+
%rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
88+
vector.print %rhs0 : vector<[16]xi8>
89+
vector.print %rhs1 : vector<[16]xi8>
90+
91+
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
92+
93+
// Matrix multiplication
94+
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
95+
%1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
96+
%2 = vector.contract {indexing_maps = #packed_maps,
97+
iterator_types = ["parallel", "parallel", "reduction"],
98+
kind = #vector.kind<add>} %0, %1, %acc
99+
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
100+
101+
// Display the result of the multiplication
102+
vector.print str "Result:\n"
103+
%u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
104+
%u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
105+
%u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
106+
%u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
107+
vector.print %u0 : vector<[4]xi32>
108+
vector.print %u1 : vector<[4]xi32>
109+
vector.print %u2 : vector<[4]xi32>
110+
vector.print %u3 : vector<[4]xi32>
111+
112+
// CHECK: ( -27190, -28812, -30502, -23575 )
113+
// CHECK: ( -7613, -8386, -15938, -6521 )
114+
// CHECK: ( 9468, 18750, 9199, 5764 )
115+
// CHECK: ( 33655, 41064, 48900, 31627 )
116+
return
117+
}
118+

0 commit comments

Comments
 (0)