Skip to content

Commit e9c4e33

Browse files
committed
Improvments
- use vector.print str - fix matrix-b tensor descriptor to 64x64 - memref rank reduction
1 parent f8ec169 commit e9c4e33

File tree

1 file changed

+33
-38
lines changed

1 file changed

+33
-38
lines changed

mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
// RUN: --entry-point-result=void \
88
// RUN: | FileCheck %s
99

10-
// CHECK: Correct Results : 16384
11-
// CHECK: Incorrect Results : 0
10+
// CHECK: Correct Results :
11+
// CHECK: 16384
12+
// CHECK: Incorrect Results :
13+
// CHECK: 0
1214

1315
// This program performs 128x128x128 GEMM (F32 += F16 * F16)
1416
//
@@ -50,7 +52,7 @@
5052

5153
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 2>
5254
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
53-
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
55+
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
5456

5557
func.func private @printMemrefF32(memref<*xf32>)
5658
llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
@@ -59,7 +61,9 @@ memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
5961
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
6062

6163
func.func @main() {
62-
%c214016_i32 = arith.constant 214016 : i32
64+
// matrix A (128*64) * matrix B (64*128) * stages(2)
65+
// matrix A [128][64] * matrix B[64][128] * stages(2)
66+
%shmemSize = arith.constant 65536 : i32
6367
%hc1 = arith.constant 1 : index
6468
%hc4096 = arith.constant 4096 : index
6569
%hc0 = arith.constant 0 : index
@@ -114,7 +118,7 @@ func.func @main() {
114118
// Step 4. Launch GPU Kernel
115119
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %hc1, %arg7 = %hc1, %arg8 = %hc1)
116120
threads(%arg3, %arg4, %arg5) in (%arg9 = %hc128, %arg10 = %hc1, %arg11 = %hc1)
117-
dynamic_shared_memory_size %c214016_i32
121+
dynamic_shared_memory_size %shmemSize
118122
{
119123
memref.assume_alignment %matrixD, 16 : memref<128x128xf32>
120124

@@ -141,9 +145,9 @@ func.func @main() {
141145

142146
%tidx = gpu.thread_id x
143147
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
144-
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [7, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<7x128x64xf16, 3>
145-
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
146-
%rhsShmem = memref.subview %rhsShmem2[7, 0, 0][7, 64, 128][1, 1, 1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
148+
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
149+
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
150+
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
147151

148152
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
149153
%barrier = nvgpu.mbarrier.create -> !barrierType
@@ -156,30 +160,32 @@ func.func @main() {
156160
// Step 3. [GPU] Prefetch TMA Descriptors to L1 Cache
157161
nvgpu.tma.prefetch.descriptor %descA : !lhsTensorMap
158162
nvgpu.tma.prefetch.descriptor %descB : !rhsTensorMap
159-
163+
160164
// Step 4.1 [GPU] TMA Load Pipeline 1
161165
scf.if %cnd {
162166
%pipe = arith.constant 0 : index
163-
%lhsSlice = memref.subview %lhsShmem [0, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1]>, 3>
164-
%rhsSlice = memref.subview %rhsShmem [0, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3>
165-
%rhsSlice2 = memref.subview %rhsSlice[0, 32, 0][1, 128, 64][1,1,1] : memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 61440>, 3>
167+
%lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
168+
%rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
169+
%halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
170+
%halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
166171
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
167172
%dim = arith.muli %pipe, %c64 : index
168-
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<1x64x128xf16, strided<[8192, 64, 1]>, 3>
169-
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %rhsSlice : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3>
170-
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %rhsSlice2 : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 61440>, 3>
173+
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
174+
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
175+
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
171176
}
172177
// Step 4.2 [GPU] TMA Load Pipeline 2
173178
scf.if %cnd {
174179
%pipe = arith.constant 1 : index
175-
%lhsSlice = memref.subview %lhsShmem [1, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1], offset: 8192>, 3>
176-
%rhsSlice = memref.subview %rhsShmem [1, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3>
177-
%rhsSlice2 = memref.subview %rhsSlice[0, 32, 0][1, 128, 64][1,1,1] : memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 69632>, 3>
180+
%lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
181+
%rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
182+
%halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
183+
%halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
178184
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
179185
%dim = arith.muli %pipe, %c64 : index
180-
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<1x64x128xf16, strided<[8192, 64, 1], offset: 8192>, 3>
181-
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %rhsSlice : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3>
182-
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %rhsSlice2 : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 69632>, 3>
186+
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
187+
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
188+
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
183189
}
184190

185191
// Step 5. [GPU] Initiliaze accumulator matrix
@@ -192,11 +198,11 @@ func.func @main() {
192198
%ticks = arith.constant 10000000 : index
193199
// TMA wait
194200
nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
195-
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1], offset: ?>, 3>
196-
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: ?>, 3>
201+
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
202+
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
197203
// Descriptor WGMMA
198-
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<1x64x128xf16, strided<[8192, 64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
199-
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<1x128x64xf16, strided<[8192, 128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
204+
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
205+
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
200206
// Perform WGMMA 128x128x64
201207
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
202208
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
@@ -258,21 +264,10 @@ func.func @main() {
258264
scf.yield %ec2,%cc2 : i32,i32
259265
}
260266

261-
%s0 = llvm.mlir.addressof @str_correct : !llvm.ptr<array<18 x i8>>
262-
%s1 = llvm.mlir.constant(0 : index) : i64
263-
%s2 = llvm.getelementptr %s0[%s1, %s1]
264-
: (!llvm.ptr<array<18 x i8>>, i64, i64) -> !llvm.ptr<i8>
265-
func.call @printCString(%s2) : (!llvm.ptr<i8>) -> ()
267+
vector.print str "Correct Results :"
266268
vector.print %correctCount : i32
267-
%s3 = llvm.mlir.addressof @str_incorrect : !llvm.ptr<array<20 x i8>>
268-
%s4 = llvm.getelementptr %s3[%s1, %s1]
269-
: (!llvm.ptr<array<20 x i8>>, i64, i64) -> !llvm.ptr<i8>
270-
func.call @printCString(%s4) : (!llvm.ptr<i8>) -> ()
269+
vector.print str "Incorrect Results :"
271270
vector.print %errorCount : i32
272271

273272
return
274273
}
275-
llvm.mlir.global internal constant @str_correct("Correct Results : ") {addr_space = 0 : i32}
276-
llvm.mlir.global internal constant @str_incorrect("Incorrect Results : ") {addr_space = 0 : i32}
277-
func.func private @printCString(!llvm.ptr<i8>)
278-

0 commit comments

Comments
 (0)