Skip to content

Commit 51916f0

Browse files
authored
[mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) (#69913)
This PR adds a test that performs GEMM 128x128x128 (F32 += F16 * F16). It uses `sm_90a` features in NVGPU dialect. Simplified algorithm is as follows: **Prologue** ``` mgroup = mbarriers.init x 2 tma.load ... shmem_buffer_lhs<0 x 128 x 64> tma.load ... shmem_buffer_rhs<0 x 64 x 64> tma.load ... shmem_buffer_rhs<0 x 64 x 64> mbarrier.expect_tx 32768 tma.load ... shmem_buffer_lhs<1 x 128 x 64> tma.load ... shmem_buffer_rhs<1 x 64 x 64> tma.load ... shmem_buffer_rhs<1 x 64 x 64> mbarrier.expect_tx 32768 ``` **Mainloop** ``` matrixD = for(i = 0;...2) { mbarrier.try_wait [i] lhs = shmem_buffer_lhs<pipe x 128 x 64> rhs = shmem_buffer_rhs<pipe x 64 x 128> yield nvgpu.warpgroup.mma (lhs, rhs) // Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128] // wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128]) // wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128]) // wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128]) // wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128]) // wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128]) // wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128]) // wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128]) // wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128]) ``` **Epilogue** ``` //reg->shmem warpgroup.mma.store matrixD, shmem //shmem->glbmem parallel-for(i=0;...128) parallel-for(j=0;...128) store shmem, globalmem ```
1 parent a00caad commit 51916f0

File tree

1 file changed

+272
-0
lines changed

1 file changed

+272
-0
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
3+
// RUN: | mlir-cpu-runner \
4+
// RUN: --shared-libs=%mlir_cuda_runtime \
5+
// RUN: --shared-libs=%mlir_runner_utils \
6+
// RUN: --shared-libs=%mlir_c_runner_utils \
7+
// RUN: --entry-point-result=void \
8+
// RUN: | FileCheck %s
9+
10+
// CHECK: Correct Results :
11+
// CHECK: 16384
12+
// CHECK: Incorrect Results :
13+
// CHECK: 0
14+
15+
// This program performs 128x128x128 GEMM (F32 += F16 * F16)
16+
//
17+
// ## Sequential
18+
// for(128)
19+
// for(128)
20+
// for(128)
21+
// D += A * B
22+
//
23+
// ## Parallel 1 CTA with 1 Warpgroup with 2 pipelining stage
24+
//
25+
// cuda kernel() {
26+
// mbarriers.init[2]
27+
// for(i = 0;...2) {
28+
// tma.load shmem_buffer<i x...>
29+
// mbarrier.expect_tx group[i]
30+
// }
31+
// result =
32+
// for(i = 0;...2) {
33+
// pipe = i % 2
34+
// mbarrier.wait [pipe]
35+
// lhs = shmem_buffer_lhs<pipe x 128 x 64>
36+
// rhs = shmem_buffer_rhs<pipe x 64 x 128>
37+
// yield nvgpu.warpgroup.mma (lhs, rhs)
38+
// ---------------------------------------------------------------------
39+
// Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128]
40+
// wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128])
41+
// wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128])
42+
// wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128])
43+
// wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128])
44+
// wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128])
45+
// wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128])
46+
// wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128])
47+
// wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128])
48+
// ---------------------------------------------------------------------
49+
// }
50+
// nvgpu.store result -> shmem_buffer_result
51+
52+
53+
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 2>
54+
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
55+
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
56+
57+
func.func private @printMemrefF32(memref<*xf32>)
58+
59+
memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
60+
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
61+
62+
func.func @main() {
63+
// matrix A (128*64) * matrix B (64*128) * stages(2)
64+
// matrix A [128][64] * matrix B[64][128] * stages(2)
65+
%shmemSize = arith.constant 65536 : i32
66+
%hc1 = arith.constant 1 : index
67+
%hc4096 = arith.constant 4096 : index
68+
%hc0 = arith.constant 0 : index
69+
%hc64 = arith.constant 64 : index
70+
%hc16 = arith.constant 16 : index
71+
%hc8 = arith.constant 8 : index
72+
%hc128 = arith.constant 128 : index
73+
%hc32 = arith.constant 32 : index
74+
%hc256 = arith.constant 256 : index
75+
%f0 = arith.constant 0.0 : f32
76+
77+
// Step 1. Allocate and Initilize LHS and RHS Matrices
78+
%matrixAHost = memref.alloc() : memref<128x128xf16>
79+
%matrixBHost = memref.alloc() : memref<128x128xf16>
80+
%matrixDHost = memref.alloc() : memref<128x128xf32>
81+
%matrixRefHost = memref.alloc() : memref<128x128xf32>
82+
scf.for %i = %hc0 to %hc128 step %hc1 {
83+
scf.for %j = %hc0 to %hc128 step %hc1 {
84+
%v0 = arith.muli %i, %hc128 : index // i * 128
85+
%v00 = arith.addi %v0, %j : index // i * 128 + j
86+
%v01 = arith.divui %v00, %hc8 : index // (i * 128 + j) / 8
87+
%v02 = arith.remui %v01, %hc16 : index // <<<<< mod 128
88+
%v2 = arith.index_cast %v02 : index to i32
89+
%vR = arith.sitofp %v2 : i32 to f16
90+
memref.store %vR, %matrixBHost[%i, %j] : memref<128x128xf16>
91+
%b0 = arith.muli %j, %hc64 : index
92+
%b00 = arith.addi %b0, %i : index
93+
%b01 = arith.divui %b00, %hc8 : index
94+
%b02 = arith.remui %b01, %hc16 : index // <<<<< mod 128
95+
%v1 = arith.index_cast %b02 : index to i32
96+
%vL = arith.sitofp %v1 : i32 to f16
97+
memref.store %vL, %matrixAHost[%j, %i] : memref<128x128xf16>
98+
memref.store %f0, %matrixDHost[%i, %j] : memref<128x128xf32>
99+
memref.store %f0, %matrixRefHost[%i, %j] : memref<128x128xf32>
100+
}
101+
}
102+
103+
// Step 2. Allocate Device Memory for LHS and RHS Matrices and Copy H2D
104+
%token = gpu.wait async
105+
%matrixA:2 = gpu.alloc async [%token] () : memref<128x128xf16>
106+
%matrixB:2 = gpu.alloc async [%token] () : memref<128x128xf16>
107+
%matrixD:2 = gpu.alloc async [%token] () : memref<128x128xf32>
108+
%1 = gpu.memcpy async [%token] %matrixA, %matrixAHost : memref<128x128xf16>, memref<128x128xf16>
109+
%2 = gpu.memcpy async [%token] %matrixB, %matrixBHost : memref<128x128xf16>, memref<128x128xf16>
110+
%castA = memref.cast %matrixA : memref<128x128xf16> to memref<*xf16>
111+
%castB = memref.cast %matrixB : memref<128x128xf16> to memref<*xf16>
112+
113+
// Step 3. Create TMA Descriptor
114+
%descA = nvgpu.tma.create.descriptor %castA box[%hc128, %hc64] : memref<*xf16> -> !lhsTensorMap
115+
%descB = nvgpu.tma.create.descriptor %castB box[%hc64, %hc64] : memref<*xf16> -> !rhsTensorMap
116+
117+
// Step 4. Launch GPU Kernel
118+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %hc1, %arg7 = %hc1, %arg8 = %hc1)
119+
threads(%arg3, %arg4, %arg5) in (%arg9 = %hc128, %arg10 = %hc1, %arg11 = %hc1)
120+
dynamic_shared_memory_size %shmemSize
121+
{
122+
memref.assume_alignment %matrixD, 16 : memref<128x128xf32>
123+
124+
%c256 = arith.constant 256 : index
125+
%c10000000 = arith.constant 10000000 : index
126+
%c32768 = arith.constant 32768 : index
127+
%c320 = arith.constant 320 : index
128+
%c192 = arith.constant 192 : index
129+
%c6 = arith.constant 6 : index
130+
%c5 = arith.constant 5 : index
131+
%c4 = arith.constant 4 : index
132+
%c3 = arith.constant 3 : index
133+
%c7 = arith.constant 7 : index
134+
%c64 = arith.constant 64 : index
135+
%c1 = arith.constant 1 : index
136+
%c2 = arith.constant 2 : index
137+
%c0 = arith.constant 0 : index
138+
%c128 = arith.constant 128 : index
139+
%c32 = arith.constant 32 : index
140+
%c16 = arith.constant 16 : index
141+
%c4096 = arith.constant 4096 : index
142+
%c8 = arith.constant 8 : index
143+
%txcount = arith.constant 32768 : index
144+
145+
%tidx = gpu.thread_id x
146+
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
147+
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
148+
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
149+
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
150+
151+
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
152+
%barrier = nvgpu.mbarrier.create -> !barrierType
153+
%cnd = arith.cmpi eq, %tidx, %c0 : index
154+
155+
// Step 2. [GPU] Initialize mbarriers
156+
nvgpu.mbarrier.init %barrier[%c0], %c1 : !barrierType
157+
nvgpu.mbarrier.init %barrier[%c1], %c1 : !barrierType
158+
159+
// Step 3. [GPU] Prefetch TMA Descriptors to L1 Cache
160+
nvgpu.tma.prefetch.descriptor %descA : !lhsTensorMap
161+
nvgpu.tma.prefetch.descriptor %descB : !rhsTensorMap
162+
163+
// Step 4.1 [GPU] TMA Load Pipeline 1
164+
scf.if %cnd {
165+
%pipe = arith.constant 0 : index
166+
%lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
167+
%rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
168+
%halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
169+
%halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
170+
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
171+
%dim = arith.muli %pipe, %c64 : index
172+
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
173+
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
174+
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
175+
}
176+
// Step 4.2 [GPU] TMA Load Pipeline 2
177+
scf.if %cnd {
178+
%pipe = arith.constant 1 : index
179+
%lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
180+
%rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
181+
%halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
182+
%halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
183+
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
184+
%dim = arith.muli %pipe, %c64 : index
185+
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
186+
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
187+
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
188+
}
189+
190+
// Step 5. [GPU] Initiliaze accumulator matrix
191+
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
192+
193+
// Step 6. [GPU] Main Loop Starts
194+
%15 = scf.for %i = %c0 to %c2 step %c1 iter_args(%mc = %14)
195+
-> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
196+
{
197+
%ticks = arith.constant 10000000 : index
198+
// TMA wait
199+
nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
200+
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
201+
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
202+
// Descriptor WGMMA
203+
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
204+
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
205+
// Perform WGMMA 128x128x64
206+
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
207+
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
208+
}
209+
210+
// Step 7. Wait all to finish mma
211+
nvvm.wgmma.wait.group.sync.aligned 0
212+
213+
// Step 8. [GPU] Epilogue, store fragmented register to shared memory
214+
%accShmem = memref.get_global @accShmem : memref<0xf32, 3>
215+
%accShmemPtr = memref.reinterpret_cast %accShmem to offset: [0], sizes: [128, 128], strides: [128, 1] : memref<0xf32, 3> to memref<128x128xf32, 3>
216+
nvgpu.warpgroup.mma.store %15, %accShmemPtr : <fragmented = vector<128x128xf32>> to memref<128x128xf32, 3>
217+
218+
// Step 9. [GPU] Epilogue, shared memory to global memory
219+
%17 = arith.divui %tidx, %c32 : index
220+
%18 = arith.remui %tidx, %c32 : index
221+
scf.for %arg12 = %17 to %c128 step %c4 {
222+
%19 = arith.muli %18, %c4 : index
223+
%20 = vector.load %accShmemPtr[%arg12, %19] : memref<128x128xf32, 3>, vector<4xf32>
224+
vector.store %20, %matrixD[%arg12, %19] : memref<128x128xf32>, vector<4xf32>
225+
}
226+
gpu.terminator
227+
}
228+
229+
// Step 5. Copy D2H
230+
%5 = gpu.memcpy async [%token] %matrixDHost, %matrixD : memref<128x128xf32>, memref<128x128xf32>
231+
gpu.wait [%token]
232+
233+
// Step 6. Compute on host
234+
linalg.matmul ins(%matrixAHost, %matrixBHost : memref<128x128xf16>, memref<128x128xf16>) outs(%matrixRefHost : memref<128x128xf32>)
235+
236+
// Step 7. Verify
237+
%ic1 = arith.constant 1 : i32
238+
%ic0 = arith.constant 0 : i32
239+
%tolerance = arith.constant 0.00000001 : f32
240+
%errorCount, %correctCount =
241+
scf.for %i = %hc0 to %hc128 step %hc1 iter_args(%ec1 = %ic0, %cc1 = %ic0) -> (i32,i32) {
242+
%ec2, %cc2 =
243+
scf.for %j = %hc0 to %hc128 step %hc1 iter_args(%ec2 = %ec1, %cc2 = %cc1) -> (i32,i32){
244+
%v1 = memref.load %matrixRefHost[%i,%j] : memref<128x128xf32>
245+
%v2 = memref.load %matrixDHost[%i,%j] : memref<128x128xf32>
246+
%g1 = arith.subf %v1,%v2 : f32
247+
%g2 = math.absf %g1: f32
248+
%g3 = arith.cmpf ult, %tolerance, %g2 : f32
249+
%ec3, %cc3 = scf.if %g3 -> (i32, i32) {
250+
%coor = arith.constant dense<-1> : vector<2xi32>
251+
%i32 = arith.index_cast %i : index to i32
252+
%j32 = arith.index_cast %j : index to i32
253+
%coord1 = vector.insert %i32, %coor[0] : i32 into vector<2xi32>
254+
%coord2 = vector.insert %j32, %coord1[1] : i32 into vector<2xi32>
255+
%ec3 = arith.addi %ec2, %ic1 : i32
256+
scf.yield %ec3, %cc2 : i32, i32
257+
} else {
258+
%cc3 = arith.addi %cc2, %ic1 : i32
259+
scf.yield %ec2, %cc3 : i32, i32
260+
}
261+
scf.yield %ec3, %cc3 : i32,i32
262+
}
263+
scf.yield %ec2,%cc2 : i32,i32
264+
}
265+
266+
vector.print str "Correct Results :"
267+
vector.print %correctCount : i32
268+
vector.print str "Incorrect Results :"
269+
vector.print %errorCount : i32
270+
271+
return
272+
}

0 commit comments

Comments
 (0)