Skip to content

Commit 91eade8

Browse files
committed
Improvments
- use vector.print str - fix matrix-b tensor descriptor to 64x64 - memref rank reduction
1 parent 894c077 commit 91eade8

File tree

1 file changed

+34
-37
lines changed

1 file changed

+34
-37
lines changed

mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
// RUN: --entry-point-result=void \
88
// RUN: | FileCheck %s
99

10-
// CHECK: Correct Results : 16384
11-
// CHECK: Incorrect Results : 0
10+
// CHECK: Correct Results :
11+
// CHECK: 16384
12+
// CHECK: Incorrect Results :
13+
// CHECK: 0
1214

1315
// This program performs 128x128x128 GEMM (F32 += F16 * F16)
1416
//
@@ -50,7 +52,7 @@
5052

5153
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 2>
5254
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
53-
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
55+
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
5456

5557
func.func private @printMemrefF32(memref<*xf32>)
5658
llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
@@ -59,7 +61,9 @@ memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
5961
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
6062

6163
func.func @main() {
62-
%c214016_i32 = arith.constant 214016 : i32
64+
// matrix A (128*64) * matrix B (64*128) * stages(2)
65+
// matrix A [128][64] * matrix B[64][128] * stages(2)
66+
%shmemSize = arith.constant 65536 : i32
6367
%hc1 = arith.constant 1 : index
6468
%hc4096 = arith.constant 4096 : index
6569
%hc0 = arith.constant 0 : index
@@ -114,7 +118,7 @@ func.func @main() {
114118
// Step 4. Launch GPU Kernel
115119
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %hc1, %arg7 = %hc1, %arg8 = %hc1)
116120
threads(%arg3, %arg4, %arg5) in (%arg9 = %hc128, %arg10 = %hc1, %arg11 = %hc1)
117-
dynamic_shared_memory_size %c214016_i32
121+
dynamic_shared_memory_size %shmemSize
118122
{
119123
memref.assume_alignment %matrixD, 16 : memref<128x128xf32>
120124

@@ -136,14 +140,14 @@ func.func @main() {
136140
%c32 = arith.constant 32 : index
137141
%c16 = arith.constant 16 : index
138142
%c4096 = arith.constant 4096 : index
139-
%c8 = arith.constant 8 : index
143+
%c8 = arith.constant 8 : index
140144
%txcount = arith.constant 32768 : index
141145

142146
%tidx = gpu.thread_id x
143147
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
144-
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [7, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<7x128x64xf16, 3>
145-
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
146-
%rhsShmem = memref.subview %rhsShmem2[7, 0, 0][7, 64, 128][1, 1, 1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
148+
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
149+
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
150+
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
147151

148152
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
149153
%barrier = nvgpu.mbarrier.create -> !barrierType
@@ -171,25 +175,27 @@ func.func @main() {
171175

172176
// Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
173177
%pipe1 = arith.constant 0 : index
174-
%p1lhsSlice = memref.subview %lhsShmem [0, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1]>, 3>
175-
%p1rhsSlice = memref.subview %rhsShmem [0, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3>
176-
%p1rhsSlice2 = memref.subview %p1rhsSlice[0, 32, 0][1, 128, 64][1,1,1] : memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 61440>, 3>
178+
%p1lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
179+
%p1rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
180+
%p1halfFirst = memref.subview %p1rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
181+
%p1halfSecond = memref.subview %p1rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
177182
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe1], %txcount, predicate = %cnd : !barrierType
178183
%dim1 = arith.muli %pipe1, %c64 : index
179-
nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %p1lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<1x64x128xf16, strided<[8192, 64, 1]>, 3>
180-
nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %p1rhsSlice, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 57344>, 3>
181-
nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %p1rhsSlice2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 61440>, 3>
184+
nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %p1lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
185+
nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %p1halfFirst, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
186+
nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %p1halfSecond, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
182187

183188
// Step 5. [GPU] TMA Load Pipeline 2 (predicated)
184189
%pipe2 = arith.constant 1 : index
185-
%p2lhsSlice = memref.subview %lhsShmem [1, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1], offset: 8192>, 3>
186-
%p2rhsSlice = memref.subview %rhsShmem [1, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3>
187-
%p2rhsSlice2 = memref.subview %p2rhsSlice[0, 32, 0][1, 128, 64][1,1,1] : memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: 69632>, 3>
190+
%p2lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
191+
%p2rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
192+
%p2halfFirst = memref.subview %p2rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
193+
%p2halfSecond = memref.subview %p2rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
188194
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe2], %txcount, predicate = %cnd : !barrierType
189195
%dim2 = arith.muli %pipe2, %c64 : index
190-
nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %p2lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<1x64x128xf16, strided<[8192, 64, 1], offset: 8192>, 3>
191-
nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %p2rhsSlice, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 65536>, 3>
192-
nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %p2rhsSlice2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<1x128x64xf16, strided<[8192, 128, 1], offset: 69632>, 3>
196+
nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %p2lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
197+
nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %p2halfFirst, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
198+
nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %p2halfSecond, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
193199

194200
// Step 6. [GPU] Initiliaze accumulator matrix
195201
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
@@ -201,11 +207,11 @@ func.func @main() {
201207
%ticks = arith.constant 10000000 : index
202208
// TMA wait
203209
nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
204-
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<7x128x64xf16,3> to memref<1x64x128xf16, strided<[8192, 64, 1], offset: ?>, 3>
205-
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<1x128x64xf16, strided<[8192, 128, 1], offset: ?>, 3>
210+
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
211+
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
206212
// Descriptor WGMMA
207-
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<1x64x128xf16, strided<[8192, 64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
208-
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<1x128x64xf16, strided<[8192, 128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
213+
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
214+
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
209215
// Perform WGMMA 128x128x64
210216
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
211217
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
@@ -267,21 +273,12 @@ func.func @main() {
267273
scf.yield %ec2,%cc2 : i32,i32
268274
}
269275

270-
%s0 = llvm.mlir.addressof @str_correct : !llvm.ptr<array<18 x i8>>
271-
%s1 = llvm.mlir.constant(0 : index) : i64
272-
%s2 = llvm.getelementptr %s0[%s1, %s1]
273-
: (!llvm.ptr<array<18 x i8>>, i64, i64) -> !llvm.ptr<i8>
274-
func.call @printCString(%s2) : (!llvm.ptr<i8>) -> ()
276+
vector.print str "Correct Results :"
275277
vector.print %correctCount : i32
276-
%s3 = llvm.mlir.addressof @str_incorrect : !llvm.ptr<array<20 x i8>>
277-
%s4 = llvm.getelementptr %s3[%s1, %s1]
278-
: (!llvm.ptr<array<20 x i8>>, i64, i64) -> !llvm.ptr<i8>
279-
func.call @printCString(%s4) : (!llvm.ptr<i8>) -> ()
278+
vector.print str "Incorrect Results :"
280279
vector.print %errorCount : i32
281280

282281
return
283282
}
284-
llvm.mlir.global internal constant @str_correct("Correct Results : ") {addr_space = 0 : i32}
285-
llvm.mlir.global internal constant @str_incorrect("Incorrect Results : ") {addr_space = 0 : i32}
286-
func.func private @printCString(!llvm.ptr<i8>)
283+
287284

0 commit comments

Comments
 (0)