|
| 1 | +// REQUIRES: arm-emulator |
| 2 | + |
| 3 | +// DEFINE: %{compile} = mlir-opt %s \ |
| 4 | +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ |
| 5 | +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ |
| 6 | +// DEFINE: -o %t |
| 7 | + |
| 8 | +// DEFINE: %{entry_point} = main |
| 9 | + |
| 10 | +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ |
| 11 | +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils |
| 12 | + |
| 13 | +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s |
| 14 | + |
| 15 | +#packed_maps = [ |
| 16 | + affine_map<(d0, d1, d2) -> (d0, d2)>, |
| 17 | + affine_map<(d0, d1, d2) -> (d1, d2)>, |
| 18 | + affine_map<(d0, d1, d2) -> (d0, d1)> |
| 19 | +] |
| 20 | + |
| 21 | +func.func private @setArmVLBits(%bits : i32) |
| 22 | + |
| 23 | +func.func @main() { |
| 24 | + %c256 = arith.constant 256 : i32 |
| 25 | + func.call @setArmVLBits(%c256) : (i32) -> () |
| 26 | + |
| 27 | + %c0 = arith.constant 0 : index |
| 28 | + %c0_i32 = arith.constant 0 : i32 |
| 29 | + %c0_i8 = arith.constant 0 : i8 |
| 30 | + |
| 31 | + |
| 32 | + // Accumulator test data |
| 33 | + %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26], |
| 34 | + [-20, -36, -3, 39, -48, -31, -25, -21], |
| 35 | + [-35, -27, -36, -31, 23, -34, -8, -33], |
| 36 | + [-20, 17, -32, -47, 37, 22, -7, -21], |
| 37 | + [ -7, -35, 20, -4, 39, 46, -23, 40], |
| 38 | + [ 40, 27, 37, 43, 38, -6, 37, 49], |
| 39 | + [-17, -50, -1, 48, -13, 22, 39, 33], |
| 40 | + [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32> |
| 41 | + %acc_m = memref.alloca() : memref<8x8xi32> |
| 42 | + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32> |
| 43 | + |
| 44 | + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32> |
| 45 | + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32> |
| 46 | + %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32> |
| 47 | + |
| 48 | + vector.print str "ACC:\n" |
| 49 | + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32> |
| 50 | + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32> |
| 51 | + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32> |
| 52 | + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32> |
| 53 | + %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32> |
| 54 | + %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32> |
| 55 | + %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32> |
| 56 | + %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32> |
| 57 | + vector.print %acc0 : vector<[4]xi32> |
| 58 | + vector.print %acc1 : vector<[4]xi32> |
| 59 | + vector.print %acc2 : vector<[4]xi32> |
| 60 | + vector.print %acc3 : vector<[4]xi32> |
| 61 | + vector.print %acc4 : vector<[4]xi32> |
| 62 | + vector.print %acc5 : vector<[4]xi32> |
| 63 | + vector.print %acc6 : vector<[4]xi32> |
| 64 | + vector.print %acc7 : vector<[4]xi32> |
| 65 | + |
| 66 | + // LHS test data |
| 67 | + %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35], |
| 68 | + [-23, 39, 48, 26, -23, 32, -39, -38], |
| 69 | + [ -3, 9, 43, -30, -32, 39, 41, -39], |
| 70 | + [-13, -21, -25, 27, 47, -36, -11, -11], |
| 71 | + [ -4, -20, 36, 11, 13, -23, 24, -13], |
| 72 | + [-20, 30, -5, 1, 42, -37, -22, 35], |
| 73 | + [-22, 38, -4, 44, 25, -31, 23, -39], |
| 74 | + [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8> |
| 75 | + |
| 76 | + %lhs_m = memref.alloca() : memref<8x8xi8> |
| 77 | + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8> |
| 78 | + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8> |
| 79 | + |
| 80 | + vector.print str "LHS:\n" |
| 81 | + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8> |
| 82 | + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8> |
| 83 | + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8> |
| 84 | + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8> |
| 85 | + %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8> |
| 86 | + %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8> |
| 87 | + %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8> |
| 88 | + %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8> |
| 89 | + vector.print %lhs0 : vector<8xi8> |
| 90 | + vector.print %lhs1 : vector<8xi8> |
| 91 | + vector.print %lhs2 : vector<8xi8> |
| 92 | + vector.print %lhs3 : vector<8xi8> |
| 93 | + vector.print %lhs4 : vector<8xi8> |
| 94 | + vector.print %lhs5 : vector<8xi8> |
| 95 | + vector.print %lhs6 : vector<8xi8> |
| 96 | + vector.print %lhs7 : vector<8xi8> |
| 97 | + |
| 98 | + // RHS test data |
| 99 | + %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32], |
| 100 | + [ 46, -45, -48, -46, -24, 31, -36, 22], |
| 101 | + [ 2, 36, 45, -29, -37, -49, -20, -35], |
| 102 | + [ -6, 23, 23, 15, 20, 4, -8, -2], |
| 103 | + [-35, -6, 16, 49, -50, 9, -44, 13], |
| 104 | + [ 24, 1, -4, -44, 41, 15, -43, 44], |
| 105 | + [ 44, 0, -10, 41, 22, 44, -40, 0], |
| 106 | + [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8> |
| 107 | + |
| 108 | + %rhs_m = memref.alloca() : memref<8x8xi8> |
| 109 | + vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8> |
| 110 | + |
| 111 | + %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8> |
| 112 | + %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8> |
| 113 | + |
| 114 | + vector.print str "RHS:\n" |
| 115 | + %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8> |
| 116 | + %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8> |
| 117 | + vector.print %rhs0 : vector<[16]xi8> |
| 118 | + vector.print %rhs1 : vector<[16]xi8> |
| 119 | + |
| 120 | + %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> |
| 121 | + |
| 122 | + // Matrix multiplication |
| 123 | + %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32> |
| 124 | + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> |
| 125 | + %2 = vector.contract {indexing_maps = #packed_maps, |
| 126 | + iterator_types = ["parallel", "parallel", "reduction"], |
| 127 | + kind = #vector.kind<add>} %0, %1, %acc |
| 128 | + : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32> |
| 129 | + |
| 130 | + // Display the result of the multilication |
| 131 | + vector.print str "Result:\n" |
| 132 | + %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32> |
| 133 | + %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32> |
| 134 | + %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32> |
| 135 | + %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32> |
| 136 | + %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32> |
| 137 | + %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32> |
| 138 | + %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32> |
| 139 | + %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32> |
| 140 | + vector.print %u0 : vector<[4]xi32> |
| 141 | + vector.print %u1 : vector<[4]xi32> |
| 142 | + vector.print %u2 : vector<[4]xi32> |
| 143 | + vector.print %u3 : vector<[4]xi32> |
| 144 | + vector.print %u4 : vector<[4]xi32> |
| 145 | + vector.print %u5 : vector<[4]xi32> |
| 146 | + vector.print %u6 : vector<[4]xi32> |
| 147 | + vector.print %u7 : vector<[4]xi32> |
| 148 | + |
| 149 | + |
| 150 | +// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 ) |
| 151 | +// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 ) |
| 152 | +// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 ) |
| 153 | +// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 ) |
| 154 | +// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 ) |
| 155 | +// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 ) |
| 156 | +// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 ) |
| 157 | +// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 ) |
| 158 | + return |
| 159 | +} |
0 commit comments