Skip to content

[mlir][xegpu] Handle scalar uniform ops in SIMT distribution. #138593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
signalPassFailure();
return;
}
// At this point, we have moved the entire function body inside the warpOp.
// Now move any scalar uniform code outside of the warpOp (like GPU index
// ops, scalar constants, etc.). This will simplify the later lowering and
// avoid custom patterns for these ops.
getOperation()->walk([&](Operation *op) {
if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
vector::moveScalarUniformCode(warpOp);
}
});
}
// Finally, do the SIMD to SIMT distribution.
RewritePatternSet patterns(&getContext());
Expand Down
49 changes: 48 additions & 1 deletion mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -xegpu-subgroup-distribute -split-input-file %s | FileCheck %s
// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s

// CHECK-LABEL: gpu.func @store_nd_1d
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
Expand Down Expand Up @@ -160,3 +160,50 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
gpu.return
}
}

// -----
// CHECK-LABEL: gpu.func @gemm_loop
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
// CHECK: }
// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.module @test {
gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
%c1024 = arith.constant 1024 : index
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = arith.muli %0, %c8 : index
%3 = arith.muli %1, %c16 : index
%4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
%5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
%7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
%8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
%9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
%10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
%11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
scf.yield %11 : vector<8x16xf32>
}
xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
Loading