-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM #140573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/momchil-velikov/vector-contract-i8mm-transform-dialect
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesPatch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff 5 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17],
+ [-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999, 1941, 685, -2879 )
+// CHECK: ( -3705, 2952, 987, -685 )
+// CHECK: ( 2565, 4157, -1589, -357 )
+// CHECK: ( 2383, -2252, 32, -1365 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c256 = arith.constant 256 : i32
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
+ [-20, -36, -3, 39, -48, -31, -25, -21],
+ [-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49],
+ [-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
+ %acc_m = memref.alloca() : memref<8x8xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+ vector.print %acc4 : vector<[4]xi32>
+ vector.print %acc5 : vector<[4]xi32>
+ vector.print %acc6 : vector<[4]xi32>
+ vector.print %acc7 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38],
+ [ -3, 9, 43, -30, -32, 39, 41, -39],
+ [-13, -21, -25, 27, 47, -36, -11, -11],
+ [ -4, -20, 36, 11, 13, -23, 24, -13],
+ [-20, 30, -5, 1, 42, -37, -22, 35],
+ [-22, 38, -4, 44, 25, -31, 23, -39],
+ [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
+
+ %lhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+ %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+ %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+ %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+ %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+ vector.print %lhs4 : vector<8xi8>
+ vector.print %lhs5 : vector<8xi8>
+ vector.print %lhs6 : vector<8xi8>
+ vector.print %lhs7 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
+ [ 46, -45, -48, -46, -24, 31, -36, 22],
+ [ 2, 36, 45, -29, -37, -49, -20, -35],
+ [ -6, 23, 23, 15, 20, 4, -8, -2],
+ [-35, -6, 16, 49, -50, 9, -44, 13],
+ [ 24, 1, -4, -44, 41, 15, -43, 44],
+ [ 44, 0, -10, 41, 22, 44, -40, 0],
+ [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
+
+ %rhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+ // Display the result of the multilication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+ vector.print %u4 : vector<[4]xi32>
+ vector.print %u5 : vector<[4]xi32>
+ vector.print %u6 : vector<[4]xi32>
+ vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
+// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
+// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
+// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
+// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
+// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
+// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
+// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
+ [221, 25, 164, 97, 156, 221, 218, 177],
+ [171, 160, 219, 191, 144, 45, 161, 210],
+ [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: ( -7613, -8386, -15938, -6521 )
+// CHECK: ( 9468, 18750, 9199, 5764 )
+// CHECK: ( 33655, 41064, 48900, 31627 )
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesPatch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff 5 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17],
+ [-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999, 1941, 685, -2879 )
+// CHECK: ( -3705, 2952, 987, -685 )
+// CHECK: ( 2565, 4157, -1589, -357 )
+// CHECK: ( 2383, -2252, 32, -1365 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c256 = arith.constant 256 : i32
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
+ [-20, -36, -3, 39, -48, -31, -25, -21],
+ [-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49],
+ [-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
+ %acc_m = memref.alloca() : memref<8x8xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+ vector.print %acc4 : vector<[4]xi32>
+ vector.print %acc5 : vector<[4]xi32>
+ vector.print %acc6 : vector<[4]xi32>
+ vector.print %acc7 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38],
+ [ -3, 9, 43, -30, -32, 39, 41, -39],
+ [-13, -21, -25, 27, 47, -36, -11, -11],
+ [ -4, -20, 36, 11, 13, -23, 24, -13],
+ [-20, 30, -5, 1, 42, -37, -22, 35],
+ [-22, 38, -4, 44, 25, -31, 23, -39],
+ [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
+
+ %lhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+ %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+ %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+ %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+ %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+ vector.print %lhs4 : vector<8xi8>
+ vector.print %lhs5 : vector<8xi8>
+ vector.print %lhs6 : vector<8xi8>
+ vector.print %lhs7 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
+ [ 46, -45, -48, -46, -24, 31, -36, 22],
+ [ 2, 36, 45, -29, -37, -49, -20, -35],
+ [ -6, 23, 23, 15, 20, 4, -8, -2],
+ [-35, -6, 16, 49, -50, 9, -44, 13],
+ [ 24, 1, -4, -44, 41, 15, -43, 44],
+ [ 44, 0, -10, 41, 22, 44, -40, 0],
+ [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
+
+ %rhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+ // Display the result of the multilication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+ vector.print %u4 : vector<[4]xi32>
+ vector.print %u5 : vector<[4]xi32>
+ vector.print %u6 : vector<[4]xi32>
+ vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
+// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
+// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
+// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
+// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
+// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
+// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
+// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
+ [221, 25, 164, 97, 156, 221, 218, 177],
+ [171, 160, 219, 191, 144, 45, 161, 210],
+ [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: ( -7613, -8386, -15938, -6521 )
+// CHECK: ( 9468, 18750, 9199, 5764 )
+// CHECK: ( 33655, 41064, 48900, 31627 )
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]
|
No description provided.