@@ -142,17 +142,13 @@ func.func @main() {
142
142
%c4096 = arith.constant 4096 : index
143
143
%c8 = arith.constant 8 : index
144
144
%txcount = arith.constant 32768 : index
145
- %c24576 = arith.constant 24576 : index
146
- %c16384 = arith.constant 16384 : index
147
- %c49152 = arith.constant 49152 : index
148
- %c57344 = arith.constant 57344 : index
149
145
150
146
%tidx = gpu.thread_id x
151
147
%dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
152
148
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x128 x64 xf16 , 3 >
153
149
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x64 x128 xf16 ,3 >
154
150
%rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x64 x128 xf16 ,3 > to memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
155
- %dynsmem = gpu.dynamic_shared_memory : memref <?x i8 , #gpu.address_space < workgroup >>
151
+
156
152
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
157
153
%barrier = nvgpu.mbarrier.create -> !barrierType
158
154
@@ -179,25 +175,28 @@ func.func @main() {
179
175
180
176
// Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
181
177
%pipe1 = arith.constant 0 : index
182
- %lhsSlice1 = memref.view %dynsmem [%c0 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
183
- %halfFirst1 = memref.view %dynsmem [%c16384 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
184
- %halfSecond1 = memref.view %dynsmem [%c24576 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
178
+ %p1lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , 3 >
179
+ %p1rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
180
+ %p1halfFirst = memref.subview %p1rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
181
+ %p1halfSecond = memref.subview %p1rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 20480 >, 3 >
185
182
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe1 ], %txcount , predicate = %cnd : !barrierType
186
183
%dim1 = arith.muli %pipe1 , %c64 : index
187
- nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %lhsSlice1 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space < workgroup > >
188
- nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %halfFirst1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
189
- nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %halfSecond1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
184
+ nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %p1lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , 3 >
185
+ nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %p1halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 16384 >, 3 >
186
+ nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %p1halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 20480 >, 3 >
190
187
191
188
// Step 5. [GPU] TMA Load Pipeline 2 (predicated)
192
189
%pipe2 = arith.constant 1 : index
193
- %lhsSlice2 = memref.view %dynsmem [%c32768 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
194
- %halfFirst2 = memref.view %dynsmem [%c49152 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
195
- %halfSecond2 = memref.view %dynsmem [%c57344 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
190
+ %p2lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
191
+ %p2rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
192
+ %p2halfFirst = memref.subview %p2rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
193
+ %p2halfSecond = memref.subview %p2rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
196
194
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe2 ], %txcount , predicate = %cnd : !barrierType
197
195
%dim2 = arith.muli %pipe2 , %c64 : index
198
- nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %lhsSlice2 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space <workgroup >>
199
- nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %halfFirst2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
200
- nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %halfSecond2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
196
+ nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %p2lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
197
+ nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %p2halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
198
+ nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %p2halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
199
+
201
200
// Step 6. [GPU] Initiliaze accumulator matrix
202
201
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector <128 x128 xf32 >>
203
202
0 commit comments