Skip to content

Commit fa6f88a

Browse files
[MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … (#120566)
…the distributed IR case. This patch allows `nd_load` and `nd_store` to preserve the tensor descriptor shape during distribution to SIMT. The validation now expects the distributed instruction to retain the `sg_map` attribute and uses it to verify the consistency.
1 parent 75ce2dc commit fa6f88a

File tree

4 files changed

+112
-28
lines changed

4 files changed

+112
-28
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
327327
let hasVerifier = 1;
328328
}
329329

330-
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
331-
AllElementTypesMatch<["value", "TensorDesc"]>]> {
330+
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
332331
let summary = "stores a n-D block register region back to memory, currently only supports 2D";
333332

334333
let description = [{

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
7373
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
7474
}
7575

76+
// Validations for nd instruction arguments is successful if any of these are
77+
// true:
78+
// - tensor descriptor and the output vector shapes exactly match.
79+
// - tensor descriptor has a sg_map attribute and the distributed vector shape
80+
// matches the tensor descriptor shape when scaled using sg_map factors on
81+
// each dimension.
82+
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
83+
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84+
if (descShape == valShape) {
85+
if (!sgMap)
86+
return true;
87+
88+
// this can be relaxed if necessary by supporting non-2d shapes distribution
89+
// until the constraints are defined this lives here instead of the tensor
90+
// descriptor type.
91+
return valShape.size() == sgMap.getWiLayout().size();
92+
}
93+
94+
if (!sgMap)
95+
return false;
96+
97+
if (valShape.size() != descShape.size())
98+
return false;
99+
100+
for (const auto &[factor, dim, expected] :
101+
llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
102+
if (factor * dim != expected)
103+
return false;
104+
}
105+
106+
return true;
107+
}
108+
76109
//===----------------------------------------------------------------------===//
77110
// XeGPU_CreateNdDescOp
78111
//===----------------------------------------------------------------------===//
@@ -210,13 +243,13 @@ LogicalResult PrefetchNdOp::verify() {
210243
return emitOpError("Expects a non-scattered TensorDesc.\n");
211244

212245
if (!isReadHintOrNone(getL1HintAttr()))
213-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
246+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
214247

215248
if (!isReadHintOrNone(getL2HintAttr()))
216-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
249+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
217250

218251
if (!isReadHintOrNone(getL3HintAttr()))
219-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
252+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
220253

221254
return success();
222255
}
@@ -238,13 +271,13 @@ LogicalResult LoadNdOp::verify() {
238271
return emitOpError("Invalid result, it should be a VectorType.\n");
239272

240273
if (!isReadHintOrNone(getL1HintAttr()))
241-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
274+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
242275

243276
if (!isReadHintOrNone(getL2HintAttr()))
244-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
277+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
245278

246279
if (!isReadHintOrNone(getL3HintAttr()))
247-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
280+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
248281

249282
auto array_len = tdescTy.getArrayLength();
250283
auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +313,9 @@ LogicalResult LoadNdOp::verify() {
280313
auto it = tdescShape.begin();
281314
tdescShape.insert(it, array_len);
282315
}
316+
auto sgMap = tdescTy.getSGMapAttr();
283317

284-
if (tdescShape != valueShape)
318+
if (!isArgShapesValid(tdescShape, valueShape, sgMap))
285319
return emitOpError() << "Result shape doesn't match TensorDesc shape."
286320
<< "The expected shape is " << makeString(tdescShape)
287321
<< ". But the given shape is "
@@ -303,17 +337,26 @@ LogicalResult StoreNdOp::verify() {
303337
return emitOpError("Expects a non-scattered TensorDesc.\n");
304338

305339
if (!valTy)
306-
return emitOpError("Exepcting a VectorType result.\n");
340+
return emitOpError("Expecting a VectorType result.\n");
307341

308342
if (!isWriteHintOrNone(getL1HintAttr()))
309-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
343+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
310344

311345
if (!isWriteHintOrNone(getL2HintAttr()))
312-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
346+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
313347

314348
if (!isWriteHintOrNone(getL3HintAttr()))
315-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
349+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
350+
351+
auto tdescShape = getShapeOf(dstTy);
352+
auto valueShape = getShapeOf(valTy);
353+
auto sgMap = dstTy.getSGMapAttr();
316354

355+
if (!isArgShapesValid(tdescShape, valueShape, sgMap))
356+
return emitOpError() << "Result shape doesn't match TensorDesc shape."
357+
<< "The expected shape is " << makeString(tdescShape)
358+
<< ". But the given shape is "
359+
<< makeString(valueShape) << ".\n";
317360
return success();
318361
}
319362

@@ -423,13 +466,13 @@ LogicalResult PrefetchOp::verify() {
423466
return emitOpError("Expects a scattered TensorDesc.\n");
424467

425468
if (!isReadHintOrNone(getL1HintAttr()))
426-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
469+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
427470

428471
if (!isReadHintOrNone(getL2HintAttr()))
429-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
472+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
430473

431474
if (!isReadHintOrNone(getL3HintAttr()))
432-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
475+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
433476

434477
return success();
435478
}
@@ -446,13 +489,13 @@ LogicalResult LoadGatherOp::verify() {
446489
return emitOpError("Expects a scattered TensorDesc.\n");
447490

448491
if (!isReadHintOrNone(getL1HintAttr()))
449-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
492+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
450493

451494
if (!isReadHintOrNone(getL2HintAttr()))
452-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
495+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
453496

454497
if (!isReadHintOrNone(getL3HintAttr()))
455-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
498+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
456499

457500
auto tdescElemTy = tdescTy.getElementType();
458501
auto valueElemTy = getElementType();
@@ -490,13 +533,13 @@ LogicalResult StoreScatterOp::verify() {
490533
return emitOpError("Expects a scattered TensorDesc.\n");
491534

492535
if (!isWriteHintOrNone(getL1HintAttr()))
493-
return emitOpError("invlid l1_hint: ") << getL1HintAttr();
536+
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494537

495538
if (!isWriteHintOrNone(getL2HintAttr()))
496-
return emitOpError("invlid l2_hint: ") << getL2HintAttr();
539+
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497540

498541
if (!isWriteHintOrNone(getL3HintAttr()))
499-
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
542+
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500543

501544
auto maskTy = getMaskType();
502545
auto valueTy = getValueType();

mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
8686
gpu.return
8787
}
8888

89+
// load_nd args may have different shapes, validated against sg_map
90+
// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
91+
gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
92+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
93+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
94+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
95+
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
96+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
97+
gpu.return
98+
}
99+
89100
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
90101
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
91102
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
108119
gpu.return
109120
}
110121

122+
// store_nd args may have different shapes, validated against sg_map
123+
// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
124+
gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
125+
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
126+
%1 = arith.constant dense<1.0>: vector<24x2xf16>
127+
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
128+
%2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
129+
!xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
130+
// CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
131+
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
132+
gpu.return
133+
}
134+
111135
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
112136
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
113137
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
3232
// -----
3333
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
3434
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
35-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
35+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
3636
xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
3737
return
3838
}
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
5151
// -----
5252
func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
5353
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
54-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
54+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
5555
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
5656
: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
5757
return
@@ -77,11 +77,29 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
7777
return
7878
}
7979

80+
// -----
81+
func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
82+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
83+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
84+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
85+
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
86+
return
87+
}
88+
89+
// -----
90+
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
91+
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
92+
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
93+
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
94+
%2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
95+
return
96+
}
97+
8098
// -----
8199
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
82100
%1 = arith.constant dense<1.0>: vector<24x32xf16>
83101
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
84-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
102+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
85103
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
86104
return
87105
}
@@ -147,7 +165,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
147165
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
148166
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
149167
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
150-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
168+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
151169
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
152170
return
153171
}
@@ -168,7 +186,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
168186
%0 = arith.constant dense<1>: vector<4xi1>
169187
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
170188
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
171-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
189+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
172190
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
173191
: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
174192
-> vector<4x2xf32>
@@ -193,7 +211,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
193211
%1 = arith.constant dense<2.9>: vector<4x2xf32>
194212
%2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
195213
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
196-
// expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
214+
// expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
197215
xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
198216
!xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
199217
return

0 commit comments

Comments
 (0)