|
1 |
| -// RUN: mlir-opt -xegpu-subgroup-distribute -split-input-file %s | FileCheck %s |
| 1 | +// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s |
2 | 2 |
|
3 | 3 | // CHECK-LABEL: gpu.func @store_nd_1d
|
4 | 4 | // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
|
@@ -160,3 +160,50 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
|
160 | 160 | gpu.return
|
161 | 161 | }
|
162 | 162 | }
|
| 163 | + |
| 164 | +// ----- |
| 165 | +// CHECK-LABEL: gpu.func @gemm_loop |
| 166 | +// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) { |
| 167 | +// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x |
| 168 | +// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y |
| 169 | +// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index |
| 170 | +// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index |
| 171 | +// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> |
| 172 | +// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> |
| 173 | +// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32> |
| 174 | +// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) { |
| 175 | +// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16> |
| 176 | +// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16> |
| 177 | +// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16> |
| 178 | +// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16> |
| 179 | +// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32> |
| 180 | +// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32> |
| 181 | +// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32> |
| 182 | +// CHECK: scf.yield %[[T16]] : vector<8x1xf32> |
| 183 | +// CHECK: } |
| 184 | +// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32> |
| 185 | +// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> |
| 186 | +gpu.module @test { |
| 187 | +gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){ |
| 188 | + %c0 = arith.constant 0 : index |
| 189 | + %c16 = arith.constant 16 : index |
| 190 | + %c8 = arith.constant 8 : index |
| 191 | + %c1024 = arith.constant 1024 : index |
| 192 | + %0 = gpu.block_id x |
| 193 | + %1 = gpu.block_id y |
| 194 | + %2 = arith.muli %0, %c8 : index |
| 195 | + %3 = arith.muli %1, %c16 : index |
| 196 | + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> |
| 197 | + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> |
| 198 | + %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { |
| 199 | + %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16> |
| 200 | + %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16> |
| 201 | + %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> |
| 202 | + %10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16> |
| 203 | + %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32> |
| 204 | + scf.yield %11 : vector<8x16xf32> |
| 205 | + } |
| 206 | + xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> |
| 207 | + gpu.return |
| 208 | +} |
| 209 | +} |
0 commit comments