35
35
36
36
!rhs = memref <64 x128 xf16 >
37
37
!shmemrhs = memref <64 x128 xf16 , 3 >
38
- !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = !shmemrhs , swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
38
+ !rhsTensorMap = !nvgpu.tensormap.descriptor <tensor = memref < 64 x 64 x f16 , 3 > , swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
39
39
40
40
module @mymod {
41
41
func.func private @printMemrefF32 (memref <*xf32 >)
@@ -99,7 +99,8 @@ module @mymod {
99
99
%6 = gpu.thread_id x
100
100
%lhsShmem = memref.get_global @bufferLhsGlobal : !shmemlhs
101
101
%rhsShmem = memref.get_global @bufferRhsGlobal : !shmemrhs
102
- %rhsShmem2 = memref.subview %rhsShmem [32 , 0 ][128 , 64 ][1 , 1 ] : !shmemrhs to memref <128 x64 xf16 , strided <[128 , 1 ], offset : 4096 >, 3 >
102
+ %rhsShmem1 = memref.subview %rhsShmem [0 , 0 ][64 , 64 ][1 , 1 ] : !shmemrhs to memref <64 x64 xf16 , strided <[128 , 1 ]>, 3 >
103
+ %rhsShmem2 = memref.subview %rhsShmem [32 , 0 ][64 , 64 ][1 , 1 ] : !shmemrhs to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 4096 >, 3 >
103
104
104
105
// Step 5. Initialize the mbarrier
105
106
%9 = nvgpu.mbarrier.create -> !barrierType
@@ -110,8 +111,8 @@ module @mymod {
110
111
scf.if %10 {
111
112
gpu.printf " [GPU] TMA SIZE %d\0A" %c32768 : index
112
113
nvgpu.tma.async.load %d_lhsTensorMap [%c0 , %c0 ], %9 [%c0 ] to %lhsShmem : !lhsTensorMap , !barrierType -> !shmemlhs
113
- nvgpu.tma.async.load %d_rhsTensorMap [%c0 , %c0 ], %9 [%c0 ] to %rhsShmem : !rhsTensorMap , !barrierType -> !shmemrhs
114
- nvgpu.tma.async.load %d_rhsTensorMap [%c64 , %c0 ], %9 [%c0 ] to %rhsShmem2 : !rhsTensorMap , !barrierType -> memref <128 x 64 x f16 , strided <[128 , 1 ], offset : 4096 >, 3 >
114
+ nvgpu.tma.async.load %d_rhsTensorMap [%c0 , %c0 ], %9 [%c0 ] to %rhsShmem1 : !rhsTensorMap , !barrierType -> memref < 64 x 64 x f16 , strided <[ 128 , 1 ]>, 3 >
115
+ nvgpu.tma.async.load %d_rhsTensorMap [%c64 , %c0 ], %9 [%c0 ] to %rhsShmem2 : !rhsTensorMap , !barrierType -> memref <64 x 64 x f16 , strided <[128 , 1 ], offset : 4096 >, 3 >
115
116
nvgpu.mbarrier.arrive.expect_tx %9 [%c0 ], %c32768 : !barrierType
116
117
} else {
117
118
nvgpu.mbarrier.arrive.expect_tx %9 [%c0 ], %c0 : !barrierType
0 commit comments