7
7
// RUN: --entry-point-result=void \
8
8
// RUN: | FileCheck %s
9
9
10
- // CHECK: Correct Results : 16384
11
- // CHECK: Incorrect Results : 0
10
+ // CHECK: Correct Results :
11
+ // CHECK: 16384
12
+ // CHECK: Incorrect Results :
13
+ // CHECK: 0
12
14
13
15
// This program performs 128x128x128 GEMM (F32 += F16 * F16)
14
16
//
50
52
51
53
!barrierType = !nvgpu.mbarrier.group <memorySpace = #gpu.address_space <workgroup >, num_barriers = 2 >
52
54
!lhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref <128 x64 xf16 , 3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
53
- !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref <64 x 128 x f16 , 3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
55
+ !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref <64 x 64 x f16 , 3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
54
56
55
57
func.func private @printMemrefF32 (memref <*xf32 >)
56
58
llvm.func @printf (!llvm.ptr <i8 >, ...) -> i32
@@ -59,7 +61,9 @@ memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
59
61
memref.global " private" @accShmem : memref <0 xf32 , 3 > {alignment = 16 : i64 }
60
62
61
63
func.func @main () {
62
- %c214016_i32 = arith.constant 214016 : i32
64
+ // matrix A (128*64) * matrix B (64*128) * stages(2)
65
+ // matrix A [128][64] * matrix B[64][128] * stages(2)
66
+ %shmemSize = arith.constant 65536 : i32
63
67
%hc1 = arith.constant 1 : index
64
68
%hc4096 = arith.constant 4096 : index
65
69
%hc0 = arith.constant 0 : index
@@ -114,7 +118,7 @@ func.func @main() {
114
118
// Step 4. Launch GPU Kernel
115
119
gpu.launch blocks (%arg0 , %arg1 , %arg2 ) in (%arg6 = %hc1 , %arg7 = %hc1 , %arg8 = %hc1 )
116
120
threads (%arg3 , %arg4 , %arg5 ) in (%arg9 = %hc128 , %arg10 = %hc1 , %arg11 = %hc1 )
117
- dynamic_shared_memory_size %c214016_i32
121
+ dynamic_shared_memory_size %shmemSize
118
122
{
119
123
memref.assume_alignment %matrixD , 16 : memref <128 x128 xf32 >
120
124
@@ -136,14 +140,14 @@ func.func @main() {
136
140
%c32 = arith.constant 32 : index
137
141
%c16 = arith.constant 16 : index
138
142
%c4096 = arith.constant 4096 : index
139
- %c8 = arith.constant 8 : index
143
+ %c8 = arith.constant 8 : index
140
144
%txcount = arith.constant 32768 : index
141
145
142
146
%tidx = gpu.thread_id x
143
147
%dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
144
- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [7 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <7 x 128 x 64 x f16 , 3 >
145
- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [14 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <14 x 64 x 128 x f16 ,3 >
146
- %rhsShmem = memref.subview %rhsShmem2 [7 , 0 , 0 ][7 , 64 , 128 ][1 , 1 , 1 ] : memref <14 x 64 x 128 x f16 ,3 > to memref <7 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 >
148
+ %lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x 128 x 64 x f16 , 3 >
149
+ %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x 64 x 128 x f16 ,3 >
150
+ %rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x 64 x 128 x f16 ,3 > to memref <2 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
147
151
148
152
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
149
153
%barrier = nvgpu.mbarrier.create -> !barrierType
@@ -171,25 +175,27 @@ func.func @main() {
171
175
172
176
// Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
173
177
%pipe1 = arith.constant 0 : index
174
- %p1lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <7 x128 x64 xf16 ,3 > to memref <1 x64 x128 xf16 , strided <[8192 , 64 , 1 ]>, 3 >
175
- %p1rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <7 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 >
176
- %p1rhsSlice2 = memref.subview %p1rhsSlice [0 , 32 , 0 ][1 , 128 , 64 ][1 ,1 ,1 ] : memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 61440 >, 3 >
178
+ %p1lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , 3 >
179
+ %p1rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
180
+ %p1halfFirst = memref.subview %p1rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
181
+ %p1halfSecond = memref.subview %p1rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 20480 >, 3 >
177
182
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe1 ], %txcount , predicate = %cnd : !barrierType
178
183
%dim1 = arith.muli %pipe1 , %c64 : index
179
- nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %p1lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <1 x 64 x 128 x f16 , strided <[ 8192 , 64 , 1 ]> , 3 >
180
- nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %p1rhsSlice , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 >
181
- nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %p1rhsSlice2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 61440 >, 3 >
184
+ nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %p1lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x 64 x f16 , 3 >
185
+ nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %p1halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 16384 >, 3 >
186
+ nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %p1halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 20480 >, 3 >
182
187
183
188
// Step 5. [GPU] TMA Load Pipeline 2 (predicated)
184
189
%pipe2 = arith.constant 1 : index
185
- %p2lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <7 x128 x64 xf16 ,3 > to memref <1 x64 x128 xf16 , strided <[8192 , 64 , 1 ], offset : 8192 >, 3 >
186
- %p2rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <7 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 65536 >, 3 >
187
- %p2rhsSlice2 = memref.subview %p2rhsSlice [0 , 32 , 0 ][1 , 128 , 64 ][1 ,1 ,1 ] : memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 65536 >, 3 > to memref <1 x128 x64 xf16 , strided <[8192 , 128 , 1 ], offset : 69632 >, 3 >
190
+ %p2lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
191
+ %p2rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
192
+ %p2halfFirst = memref.subview %p2rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
193
+ %p2halfSecond = memref.subview %p2rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
188
194
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe2 ], %txcount , predicate = %cnd : !barrierType
189
195
%dim2 = arith.muli %pipe2 , %c64 : index
190
- nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %p2lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <1 x 64 x 128 x f16 , strided <[8192 , 64 , 1 ], offset : 8192 >, 3 >
191
- nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %p2rhsSlice , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 65536 >, 3 >
192
- nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %p2rhsSlice2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : 69632 >, 3 >
196
+ nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %p2lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x 64 x f16 , strided <[64 , 1 ], offset : 8192 >, 3 >
197
+ nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %p2halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 24576 >, 3 >
198
+ nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %p2halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 28672 >, 3 >
193
199
194
200
// Step 6. [GPU] Initiliaze accumulator matrix
195
201
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector <128 x128 xf32 >>
@@ -201,11 +207,11 @@ func.func @main() {
201
207
%ticks = arith.constant 10000000 : index
202
208
// TMA wait
203
209
nvgpu.mbarrier.try_wait.parity %barrier [%i ], %c0 , %ticks : !barrierType
204
- %lhsSlice = memref.subview %lhsShmem [%i , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <7 x 128 x 64 x f16 , 3 > to memref <1 x 64 x 128 x f16 , strided <[8192 , 64 , 1 ], offset : ?>, 3 >
205
- %rhsSlice = memref.subview %rhsShmem [%i , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <7 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 57344 >, 3 > to memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : ?>, 3 >
210
+ %lhsSlice = memref.subview %lhsShmem [%i , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x 128 x 64 x f16 , 3 > to memref <128 x 64 x f16 , strided <[64 , 1 ], offset : ?>, 3 >
211
+ %rhsSlice = memref.subview %rhsShmem [%i , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x 64 x 128 x f16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x 128 x f16 , strided <[128 , 1 ], offset : ?>, 3 >
206
212
// Descriptor WGMMA
207
- %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice , %descA : memref <1 x 64 x 128 x f16 , strided <[8192 , 64 , 1 ], offset : ?>, 3 >, !lhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <128 x64 xf16 , 3 >>
208
- %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice , %descB : memref <1 x 128 x 64 x f16 , strided <[8192 , 128 , 1 ], offset : ?>, 3 >, !rhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <64 x128 xf16 , 3 >>
213
+ %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice , %descA : memref <128 x 64 x f16 , strided <[64 , 1 ], offset : ?>, 3 >, !lhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <128 x64 xf16 , 3 >>
214
+ %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice , %descB : memref <64 x 128 x f16 , strided <[128 , 1 ], offset : ?>, 3 >, !rhsTensorMap -> !nvgpu.warpgroup.descriptor <tensor =memref <64 x128 xf16 , 3 >>
209
215
// Perform WGMMA 128x128x64
210
216
%md = nvgpu.warpgroup.mma %dA , %dB , %mc {transposeB } : <tensor = memref <128 x64 xf16 ,3 >>, <tensor = memref <64 x128 xf16 ,3 >>, <fragmented = vector <128 x128 xf32 >> -> <fragmented = vector <128 x128 xf32 >>
211
217
scf.yield %md : !nvgpu.warpgroup.accumulator <fragmented = vector <128 x128 xf32 >>
@@ -267,21 +273,12 @@ func.func @main() {
267
273
scf.yield %ec2 ,%cc2 : i32 ,i32
268
274
}
269
275
270
- %s0 = llvm.mlir.addressof @str_correct : !llvm.ptr <array <18 x i8 >>
271
- %s1 = llvm.mlir.constant (0 : index ) : i64
272
- %s2 = llvm.getelementptr %s0 [%s1 , %s1 ]
273
- : (!llvm.ptr <array <18 x i8 >>, i64 , i64 ) -> !llvm.ptr <i8 >
274
- func.call @printCString (%s2 ) : (!llvm.ptr <i8 >) -> ()
276
+ vector.print str " Correct Results :"
275
277
vector.print %correctCount : i32
276
- %s3 = llvm.mlir.addressof @str_incorrect : !llvm.ptr <array <20 x i8 >>
277
- %s4 = llvm.getelementptr %s3 [%s1 , %s1 ]
278
- : (!llvm.ptr <array <20 x i8 >>, i64 , i64 ) -> !llvm.ptr <i8 >
279
- func.call @printCString (%s4 ) : (!llvm.ptr <i8 >) -> ()
278
+ vector.print str " Incorrect Results :"
280
279
vector.print %errorCount : i32
281
280
282
281
return
283
282
}
284
- llvm.mlir.global internal constant @str_correct (" Correct Results : " ) {addr_space = 0 : i32 }
285
- llvm.mlir.global internal constant @str_incorrect (" Incorrect Results : " ) {addr_space = 0 : i32 }
286
- func.func private @printCString (!llvm.ptr <i8 >)
283
+
287
284
0 commit comments