|
| 1 | +// RUN: gc-opt %s --convert-xevm-to-llvm --xevm-attach-target --convert-scf-to-cf --convert-cf-to-llvm --convert-arith-to-llvm --convert-gpu-to-llvm-spv --gpu-to-llvm --reconcile-unrealized-casts --cse --gpu-module-to-binary | gc-cpu-runner -e main -entry-point-result=void --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s |
| 2 | + |
| 3 | +module @gemm attributes {gpu.container_module} { |
| 4 | + gpu.module @kernel { |
| 5 | + // - Sets of `matrix_mad` intrinsics can differ based on device's *minimal* supported sub-group size. |
| 6 | + // The *minimum supported* sub-group size should be used to call `matrix_mad` intrinsics. |
| 7 | + // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html |
| 8 | + |
| 9 | + gpu.func @block_dpas(%a: !llvm.ptr<1>, %b: !llvm.ptr<1>, %c: !llvm.ptr<1>) kernel attributes {intel_reqd_sub_group_size = 16 : i32} { |
| 10 | + %base_width_a = arith.constant 32 : i32 |
| 11 | + %base_height_a = arith.constant 8 : i32 |
| 12 | + %base_pitch_a = arith.constant 32 : i32 |
| 13 | + %x = arith.constant 0 : i32 |
| 14 | + %y = arith.constant 0 : i32 |
| 15 | + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> |
| 16 | + |
| 17 | + %base_width_b = arith.constant 32 : i32 |
| 18 | + %base_height_b = arith.constant 16 : i32 |
| 19 | + %base_pitch_b = arith.constant 32 : i32 |
| 20 | + %loaded_b1 = xevm.blockload2d %b, %base_width_b, %base_height_b, %base_pitch_b, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> |
| 21 | + %loaded_b_casted = vector.bitcast %loaded_b1 : vector<16xi16> to vector<8xi32> |
| 22 | + |
| 23 | + %base_width_c = arith.constant 64 : i32 |
| 24 | + %base_height_c = arith.constant 8 : i32 |
| 25 | + %base_pitch_c = arith.constant 64 : i32 |
| 26 | + %loaded_c = xevm.blockload2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> |
| 27 | + |
| 28 | + %loaded_c_casted = vector.bitcast %loaded_c : vector<8xi32> to vector<8xf32> |
| 29 | + %c_result = xevm.dpas %loaded_c_casted, %loaded_a, %loaded_b_casted {pa = f16, pb = f16, rc = 8} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32> |
| 30 | + %c_result_casted = vector.bitcast %c_result : vector<8xf32> to vector<8xi32> |
| 31 | + |
| 32 | + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) |
| 33 | + gpu.return |
| 34 | + } |
| 35 | + } |
| 36 | + |
| 37 | + func.func @test(%a : memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { |
| 38 | + %c1 = arith.constant 1 : index |
| 39 | + %c16 = arith.constant 16 : index |
| 40 | + |
| 41 | + %memref_a = gpu.alloc host_shared () : memref<8x16xf16> |
| 42 | + memref.copy %a, %memref_a : memref<8x16xf16> to memref<8x16xf16> |
| 43 | + %a_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_a : memref<8x16xf16> -> index |
| 44 | + %a_ptr_as_i64 = arith.index_cast %a_ptr_as_idx : index to i64 |
| 45 | + %a_ptr = llvm.inttoptr %a_ptr_as_i64 : i64 to !llvm.ptr |
| 46 | + %a_ptr_casted = llvm.addrspacecast %a_ptr : !llvm.ptr to !llvm.ptr<1> |
| 47 | + |
| 48 | + %memref_b = gpu.alloc host_shared () : memref<16x16xf16> |
| 49 | + memref.copy %b, %memref_b : memref<16x16xf16> to memref<16x16xf16> |
| 50 | + %b_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_b : memref<16x16xf16> -> index |
| 51 | + %b_ptr_as_i64 = arith.index_cast %b_ptr_as_idx : index to i64 |
| 52 | + %b_ptr = llvm.inttoptr %b_ptr_as_i64 : i64 to !llvm.ptr |
| 53 | + %b_ptr_casted = llvm.addrspacecast %b_ptr : !llvm.ptr to !llvm.ptr<1> |
| 54 | + |
| 55 | + %memref_c = gpu.alloc host_shared () : memref<8x16xf32> |
| 56 | + memref.copy %c, %memref_c : memref<8x16xf32> to memref<8x16xf32> |
| 57 | + %c_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_c : memref<8x16xf32> -> index |
| 58 | + %c_ptr_as_i64 = arith.index_cast %c_ptr_as_idx : index to i64 |
| 59 | + %c_ptr = llvm.inttoptr %c_ptr_as_i64 : i64 to !llvm.ptr |
| 60 | + %c_ptr_casted = llvm.addrspacecast %c_ptr : !llvm.ptr to !llvm.ptr<1> |
| 61 | + |
| 62 | + gpu.launch_func @kernel::@block_dpas blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%a_ptr_casted : !llvm.ptr<1>, %b_ptr_casted : !llvm.ptr<1>, %c_ptr_casted : !llvm.ptr<1>) |
| 63 | + return %memref_c : memref<8x16xf32> |
| 64 | + } |
| 65 | + |
| 66 | + func.func @main() attributes {llvm.emit_c_interface} { |
| 67 | + %A = memref.alloc() : memref<8x16xf16> |
| 68 | + %c0 = arith.constant 0 : index |
| 69 | + %c1 = arith.constant 1 : index |
| 70 | + %c8 = arith.constant 8 : index |
| 71 | + %c16 = arith.constant 16 : index |
| 72 | + |
| 73 | + scf.for %i = %c0 to %c8 step %c1 { |
| 74 | + scf.for %j = %c0 to %c16 step %c1 { |
| 75 | + %row_idx = arith.index_cast %i : index to i32 |
| 76 | + %row = arith.sitofp %row_idx : i32 to f16 |
| 77 | + memref.store %row, %A[%i, %j] : memref<8x16xf16> |
| 78 | + } |
| 79 | + } |
| 80 | + %B = memref.alloc() : memref<16x16xf16> |
| 81 | + scf.for %i = %c0 to %c16 step %c1 { |
| 82 | + scf.for %j = %c0 to %c16 step %c1 { |
| 83 | + %col_idx = arith.index_cast %j : index to i32 |
| 84 | + %col = arith.sitofp %col_idx : i32 to f16 |
| 85 | + memref.store %col, %B[%i, %j] : memref<16x16xf16> |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + %C = memref.alloc() : memref<8x16xf32> |
| 90 | + %c0_f16 = arith.constant 0.0 : f32 |
| 91 | + scf.for %i = %c0 to %c8 step %c1 { |
| 92 | + scf.for %j = %c0 to %c16 step %c1 { |
| 93 | + memref.store %c0_f16, %C[%i, %j] : memref<8x16xf32> |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + %C_res = call @test(%A, %B, %C) : (memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) -> memref<8x16xf32> |
| 98 | + %C_cast = memref.cast %C_res : memref<8x16xf32> to memref<*xf32> |
| 99 | + %A_cast = memref.cast %A : memref<8x16xf16> to memref<*xf16> |
| 100 | + call @printMemrefF32(%C_cast) : (memref<*xf32>) -> () |
| 101 | + |
| 102 | + // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}} |
| 103 | + // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 104 | + // CHECK-NEXT: [0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240] |
| 105 | + // CHECK-NEXT: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480] |
| 106 | + // CHECK-NEXT: [0, 48, 96, 144, 192, 240, 288, 336, 384, 432, 480, 528, 576, 624, 672, 720] |
| 107 | + // CHECK-NEXT: [0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960] |
| 108 | + // CHECK-NEXT: [0, 80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120, 1200] |
| 109 | + // CHECK-NEXT: [0, 96, 192, 288, 384, 480, 576, 672, 768, 864, 960, 1056, 1152, 1248, 1344, 1440] |
| 110 | + // CHECK-NEXT: [0, 112, 224, 336, 448, 560, 672, 784, 896, 1008, 1120, 1232, 1344, 1456, 1568, 1680] |
| 111 | + |
| 112 | + return |
| 113 | + } |
| 114 | + func.func private @printMemrefF16(%ptr : memref<*xf16>) attributes { llvm.emit_c_interface } |
| 115 | + func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface } |
| 116 | + |
| 117 | +} |
0 commit comments