Skip to content

[MLIR][XeGPU] Account for sg_map in LoadNdOp verification #123928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

akroviakov
Copy link
Contributor

The current XeGPU has the sg_map attribute definition, but does not allow using it for loads due to verification failure (shape mismatch).

To allow both valid load operations with sg_map tdesc and to introduce certain rules for sg_map usage, this PR modifies the verifiers of CreateNdOp and LoadNdOp by including checks related to the sg_map attribute.

@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

The current XeGPU has the sg_map attribute definition, but does not allow using it for loads due to verification failure (shape mismatch).

To allow both valid load operations with sg_map tdesc and to introduce certain rules for sg_map usage, this PR modifies the verifiers of CreateNdOp and LoadNdOp by including checks related to the sg_map attribute.


Full diff: https://github.com/llvm/llvm-project/pull/123928.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+23)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+27)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..0c5a1ce0e96a38 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -198,6 +198,22 @@ LogicalResult CreateNdDescOp::verify() {
       tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
     return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
 
+  if (auto attr = getType().getSGMapAttr()) {
+    auto wiLayout = attr.getWiLayout();
+    auto wiData = attr.getWiData();
+    if (wiData[0] < 1 || wiData[1] < 1 || (wiData[0] > 1 && wiData[1] > 1))
+      return emitOpError() << "`wi_data` values must be >=1 and can only be >1 "
+                              "along one dimension."
+                           << "\n";
+    auto tdescShape = getType().getShape();
+    for (size_t i = 0; i < tdescShape.size(); i++) {
+      if (tdescShape[i] % wiLayout[i])
+        return emitOpError() << "Work-items must uniformly divide a tile "
+                                "(tdescShape[i] % wiLayout[i] == 0)"
+                             << "\n";
+    }
+  }
+
   return success();
 }
 
@@ -250,6 +266,13 @@ LogicalResult LoadNdOp::verify() {
   auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 
+  if (auto attr = getTensorDescType().getSGMapAttr()) {
+    auto wiLayout = attr.getWiLayout();
+    for (size_t i = 0; i < tdescShape.size(); i++) {
+      tdescShape[i] /= wiLayout[i];
+    }
+  }
+
   if (getTranspose()) {
     auto trans = getTranspose().value();
 
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..0f92e9cb68db68 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -21,6 +21,33 @@ gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<32x32xi8>) {
+gpu.func @test_load_nd_tdesc_with_sg_map(%src: memref<32x32xi8>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_2(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_tdesc_with_sg_map_2(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_3(%[[arg0:.*]]: memref<32x32xf32>) {
+gpu.func @test_load_nd_tdesc_with_sg_map_3(%src: memref<32x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
 gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index

@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Artem Kroviakov (akroviakov)

Changes

The current XeGPU has the sg_map attribute definition, but does not allow using it for loads due to verification failure (shape mismatch).

To allow both valid load operations with sg_map tdesc and to introduce certain rules for sg_map usage, this PR modifies the verifiers of CreateNdOp and LoadNdOp by including checks related to the sg_map attribute.


Full diff: https://github.com/llvm/llvm-project/pull/123928.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+23)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+27)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..0c5a1ce0e96a38 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -198,6 +198,22 @@ LogicalResult CreateNdDescOp::verify() {
       tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
     return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
 
+  if (auto attr = getType().getSGMapAttr()) {
+    auto wiLayout = attr.getWiLayout();
+    auto wiData = attr.getWiData();
+    if (wiData[0] < 1 || wiData[1] < 1 || (wiData[0] > 1 && wiData[1] > 1))
+      return emitOpError() << "`wi_data` values must be >=1 and can only be >1 "
+                              "along one dimension."
+                           << "\n";
+    auto tdescShape = getType().getShape();
+    for (size_t i = 0; i < tdescShape.size(); i++) {
+      if (tdescShape[i] % wiLayout[i])
+        return emitOpError() << "Work-items must uniformly divide a tile "
+                                "(tdescShape[i] % wiLayout[i] == 0)"
+                             << "\n";
+    }
+  }
+
   return success();
 }
 
@@ -250,6 +266,13 @@ LogicalResult LoadNdOp::verify() {
   auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 
+  if (auto attr = getTensorDescType().getSGMapAttr()) {
+    auto wiLayout = attr.getWiLayout();
+    for (size_t i = 0; i < tdescShape.size(); i++) {
+      tdescShape[i] /= wiLayout[i];
+    }
+  }
+
   if (getTranspose()) {
     auto trans = getTranspose().value();
 
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..0f92e9cb68db68 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -21,6 +21,33 @@ gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<32x32xi8>) {
+gpu.func @test_load_nd_tdesc_with_sg_map(%src: memref<32x32xi8>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_2(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_tdesc_with_sg_map_2(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_3(%[[arg0:.*]]: memref<32x32xf32>) {
+gpu.func @test_load_nd_tdesc_with_sg_map_3(%src: memref<32x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
 gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index

@adam-smnk
Copy link
Contributor

Just a quick fly-by for now, are the changes aligned with #120566?

@akroviakov
Copy link
Contributor Author

Oops, missed that indeed, closing this PR in favor of the mentioned predecessor

@akroviakov akroviakov closed this Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants