|
| 1 | +// RUN: mlir-opt %s \ |
| 2 | +// RUN: -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \ |
| 3 | +// RUN: | mlir-cpu-runner \ |
| 4 | +// RUN: --shared-libs=%mlir_cuda_runtime \ |
| 5 | +// RUN: --shared-libs=%mlir_runner_utils \ |
| 6 | +// RUN: --shared-libs=%mlir_c_runner_utils \ |
| 7 | +// RUN: --entry-point-result=void \ |
| 8 | +// RUN: | FileCheck %s |
| 9 | + |
| 10 | +// CHECK: Correct Results : |
| 11 | +// CHECK: 16384 |
| 12 | +// CHECK: Incorrect Results : |
| 13 | +// CHECK: 0 |
| 14 | + |
| 15 | +// This program performs 128x128x128 GEMM (F32 += F16 * F16) |
| 16 | +// |
| 17 | +// ## Sequential |
| 18 | +// for(128) |
| 19 | +// for(128) |
| 20 | +// for(128) |
| 21 | +// D += A * B |
| 22 | +// |
| 23 | +// ## Parallel 1 CTA with 1 Warpgroup with 2 pipelining stage |
| 24 | +// |
| 25 | +// cuda kernel() { |
| 26 | +// mbarriers.init[2] |
| 27 | +// for(i = 0;...2) { |
| 28 | +// tma.load shmem_buffer<i x...> |
| 29 | +// mbarrier.expect_tx group[i] |
| 30 | +// } |
| 31 | +// result = |
| 32 | +// for(i = 0;...2) { |
| 33 | +// pipe = i % 2 |
| 34 | +// mbarrier.wait [pipe] |
| 35 | +// lhs = shmem_buffer_lhs<pipe x 128 x 64> |
| 36 | +// rhs = shmem_buffer_rhs<pipe x 64 x 128> |
| 37 | +// yield nvgpu.warpgroup.mma (lhs, rhs) |
| 38 | +// --------------------------------------------------------------------- |
| 39 | +// Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128] |
| 40 | +// wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128]) |
| 41 | +// wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128]) |
| 42 | +// wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128]) |
| 43 | +// wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128]) |
| 44 | +// wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128]) |
| 45 | +// wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128]) |
| 46 | +// wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128]) |
| 47 | +// wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128]) |
| 48 | +// --------------------------------------------------------------------- |
| 49 | +// } |
| 50 | +// nvgpu.store result -> shmem_buffer_result |
| 51 | + |
| 52 | + |
| 53 | +!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 2> |
| 54 | +!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> |
| 55 | +!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> |
| 56 | + |
| 57 | +func.func private @printMemrefF32(memref<*xf32>) |
| 58 | + |
| 59 | +memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64} |
| 60 | +memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64} |
| 61 | + |
| 62 | +func.func @main() { |
| 63 | + // matrix A (128*64) * matrix B (64*128) * stages(2) |
| 64 | + // matrix A [128][64] * matrix B[64][128] * stages(2) |
| 65 | + %shmemSize = arith.constant 65536 : i32 |
| 66 | + %hc1 = arith.constant 1 : index |
| 67 | + %hc4096 = arith.constant 4096 : index |
| 68 | + %hc0 = arith.constant 0 : index |
| 69 | + %hc64 = arith.constant 64 : index |
| 70 | + %hc16 = arith.constant 16 : index |
| 71 | + %hc8 = arith.constant 8 : index |
| 72 | + %hc128 = arith.constant 128 : index |
| 73 | + %hc32 = arith.constant 32 : index |
| 74 | + %hc256 = arith.constant 256 : index |
| 75 | + %f0 = arith.constant 0.0 : f32 |
| 76 | + |
| 77 | + // Step 1. Allocate and Initilize LHS and RHS Matrices |
| 78 | + %matrixAHost = memref.alloc() : memref<128x128xf16> |
| 79 | + %matrixBHost = memref.alloc() : memref<128x128xf16> |
| 80 | + %matrixDHost = memref.alloc() : memref<128x128xf32> |
| 81 | + %matrixRefHost = memref.alloc() : memref<128x128xf32> |
| 82 | + scf.for %i = %hc0 to %hc128 step %hc1 { |
| 83 | + scf.for %j = %hc0 to %hc128 step %hc1 { |
| 84 | + %v0 = arith.muli %i, %hc128 : index // i * 128 |
| 85 | + %v00 = arith.addi %v0, %j : index // i * 128 + j |
| 86 | + %v01 = arith.divui %v00, %hc8 : index // (i * 128 + j) / 8 |
| 87 | + %v02 = arith.remui %v01, %hc16 : index // <<<<< mod 128 |
| 88 | + %v2 = arith.index_cast %v02 : index to i32 |
| 89 | + %vR = arith.sitofp %v2 : i32 to f16 |
| 90 | + memref.store %vR, %matrixBHost[%i, %j] : memref<128x128xf16> |
| 91 | + %b0 = arith.muli %j, %hc64 : index |
| 92 | + %b00 = arith.addi %b0, %i : index |
| 93 | + %b01 = arith.divui %b00, %hc8 : index |
| 94 | + %b02 = arith.remui %b01, %hc16 : index // <<<<< mod 128 |
| 95 | + %v1 = arith.index_cast %b02 : index to i32 |
| 96 | + %vL = arith.sitofp %v1 : i32 to f16 |
| 97 | + memref.store %vL, %matrixAHost[%j, %i] : memref<128x128xf16> |
| 98 | + memref.store %f0, %matrixDHost[%i, %j] : memref<128x128xf32> |
| 99 | + memref.store %f0, %matrixRefHost[%i, %j] : memref<128x128xf32> |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + // Step 2. Allocate Device Memory for LHS and RHS Matrices and Copy H2D |
| 104 | + %token = gpu.wait async |
| 105 | + %matrixA:2 = gpu.alloc async [%token] () : memref<128x128xf16> |
| 106 | + %matrixB:2 = gpu.alloc async [%token] () : memref<128x128xf16> |
| 107 | + %matrixD:2 = gpu.alloc async [%token] () : memref<128x128xf32> |
| 108 | + %1 = gpu.memcpy async [%token] %matrixA, %matrixAHost : memref<128x128xf16>, memref<128x128xf16> |
| 109 | + %2 = gpu.memcpy async [%token] %matrixB, %matrixBHost : memref<128x128xf16>, memref<128x128xf16> |
| 110 | + %castA = memref.cast %matrixA : memref<128x128xf16> to memref<*xf16> |
| 111 | + %castB = memref.cast %matrixB : memref<128x128xf16> to memref<*xf16> |
| 112 | + |
| 113 | + // Step 3. Create TMA Descriptor |
| 114 | + %descA = nvgpu.tma.create.descriptor %castA box[%hc128, %hc64] : memref<*xf16> -> !lhsTensorMap |
| 115 | + %descB = nvgpu.tma.create.descriptor %castB box[%hc64, %hc64] : memref<*xf16> -> !rhsTensorMap |
| 116 | + |
| 117 | + // Step 4. Launch GPU Kernel |
| 118 | + gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %hc1, %arg7 = %hc1, %arg8 = %hc1) |
| 119 | + threads(%arg3, %arg4, %arg5) in (%arg9 = %hc128, %arg10 = %hc1, %arg11 = %hc1) |
| 120 | + dynamic_shared_memory_size %shmemSize |
| 121 | + { |
| 122 | + memref.assume_alignment %matrixD, 16 : memref<128x128xf32> |
| 123 | + |
| 124 | + %c256 = arith.constant 256 : index |
| 125 | + %c10000000 = arith.constant 10000000 : index |
| 126 | + %c32768 = arith.constant 32768 : index |
| 127 | + %c320 = arith.constant 320 : index |
| 128 | + %c192 = arith.constant 192 : index |
| 129 | + %c6 = arith.constant 6 : index |
| 130 | + %c5 = arith.constant 5 : index |
| 131 | + %c4 = arith.constant 4 : index |
| 132 | + %c3 = arith.constant 3 : index |
| 133 | + %c7 = arith.constant 7 : index |
| 134 | + %c64 = arith.constant 64 : index |
| 135 | + %c1 = arith.constant 1 : index |
| 136 | + %c2 = arith.constant 2 : index |
| 137 | + %c0 = arith.constant 0 : index |
| 138 | + %c128 = arith.constant 128 : index |
| 139 | + %c32 = arith.constant 32 : index |
| 140 | + %c16 = arith.constant 16 : index |
| 141 | + %c4096 = arith.constant 4096 : index |
| 142 | + %c8 = arith.constant 8 : index |
| 143 | + %txcount = arith.constant 32768 : index |
| 144 | + |
| 145 | + %tidx = gpu.thread_id x |
| 146 | + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> |
| 147 | + %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3> |
| 148 | + %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3> |
| 149 | + %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> |
| 150 | + |
| 151 | + // Step 1. [GPU] Create Async Transactional Barriers (mbarriers) |
| 152 | + %barrier = nvgpu.mbarrier.create -> !barrierType |
| 153 | + %cnd = arith.cmpi eq, %tidx, %c0 : index |
| 154 | + |
| 155 | + // Step 2. [GPU] Initialize mbarriers |
| 156 | + nvgpu.mbarrier.init %barrier[%c0], %c1 : !barrierType |
| 157 | + nvgpu.mbarrier.init %barrier[%c1], %c1 : !barrierType |
| 158 | + |
| 159 | + // Step 3. [GPU] Prefetch TMA Descriptors to L1 Cache |
| 160 | + nvgpu.tma.prefetch.descriptor %descA : !lhsTensorMap |
| 161 | + nvgpu.tma.prefetch.descriptor %descB : !rhsTensorMap |
| 162 | + |
| 163 | + // Step 4.1 [GPU] TMA Load Pipeline 1 |
| 164 | + scf.if %cnd { |
| 165 | + %pipe = arith.constant 0 : index |
| 166 | + %lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3> |
| 167 | + %rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> |
| 168 | + %halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3> |
| 169 | + %halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3> |
| 170 | + nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType |
| 171 | + %dim = arith.muli %pipe, %c64 : index |
| 172 | + nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3> |
| 173 | + nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3> |
| 174 | + nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3> |
| 175 | + } |
| 176 | + // Step 4.2 [GPU] TMA Load Pipeline 2 |
| 177 | + scf.if %cnd { |
| 178 | + %pipe = arith.constant 1 : index |
| 179 | + %lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3> |
| 180 | + %rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> |
| 181 | + %halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3> |
| 182 | + %halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3> |
| 183 | + nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType |
| 184 | + %dim = arith.muli %pipe, %c64 : index |
| 185 | + nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3> |
| 186 | + nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3> |
| 187 | + nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3> |
| 188 | + } |
| 189 | + |
| 190 | + // Step 5. [GPU] Initiliaze accumulator matrix |
| 191 | + %14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> |
| 192 | + |
| 193 | + // Step 6. [GPU] Main Loop Starts |
| 194 | + %15 = scf.for %i = %c0 to %c2 step %c1 iter_args(%mc = %14) |
| 195 | + -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>) |
| 196 | + { |
| 197 | + %ticks = arith.constant 10000000 : index |
| 198 | + // TMA wait |
| 199 | + nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType |
| 200 | + %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3> |
| 201 | + %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3> |
| 202 | + // Descriptor WGMMA |
| 203 | + %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>> |
| 204 | + %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>> |
| 205 | + // Perform WGMMA 128x128x64 |
| 206 | + %md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> |
| 207 | + scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> |
| 208 | + } |
| 209 | + |
| 210 | + // Step 7. Wait all to finish mma |
| 211 | + nvvm.wgmma.wait.group.sync.aligned 0 |
| 212 | + |
| 213 | + // Step 8. [GPU] Epilogue, store fragmented register to shared memory |
| 214 | + %accShmem = memref.get_global @accShmem : memref<0xf32, 3> |
| 215 | + %accShmemPtr = memref.reinterpret_cast %accShmem to offset: [0], sizes: [128, 128], strides: [128, 1] : memref<0xf32, 3> to memref<128x128xf32, 3> |
| 216 | + nvgpu.warpgroup.mma.store %15, %accShmemPtr : <fragmented = vector<128x128xf32>> to memref<128x128xf32, 3> |
| 217 | + |
| 218 | + // Step 9. [GPU] Epilogue, shared memory to global memory |
| 219 | + %17 = arith.divui %tidx, %c32 : index |
| 220 | + %18 = arith.remui %tidx, %c32 : index |
| 221 | + scf.for %arg12 = %17 to %c128 step %c4 { |
| 222 | + %19 = arith.muli %18, %c4 : index |
| 223 | + %20 = vector.load %accShmemPtr[%arg12, %19] : memref<128x128xf32, 3>, vector<4xf32> |
| 224 | + vector.store %20, %matrixD[%arg12, %19] : memref<128x128xf32>, vector<4xf32> |
| 225 | + } |
| 226 | + gpu.terminator |
| 227 | + } |
| 228 | + |
| 229 | + // Step 5. Copy D2H |
| 230 | + %5 = gpu.memcpy async [%token] %matrixDHost, %matrixD : memref<128x128xf32>, memref<128x128xf32> |
| 231 | + gpu.wait [%token] |
| 232 | + |
| 233 | + // Step 6. Compute on host |
| 234 | + linalg.matmul ins(%matrixAHost, %matrixBHost : memref<128x128xf16>, memref<128x128xf16>) outs(%matrixRefHost : memref<128x128xf32>) |
| 235 | + |
| 236 | + // Step 7. Verify |
| 237 | + %ic1 = arith.constant 1 : i32 |
| 238 | + %ic0 = arith.constant 0 : i32 |
| 239 | + %tolerance = arith.constant 0.00000001 : f32 |
| 240 | + %errorCount, %correctCount = |
| 241 | + scf.for %i = %hc0 to %hc128 step %hc1 iter_args(%ec1 = %ic0, %cc1 = %ic0) -> (i32,i32) { |
| 242 | + %ec2, %cc2 = |
| 243 | + scf.for %j = %hc0 to %hc128 step %hc1 iter_args(%ec2 = %ec1, %cc2 = %cc1) -> (i32,i32){ |
| 244 | + %v1 = memref.load %matrixRefHost[%i,%j] : memref<128x128xf32> |
| 245 | + %v2 = memref.load %matrixDHost[%i,%j] : memref<128x128xf32> |
| 246 | + %g1 = arith.subf %v1,%v2 : f32 |
| 247 | + %g2 = math.absf %g1: f32 |
| 248 | + %g3 = arith.cmpf ult, %tolerance, %g2 : f32 |
| 249 | + %ec3, %cc3 = scf.if %g3 -> (i32, i32) { |
| 250 | + %coor = arith.constant dense<-1> : vector<2xi32> |
| 251 | + %i32 = arith.index_cast %i : index to i32 |
| 252 | + %j32 = arith.index_cast %j : index to i32 |
| 253 | + %coord1 = vector.insert %i32, %coor[0] : i32 into vector<2xi32> |
| 254 | + %coord2 = vector.insert %j32, %coord1[1] : i32 into vector<2xi32> |
| 255 | + %ec3 = arith.addi %ec2, %ic1 : i32 |
| 256 | + scf.yield %ec3, %cc2 : i32, i32 |
| 257 | + } else { |
| 258 | + %cc3 = arith.addi %cc2, %ic1 : i32 |
| 259 | + scf.yield %ec2, %cc3 : i32, i32 |
| 260 | + } |
| 261 | + scf.yield %ec3, %cc3 : i32,i32 |
| 262 | + } |
| 263 | + scf.yield %ec2,%cc2 : i32,i32 |
| 264 | + } |
| 265 | + |
| 266 | + vector.print str "Correct Results :" |
| 267 | + vector.print %correctCount : i32 |
| 268 | + vector.print str "Incorrect Results :" |
| 269 | + vector.print %errorCount : i32 |
| 270 | + |
| 271 | + return |
| 272 | +} |
0 commit comments