|
1 |
| -// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s |
| 1 | +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s |
2 | 2 |
|
3 |
| -#packed_maps = [ |
4 |
| - affine_map<(d0, d1, d2) -> (d0, d2)>, |
5 |
| - affine_map<(d0, d1, d2) -> (d1, d2)>, |
6 |
| - affine_map<(d0, d1, d2) -> (d0, d1)> |
7 |
| -] |
| 3 | +#attrs = { |
| 4 | + indexing_maps = [ |
| 5 | + affine_map<(d0, d1, d2) -> (d0, d2)>, |
| 6 | + affine_map<(d0, d1, d2) -> (d1, d2)>, |
| 7 | + affine_map<(d0, d1, d2) -> (d0, d1)> |
| 8 | + ], |
| 9 | + iterator_types = ["parallel", "parallel", "reduction"], |
| 10 | + kind = #vector.kind<add> |
| 11 | +} |
8 | 12 |
|
9 | 13 | // CHECK-LABEL: @test_vector_contract_to_smmla
|
10 | 14 |
|
@@ -85,10 +89,93 @@ func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>,
|
85 | 89 |
|
86 | 90 | %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
|
87 | 91 | %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
|
88 |
| - %2 = vector.contract {indexing_maps = #packed_maps, |
89 |
| - iterator_types = ["parallel", "parallel", "reduction"], |
90 |
| - kind = #vector.kind<add>} %0, %1, %acc |
| 92 | + %2 = vector.contract #attrs %0, %1, %acc |
91 | 93 | : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
|
92 | 94 |
|
93 | 95 | return %2 : vector<4x[4]xi32>
|
94 | 96 | }
|
| 97 | + |
| 98 | +// CHECK-LABEL: @test_vector_contract_to_smmla_implicit_sext |
| 99 | + |
| 100 | +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector |
| 101 | +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> |
| 102 | +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> |
| 103 | +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> |
| 104 | +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> |
| 105 | + |
| 106 | +// Replicate across the entire length of the scalabale vector |
| 107 | +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> |
| 108 | + |
| 109 | +// Same for LHS rows 2 and 4 |
| 110 | +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> |
| 111 | +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> |
| 112 | +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> |
| 113 | +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> |
| 114 | +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> |
| 115 | + |
| 116 | +// Extract sub-tiles from the RHS |
| 117 | +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> |
| 118 | +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> |
| 119 | +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> |
| 120 | + |
| 121 | +// Extract accumulator rows 0 and 1 and pack (into "registers") |
| 122 | +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> |
| 123 | +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> |
| 124 | +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> |
| 125 | +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> |
| 126 | +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> |
| 127 | +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> |
| 128 | +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> |
| 129 | +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> |
| 130 | + |
| 131 | +// Same for accumulator rows 2 and 3. |
| 132 | +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> |
| 133 | +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> |
| 134 | +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> |
| 135 | +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> |
| 136 | +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> |
| 137 | +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> |
| 138 | +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> |
| 139 | +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> |
| 140 | + |
| 141 | +// Do the sub-tile matrix multiplications |
| 142 | +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> |
| 143 | +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> |
| 144 | +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> |
| 145 | +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> |
| 146 | + |
| 147 | +// Unpack (from "registers") and insert in the output result rows 0 and 1 |
| 148 | +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> |
| 149 | +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> |
| 150 | +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> |
| 151 | +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> |
| 152 | +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> |
| 153 | +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> |
| 154 | +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> |
| 155 | +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> |
| 156 | +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> |
| 157 | +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> |
| 158 | + |
| 159 | +// Same for result rows 2 and 3 |
| 160 | +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> |
| 161 | +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> |
| 162 | +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> |
| 163 | +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> |
| 164 | +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> |
| 165 | +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> |
| 166 | +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> |
| 167 | +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> |
| 168 | +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> |
| 169 | +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> |
| 170 | + |
| 171 | +// Test a variant where the sign-extension of the operands is |
| 172 | +// implicit. The output is identical to the one of the previous test. |
| 173 | +func.func @test_vector_contract_to_smmla_implicit_sext(%lhs: vector<4x8xi8>, |
| 174 | + %rhs: vector<[4]x8xi8>, |
| 175 | + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { |
| 176 | + |
| 177 | + %0 = vector.contract #attrs %lhs, %rhs, %acc |
| 178 | + : vector<4x8xi8>, vector<[4]x8xi8> into vector<4x[4]xi32> |
| 179 | + |
| 180 | + return %0 : vector<4x[4]xi32> |
| 181 | +} |
0 commit comments