7
7
// RUN: --entry-point-result=void \
8
8
// RUN: | FileCheck %s
9
9
10
- // CHECK: Correct Results : 16384
11
- // CHECK: Incorrect Results : 0
10
+ // CHECK: Correct Results :
11
+ // CHECK: 16384
12
+ // CHECK: Incorrect Results :
13
+ // CHECK: 0
12
14
13
15
// This program performs 128x128x128 GEMM (F32 += F16 * F16)
14
16
//
50
52
51
53
!barrierType = !nvgpu.mbarrier.group <memorySpace = #gpu.address_space <workgroup >, num_barriers = 2 >
52
54
!lhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref <128 x64 xf16 , 3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
53
- !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref <64 x 128 x f16 , 3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
55
+ !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref <64 x 64 x f16 , 3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
54
56
55
57
func.func private @printMemrefF32 (memref <*xf32 >)
56
58
llvm.func @printf (!llvm.ptr <i8 >, ...) -> i32
@@ -59,7 +61,9 @@ memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
59
61
memref.global " private" @accShmem : memref <0 xf32 , 3 > {alignment = 16 : i64 }
60
62
61
63
func.func @main () {
62
- %c214016_i32 = arith.constant 214016 : i32
64
+ // matrix A (128*64) * matrix B (64*128) * stages(2)
65
+ // matrix A [128][64] * matrix B[64][128] * stages(2)
66
+ %shmemSize = arith.constant 65536 : i32
63
67
%hc1 = arith.constant 1 : index
64
68
%hc4096 = arith.constant 4096 : index
65
69
%hc0 = arith.constant 0 : index
@@ -114,7 +118,7 @@ func.func @main() {
114
118
// Step 4. Launch GPU Kernel
115
119
gpu.launch blocks (%arg0 , %arg1 , %arg2 ) in (%arg6 = %hc1 , %arg7 = %hc1 , %arg8 = %hc1 )
116
120
threads (%arg3 , %arg4 , %arg5 ) in (%arg9 = %hc128 , %arg10 = %hc1 , %arg11 = %hc1 )
117
- dynamic_shared_memory_size %c214016_i32
121
+ dynamic_shared_memory_size %shmemSize
118
122
{
119
123
memref.assume_alignment %matrixD , 16 : memref <128 x128 xf32 >
120
124
@@ -141,9 +145,9 @@ func.func @main() {
141
145
142
146
%tidx = gpu.thread_id x
143
147
%dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
144
- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [7 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <7 x 128 x 64 x f16 , 3 >
145
- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [14 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <14 x 64 x 128 x f16 ,3 >
146
- %rhsShmem = memref.subview %rhsShmem2 [7 , 0 , 0 ][7 , 64 , 128 ][1 , 1 , 1 ] : memref <14 x 64 x 128 x f16 ,3 > to memref <7 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 >
148
+ %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x 128 x 64 x f16 , 3 >
149
+ %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x 64 x 128 x f16 ,3 >
150
+ %rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x 64 x 128 x f16 ,3 > to memref <2 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
147
151
148
152
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
149
153
%barrier = nvgpu.mbarrier.create -> !barrierType
@@ -156,30 +160,32 @@ func.func @main() {
156
160
// Step 3. [GPU] Prefetch TMA Descriptors to L1 Cache
157
161
nvgpu.tma.prefetch.descriptor %descA : !lhsTensorMap
158
162
nvgpu.tma.prefetch.descriptor %descB : !rhsTensorMap
159
-
163
+
160
164
// Step 4.1 [GPU] TMA Load Pipeline 1
161
165
scf.if %cnd {
162
166
%pipe = arith.constant 0 : index
163
- %lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <7 x128 x64 xf16 ,3 > to memref <1 x64 x128 xf16 , strided <[8192 , 64 , 1 ]>, 3 >
164
- %rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <7 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 >
165
- %rhsSlice2 = memref.subview %rhsSlice [0 , 32 , 0 ][1 , 128 , 64 ][1 ,1 ,1 ] : memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 61440 >, 3 >
167
+ %lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , 3 >
168
+ %rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
169
+ %halfFirst = memref.subview %rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
170
+ %halfSecond = memref.subview %rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 20480 >, 3 >
166
171
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe ], %txcount : !barrierType
167
172
%dim = arith.muli %pipe , %c64 : index
168
- nvgpu.tma.async.load %descA [%dim , %c0 ], %barrier [%pipe ] to %lhsSlice : !lhsTensorMap , !barrierType -> memref <1 x 64 x 128 x f16 , strided <[ 8192 , 64 , 1 ]> , 3 >
169
- nvgpu.tma.async.load %descB [%c0 , %dim ], %barrier [%pipe ] to %rhsSlice : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 >
170
- nvgpu.tma.async.load %descB [%c64 , %dim ], %barrier [%pipe ] to %rhsSlice2 : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 61440 >, 3 >
173
+ nvgpu.tma.async.load %descA [%dim , %c0 ], %barrier [%pipe ] to %lhsSlice : !lhsTensorMap , !barrierType -> memref <128 x 64 x f16 , 3 >
174
+ nvgpu.tma.async.load %descB [%c0 , %dim ], %barrier [%pipe ] to %halfFirst : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 16384 >, 3 >
175
+ nvgpu.tma.async.load %descB [%c64 , %dim ], %barrier [%pipe ] to %halfSecond : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 20480 >, 3 >
171
176
}
172
177
// Step 4.2 [GPU] TMA Load Pipeline 2
173
178
scf.if %cnd {
174
179
%pipe = arith.constant 1 : index
175
- %lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <7 x128 x64 xf16 ,3 > to memref <1 x64 x128 xf16 , strided <[8192 , 64 , 1 ], offset : 8192 >, 3 >
176
- %rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <7 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 65536 >, 3 >
177
- %rhsSlice2 = memref.subview %rhsSlice [0 , 32 , 0 ][1 , 128 , 64 ][1 ,1 ,1 ] : memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 65536 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 69632 >, 3 >
180
+ %lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
181
+ %rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
182
+ %halfFirst = memref.subview %rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
183
+ %halfSecond = memref.subview %rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
178
184
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe ], %txcount : !barrierType
179
185
%dim = arith.muli %pipe , %c64 : index
180
- nvgpu.tma.async.load %descA [%dim , %c0 ], %barrier [%pipe ] to %lhsSlice : !lhsTensorMap , !barrierType -> memref <1 x 64 x 128 x f16 , strided <[8192 , 64 , 1 ], offset : 8192 >, 3 >
181
- nvgpu.tma.async.load %descB [%c0 , %dim ], %barrier [%pipe ] to %rhsSlice : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 65536 >, 3 >
182
- nvgpu.tma.async.load %descB [%c64 , %dim ], %barrier [%pipe ] to %rhsSlice2 : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 69632 >, 3 >
186
+ nvgpu.tma.async.load %descA [%dim , %c0 ], %barrier [%pipe ] to %lhsSlice : !lhsTensorMap , !barrierType -> memref <128 x 64 x f16 , strided <[64 , 1 ], offset : 8192 >, 3 >
187
+ nvgpu.tma.async.load %descB [%c0 , %dim ], %barrier [%pipe ] to %halfFirst : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 24576 >, 3 >
188
+ nvgpu.tma.async.load %descB [%c64 , %dim ], %barrier [%pipe ] to %halfSecond : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 28672 >, 3 >
183
189
}
184
190
185
191
// Step 5. [GPU] Initiliaze accumulator matrix
@@ -192,11 +198,11 @@ func.func @main() {
192
198
%ticks = arith.constant 10000000 : index
193
199
// TMA wait
194
200
nvgpu.mbarrier.try_wait.parity %barrier [%i ], %c0 , %ticks : !barrierType
195
- %lhsSlice = memref.subview %lhsShmem [%i , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <7 x 128 x 64 x f16 , 3 > to memref <1 x 64 x 128 x f16 , strided <[8192 , 64 , 1 ], offset : ?>, 3 >
196
- %rhsSlice = memref.subview %rhsShmem [%i , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <7 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : ?>, 3 >
201
+ %lhsSlice = memref.subview %lhsShmem [%i , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x 128 x 64 x f16 , 3 > to memref <128 x 64 x f16 , strided <[64 , 1 ], offset : ?>, 3 >
202
+ %rhsSlice = memref.subview %rhsShmem [%i , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x 128 x f16 , strided <[128 , 1 ], offset : ?>, 3 >
197
203
// Descriptor WGMMA
198
- %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice , %descA : memref <1 x 64 x 128 x f16 , strided <[8192 , 64 , 1 ], offset : ?>, 3 >, !lhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <128 x64 xf16 , 3 >>
199
- %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice , %descB : memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : ?>, 3 >, !rhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <64 x128 xf16 , 3 >>
204
+ %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice , %descA : memref <128 x 64 x f16 , strided <[64 , 1 ], offset : ?>, 3 >, !lhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <128 x64 xf16 , 3 >>
205
+ %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice , %descB : memref <64 x 128 x f16 , strided <[128 , 1 ], offset : ?>, 3 >, !rhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <64 x128 xf16 , 3 >>
200
206
// Perform WGMMA 128x128x64
201
207
%md = nvgpu.warpgroup.mma %dA , %dB , %mc {transposeB } : <tensor = memref <128 x64 xf16 ,3 >>, <tensor = memref <64 x128 xf16 ,3 >>, <fragmented = vector <128 x128 xf32 >> -> <fragmented = vector <128 x128 xf32 >>
202
208
scf.yield %md : !nvgpu.warpgroup.accumulator <fragmented = vector <128 x128 xf32 >>
@@ -258,21 +264,10 @@ func.func @main() {
258
264
scf.yield %ec2 ,%cc2 : i32 ,i32
259
265
}
260
266
261
- %s0 = llvm.mlir.addressof @str_correct : !llvm.ptr <array <18 x i8 >>
262
- %s1 = llvm.mlir.constant (0 : index ) : i64
263
- %s2 = llvm.getelementptr %s0 [%s1 , %s1 ]
264
- : (!llvm.ptr <array <18 x i8 >>, i64 , i64 ) -> !llvm.ptr <i8 >
265
- func.call @printCString (%s2 ) : (!llvm.ptr <i8 >) -> ()
267
+ vector.print str " Correct Results :"
266
268
vector.print %correctCount : i32
267
- %s3 = llvm.mlir.addressof @str_incorrect : !llvm.ptr <array <20 x i8 >>
268
- %s4 = llvm.getelementptr %s3 [%s1 , %s1 ]
269
- : (!llvm.ptr <array <20 x i8 >>, i64 , i64 ) -> !llvm.ptr <i8 >
270
- func.call @printCString (%s4 ) : (!llvm.ptr <i8 >) -> ()
269
+ vector.print str " Incorrect Results :"
271
270
vector.print %errorCount : i32
272
271
273
272
return
274
273
}
275
- llvm.mlir.global internal constant @str_correct (" Correct Results : " ) {addr_space = 0 : i32 }
276
- llvm.mlir.global internal constant @str_incorrect (" Incorrect Results : " ) {addr_space = 0 : i32 }
277
- func.func private @printCString (!llvm.ptr <i8 >)
278
-
0 commit comments