@@ -674,7 +674,7 @@ module @mymodule {
674
674
!tensorMap = !nvgpu.tensormap.descriptor <tensor = memref <128 x64 xf16 ,3 >, swizzle = swizzle_128b , l2promo =none , oob =zero , interleave =none >
675
675
memref.global " private" @dynamicShmem : memref <0 xf16 ,3 >
676
676
// CHECK-LABEL: func @create_wgmma_descriptor(
677
- func.func @create_wgmma_descriptor (%tensorMap : !tensorMap ) -> !nvgpu.wgmma .descriptor <tensor =memref <128 x64 xf16 ,3 >>{
677
+ func.func @create_wgmma_descriptor (%tensorMap : !tensorMap ) -> !nvgpu.warpgroup .descriptor <tensor =memref <128 x64 xf16 ,3 >>{
678
678
%dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
679
679
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [128 ,64 ], strides : [64 ,1 ] : memref <0 xf16 , 3 > to memref <128 x64 xf16 ,3 >
680
680
// CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
@@ -706,22 +706,22 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
706
706
// CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
707
707
// CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64
708
708
// CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64
709
- // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma .descriptor<tensor = memref<128x64xf16, 3>>
709
+ // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.warpgroup .descriptor<tensor = memref<128x64xf16, 3>>
710
710
// CHECK: return %[[ret]]
711
- %descA = nvgpu.wgmma .generate.descriptor %lhsShmem , %tensorMap : memref <128 x64 xf16 ,3 >, !tensorMap -> !nvgpu.wgmma .descriptor <tensor =memref <128 x64 xf16 ,3 >>
712
- func.return %descA : !nvgpu.wgmma .descriptor <tensor =memref <128 x64 xf16 ,3 >>
711
+ %descA = nvgpu.warpgroup .generate.descriptor %lhsShmem , %tensorMap : memref <128 x64 xf16 ,3 >, !tensorMap -> !nvgpu.warpgroup .descriptor <tensor =memref <128 x64 xf16 ,3 >>
712
+ func.return %descA : !nvgpu.warpgroup .descriptor <tensor =memref <128 x64 xf16 ,3 >>
713
713
}
714
714
715
715
// CHECK-LABEL: @warpgroup_mma_128_128_64(
716
- // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma .descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma .descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
716
+ // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup .descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup .descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
717
717
func.func @warpgroup_mma_128_128_64 (
718
- %descA: !nvgpu.wgmma .descriptor <tensor = memref <128 x64 xf16 , 3 >>,
719
- %descB: !nvgpu.wgmma .descriptor <tensor = memref <64 x128 xf16 , 3 >>,
718
+ %descA: !nvgpu.warpgroup .descriptor <tensor = memref <128 x64 xf16 , 3 >>,
719
+ %descB: !nvgpu.warpgroup .descriptor <tensor = memref <64 x128 xf16 , 3 >>,
720
720
%acc1: !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
721
721
%acc2: !nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>)
722
722
{
723
- // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma .descriptor<tensor = memref<128x64xf16, 3>> to i64
724
- // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma .descriptor<tensor = memref<64x128xf16, 3>> to i64
723
+ // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup .descriptor<tensor = memref<128x64xf16, 3>> to i64
724
+ // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup .descriptor<tensor = memref<64x128xf16, 3>> to i64
725
725
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
726
726
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
727
727
// CHECK: nvvm.wgmma.fence.aligned
@@ -762,8 +762,8 @@ func.func @warpgroup_mma_128_128_64(
762
762
// CHECK: nvvm.wgmma.commit.group.sync.aligned
763
763
// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
764
764
%wgmmaResult , %wgmmaResult2 = nvgpu.warpgroup.mma %descA , %descB , %acc1 , %acc2 {transposeB }:
765
- !nvgpu.wgmma .descriptor <tensor = memref <128 x64 xf16 , 3 >>,
766
- !nvgpu.wgmma .descriptor <tensor = memref <64 x128 xf16 , 3 >>,
765
+ !nvgpu.warpgroup .descriptor <tensor = memref <128 x64 xf16 , 3 >>,
766
+ !nvgpu.warpgroup .descriptor <tensor = memref <64 x128 xf16 , 3 >>,
767
767
!nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>,
768
768
!nvgpu.warpgroup.accumulator < fragmented = vector <64 x128 xf32 >>
769
769
->
0 commit comments