Skip to content

Commit 894c077

Browse files
committed
[mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) with predicate
The llvm#69913 added GEMM test 128x128x128 (F32 += F16 * F16) with if-statement. This PR adds the same test by using predicates in PTX. The predicate support is used via `BasicPtxBuilderInterface` The predicate is calculated in `Step 2.` as follows. We influence from cutlass. ``` lane_predicate = nvvm.elect.sync warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) warp_idx_in_warp_group = warp_idx % 4 predicate = (lane_predicate & warp_idx_in_warp_group) ```
1 parent f4d5952 commit 894c077

File tree

1 file changed

+287
-0
lines changed

1 file changed

+287
-0
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
3+
// RUN: | mlir-cpu-runner \
4+
// RUN: --shared-libs=%mlir_cuda_runtime \
5+
// RUN: --shared-libs=%mlir_runner_utils \
6+
// RUN: --shared-libs=%mlir_c_runner_utils \
7+
// RUN: --entry-point-result=void \
8+
// RUN: | FileCheck %s
9+
10+
// CHECK: Correct Results : 16384
11+
// CHECK: Incorrect Results : 0
12+
13+
// This program performs 128x128x128 GEMM (F32 += F16 * F16)
14+
//
15+
// ## Sequential
16+
// for(128)
17+
// for(128)
18+
// for(128)
19+
// D += A * B
20+
//
21+
// ## Parallel 1 CTA with 1 Warpgroup with 2 pipelining stage
22+
//
23+
// cuda kernel() {
24+
// mbarriers.init[2]
25+
// for(i = 0;...2) {
26+
// tma.load shmem_buffer<i x...>
27+
// mbarrier.expect_tx group[i]
28+
// }
29+
// result =
30+
// for(i = 0;...2) {
31+
// pipe = i % 2
32+
// mbarrier.wait [pipe]
33+
// lhs = shmem_buffer_lhs<pipe x 128 x 64>
34+
// rhs = shmem_buffer_rhs<pipe x 64 x 128>
35+
// yield nvgpu.warpgroup.mma (lhs, rhs)
36+
// ---------------------------------------------------------------------
37+
// Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128]
38+
// wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128])
39+
// wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128])
40+
// wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128])
41+
// wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128])
42+
// wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128])
43+
// wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128])
44+
// wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128])
45+
// wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128])
46+
// ---------------------------------------------------------------------
47+
// }
48+
// nvgpu.store result -> shmem_buffer_result
49+
50+
51+
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 2>
52+
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
53+
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
54+
55+
func.func private @printMemrefF32(memref<*xf32>)
56+
llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
57+
58+
memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
59+
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
60+
61+
func.func @main() {
62+
%c214016_i32 = arith.constant 214016 : i32
63+
%hc1 = arith.constant 1 : index
64+
%hc4096 = arith.constant 4096 : index
65+
%hc0 = arith.constant 0 : index
66+
%hc64 = arith.constant 64 : index
67+
%hc16 = arith.constant 16 : index
68+
%hc8 = arith.constant 8 : index
69+
%hc128 = arith.constant 128 : index
70+
%hc32 = arith.constant 32 : index
71+
%hc256 = arith.constant 256 : index
72+
%f0 = arith.constant 0.0 : f32
73+
74+
// Step 1. Allocate and Initilize LHS and RHS Matrices
75+
%matrixAHost = memref.alloc() : memref<128x128xf16>
76+
%matrixBHost = memref.alloc() : memref<128x128xf16>
77+
%matrixDHost = memref.alloc() : memref<128x128xf32>
78+
%matrixRefHost = memref.alloc() : memref<128x128xf32>
79+
scf.for %i = %hc0 to %hc128 step %hc1 {
80+
scf.for %j = %hc0 to %hc128 step %hc1 {
81+
%v0 = arith.muli %i, %hc128 : index // i * 128
82+
%v00 = arith.addi %v0, %j : index // i * 128 + j
83+
%v01 = arith.divui %v00, %hc8 : index // (i * 128 + j) / 8
84+
%v02 = arith.remui %v01, %hc16 : index // <<<<< mod 128
85+
%v2 = arith.index_cast %v02 : index to i32
86+
%vR = arith.sitofp %v2 : i32 to f16
87+
memref.store %vR, %matrixBHost[%i, %j] : memref<128x128xf16>
88+
%b0 = arith.muli %j, %hc64 : index
89+
%b00 = arith.addi %b0, %i : index
90+
%b01 = arith.divui %b00, %hc8 : index
91+
%b02 = arith.remui %b01, %hc16 : index // <<<<< mod 128
92+
%v1 = arith.index_cast %b02 : index to i32
93+
%vL = arith.sitofp %v1 : i32 to f16
94+
memref.store %vL, %matrixAHost[%j, %i] : memref<128x128xf16>
95+
memref.store %f0, %matrixDHost[%i, %j] : memref<128x128xf32>
96+
memref.store %f0, %matrixRefHost[%i, %j] : memref<128x128xf32>
97+
}
98+
}
99+
100+
// Step 2. Allocate Device Memory for LHS and RHS Matrices and Copy H2D
101+
%token = gpu.wait async
102+
%matrixA:2 = gpu.alloc async [%token] () : memref<128x128xf16>
103+
%matrixB:2 = gpu.alloc async [%token] () : memref<128x128xf16>
104+
%matrixD:2 = gpu.alloc async [%token] () : memref<128x128xf32>
105+
%1 = gpu.memcpy async [%token] %matrixA, %matrixAHost : memref<128x128xf16>, memref<128x128xf16>
106+
%2 = gpu.memcpy async [%token] %matrixB, %matrixBHost : memref<128x128xf16>, memref<128x128xf16>
107+
%castA = memref.cast %matrixA : memref<128x128xf16> to memref<*xf16>
108+
%castB = memref.cast %matrixB : memref<128x128xf16> to memref<*xf16>
109+
110+
// Step 3. Create TMA Descriptor
111+
%descA = nvgpu.tma.create.descriptor %castA box[%hc128, %hc64] : memref<*xf16> -> !lhsTensorMap
112+
%descB = nvgpu.tma.create.descriptor %castB box[%hc64, %hc64] : memref<*xf16> -> !rhsTensorMap
113+
114+
// Step 4. Launch GPU Kernel
115+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %hc1, %arg7 = %hc1, %arg8 = %hc1)
116+
threads(%arg3, %arg4, %arg5) in (%arg9 = %hc128, %arg10 = %hc1, %arg11 = %hc1)
117+
dynamic_shared_memory_size %c214016_i32
118+
{
119+
memref.assume_alignment %matrixD, 16 : memref<128x128xf32>
120+
121+
%c256 = arith.constant 256 : index
122+
%c10000000 = arith.constant 10000000 : index
123+
%c32768 = arith.constant 32768 : index
124+
%c320 = arith.constant 320 : index
125+
%c192 = arith.constant 192 : index
126+
%c6 = arith.constant 6 : index
127+
%c5 = arith.constant 5 : index
128+
%c4 = arith.constant 4 : index
129+
%c3 = arith.constant 3 : index
130+
%c7 = arith.constant 7 : index
131+
%c64 = arith.constant 64 : index
132+
%c1 = arith.constant 1 : index
133+
%c2 = arith.constant 2 : index
134+
%c0 = arith.constant 0 : index
135+
%c128 = arith.constant 128 : index
136+
%c32 = arith.constant 32 : index
137+
%c16 = arith.constant 16 : index
138+
%c4096 = arith.constant 4096 : index
139+
%c8 = arith.constant 8 : index
140+
%txcount = arith.constant 32768 : index
141+
142+
%tidx = gpu.thread_id x
143+
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
144+
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [7, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<7x128x64xf16, 3>
145+
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
146+
%rhsShmem = memref.subview %rhsShmem2[7, 0, 0][7, 64, 128][1, 1, 1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
147+
148+
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
149+
%barrier = nvgpu.mbarrier.create -> !barrierType
150+
151+
// Step 2. [GPU] Elect fastest thread in CTA
152+
%mask = arith.constant -1 : i32
153+
%i0 = arith.constant 0 : i32
154+
%i32 = arith.constant 32 : i32
155+
%i4 = arith.constant 4 : i32
156+
%lanePredicate = nvvm.elect.sync -> i1
157+
%warpIdx = arith.divui %tidx, %c32 : index
158+
%warpIdxi32 = index.casts %warpIdx : index to i32
159+
%canonical_warp_idx = nvvm.shfl.sync idx %i32, %warpIdxi32, %i0, %mask : i32 -> i32
160+
%warp_idx_in_group = arith.remui %canonical_warp_idx, %i4 : i32
161+
%cnd1 = arith.cmpi eq, %warp_idx_in_group, %i0 : i32
162+
%cnd = arith.andi %cnd1, %lanePredicate : i1
163+
164+
// Step 3. [GPU] Initialize mbarriers (predicated threadIdx==0)
165+
nvgpu.mbarrier.init %barrier[%c0], %c1, predicate = %cnd : !barrierType
166+
nvgpu.mbarrier.init %barrier[%c1], %c1, predicate = %cnd : !barrierType
167+
168+
// Step 4.1 [GPU] Prefetch TMA Descriptors to L1 Cache (predicated)
169+
nvgpu.tma.prefetch.descriptor %descA, predicate = %cnd : !lhsTensorMap
170+
nvgpu.tma.prefetch.descriptor %descB, predicate = %cnd : !rhsTensorMap
171+
172+
// Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
173+
%pipe1 = arith.constant 0 : index
174+
%p1lhsSlice = memref.subview %lhsShmem [0, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1]>, 3>
175+
%p1rhsSlice = memref.subview %rhsShmem [0, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3>
176+
%p1rhsSlice2 = memref.subview %p1rhsSlice[0, 32, 0][1, 128, 64][1,1,1] : memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 61440>, 3>
177+
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe1], %txcount, predicate = %cnd : !barrierType
178+
%dim1 = arith.muli %pipe1, %c64 : index
179+
nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %p1lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<1x64x128xf16, strided<[8192, 64, 1]>, 3>
180+
nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %p1rhsSlice, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3>
181+
nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %p1rhsSlice2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 61440>, 3>
182+
183+
// Step 5. [GPU] TMA Load Pipeline 2 (predicated)
184+
%pipe2 = arith.constant 1 : index
185+
%p2lhsSlice = memref.subview %lhsShmem [1, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1], offset: 8192>, 3>
186+
%p2rhsSlice = memref.subview %rhsShmem [1, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3>
187+
%p2rhsSlice2 = memref.subview %p2rhsSlice[0, 32, 0][1, 128, 64][1,1,1] : memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 69632>, 3>
188+
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe2], %txcount, predicate = %cnd : !barrierType
189+
%dim2 = arith.muli %pipe2, %c64 : index
190+
nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %p2lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<1x64x128xf16, strided<[8192, 64, 1], offset: 8192>, 3>
191+
nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %p2rhsSlice, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3>
192+
nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %p2rhsSlice2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 69632>, 3>
193+
194+
// Step 6. [GPU] Initiliaze accumulator matrix
195+
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
196+
197+
// Step 7. [GPU] Main Loop Starts
198+
%15 = scf.for %i = %c0 to %c2 step %c1 iter_args(%mc = %14)
199+
-> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
200+
{
201+
%ticks = arith.constant 10000000 : index
202+
// TMA wait
203+
nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
204+
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1], offset: ?>, 3>
205+
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: ?>, 3>
206+
// Descriptor WGMMA
207+
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<1x64x128xf16, strided<[8192, 64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
208+
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<1x128x64xf16, strided<[8192, 128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
209+
// Perform WGMMA 128x128x64
210+
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
211+
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
212+
}
213+
214+
// Step 8. Wait all to finish mma
215+
nvvm.wgmma.wait.group.sync.aligned 0
216+
217+
// Step 9. [GPU] Epilogue, store fragmented register to shared memory
218+
%accShmem = memref.get_global @accShmem : memref<0xf32, 3>
219+
%accShmemPtr = memref.reinterpret_cast %accShmem to offset: [0], sizes: [128, 128], strides: [128, 1] : memref<0xf32, 3> to memref<128x128xf32, 3>
220+
nvgpu.warpgroup.mma.store %15, %accShmemPtr : <fragmented = vector<128x128xf32>> to memref<128x128xf32, 3>
221+
222+
// Step 10. [GPU] Epilogue, shared memory to global memory
223+
%17 = arith.divui %tidx, %c32 : index
224+
%18 = arith.remui %tidx, %c32 : index
225+
scf.for %arg12 = %17 to %c128 step %c4 {
226+
%19 = arith.muli %18, %c4 : index
227+
%20 = vector.load %accShmemPtr[%arg12, %19] : memref<128x128xf32, 3>, vector<4xf32>
228+
vector.store %20, %matrixD[%arg12, %19] : memref<128x128xf32>, vector<4xf32>
229+
}
230+
gpu.terminator
231+
}
232+
233+
// Step 5. Copy D2H
234+
%5 = gpu.memcpy async [%token] %matrixDHost, %matrixD : memref<128x128xf32>, memref<128x128xf32>
235+
gpu.wait [%token]
236+
237+
// Step 6. Compute on host
238+
linalg.matmul ins(%matrixAHost, %matrixBHost : memref<128x128xf16>, memref<128x128xf16>) outs(%matrixRefHost : memref<128x128xf32>)
239+
240+
// Step 7. Verify
241+
%ic1 = arith.constant 1 : i32
242+
%ic0 = arith.constant 0 : i32
243+
%tolerance = arith.constant 0.00000001 : f32
244+
%errorCount, %correctCount =
245+
scf.for %i = %hc0 to %hc128 step %hc1 iter_args(%ec1 = %ic0, %cc1 = %ic0) -> (i32,i32) {
246+
%ec2, %cc2 =
247+
scf.for %j = %hc0 to %hc128 step %hc1 iter_args(%ec2 = %ec1, %cc2 = %cc1) -> (i32,i32){
248+
%v1 = memref.load %matrixRefHost[%i,%j] : memref<128x128xf32>
249+
%v2 = memref.load %matrixDHost[%i,%j] : memref<128x128xf32>
250+
%g1 = arith.subf %v1,%v2 : f32
251+
%g2 = math.absf %g1: f32
252+
%g3 = arith.cmpf ult, %tolerance, %g2 : f32
253+
%ec3, %cc3 = scf.if %g3 -> (i32, i32) {
254+
%coor = arith.constant dense<-1> : vector<2xi32>
255+
%i32 = arith.index_cast %i : index to i32
256+
%j32 = arith.index_cast %j : index to i32
257+
%coord1 = vector.insert %i32, %coor[0] : i32 into vector<2xi32>
258+
%coord2 = vector.insert %j32, %coord1[1] : i32 into vector<2xi32>
259+
%ec3 = arith.addi %ec2, %ic1 : i32
260+
scf.yield %ec3, %cc2 : i32, i32
261+
} else {
262+
%cc3 = arith.addi %cc2, %ic1 : i32
263+
scf.yield %ec2, %cc3 : i32, i32
264+
}
265+
scf.yield %ec3, %cc3 : i32,i32
266+
}
267+
scf.yield %ec2,%cc2 : i32,i32
268+
}
269+
270+
%s0 = llvm.mlir.addressof @str_correct : !llvm.ptr<array<18 x i8>>
271+
%s1 = llvm.mlir.constant(0 : index) : i64
272+
%s2 = llvm.getelementptr %s0[%s1, %s1]
273+
: (!llvm.ptr<array<18 x i8>>, i64, i64) -> !llvm.ptr<i8>
274+
func.call @printCString(%s2) : (!llvm.ptr<i8>) -> ()
275+
vector.print %correctCount : i32
276+
%s3 = llvm.mlir.addressof @str_incorrect : !llvm.ptr<array<20 x i8>>
277+
%s4 = llvm.getelementptr %s3[%s1, %s1]
278+
: (!llvm.ptr<array<20 x i8>>, i64, i64) -> !llvm.ptr<i8>
279+
func.call @printCString(%s4) : (!llvm.ptr<i8>) -> ()
280+
vector.print %errorCount : i32
281+
282+
return
283+
}
284+
llvm.mlir.global internal constant @str_correct("Correct Results : ") {addr_space = 0 : i32}
285+
llvm.mlir.global internal constant @str_incorrect("Incorrect Results : ") {addr_space = 0 : i32}
286+
func.func private @printCString(!llvm.ptr<i8>)
287+

0 commit comments

Comments
 (0)