@@ -73,6 +73,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
73
73
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74
74
}
75
75
76
+ // Validations for nd instruction arguments is successful if any of these are
77
+ // true:
78
+ // - tensor descriptor and the output vector shapes exactly match.
79
+ // - tensor descriptor has a sg_map attribute and the distributed vector shape
80
+ // matches the tensor descriptor shape when scaled using sg_map factors on
81
+ // each dimension.
82
+ static bool isArgShapesValid (ArrayRef<int64_t > descShape,
83
+ ArrayRef<int64_t > valShape, SGMapAttr sgMap) {
84
+ if (descShape == valShape) {
85
+ if (!sgMap)
86
+ return true ;
87
+
88
+ // this can be relaxed if necessary by supporting non-2d shapes distribution
89
+ // until the constraints are defined this lives here instead of the tensor
90
+ // descriptor type.
91
+ return valShape.size () == sgMap.getWiLayout ().size ();
92
+ }
93
+
94
+ if (!sgMap)
95
+ return false ;
96
+
97
+ if (valShape.size () != descShape.size ())
98
+ return false ;
99
+
100
+ for (const auto &[factor, dim, expected] :
101
+ llvm::zip_equal (sgMap.getWiLayout (), valShape, descShape)) {
102
+ if (factor * dim != expected)
103
+ return false ;
104
+ }
105
+
106
+ return true ;
107
+ }
108
+
76
109
// ===----------------------------------------------------------------------===//
77
110
// XeGPU_CreateNdDescOp
78
111
// ===----------------------------------------------------------------------===//
@@ -210,13 +243,13 @@ LogicalResult PrefetchNdOp::verify() {
210
243
return emitOpError (" Expects a non-scattered TensorDesc.\n " );
211
244
212
245
if (!isReadHintOrNone (getL1HintAttr ()))
213
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
246
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
214
247
215
248
if (!isReadHintOrNone (getL2HintAttr ()))
216
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
249
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
217
250
218
251
if (!isReadHintOrNone (getL3HintAttr ()))
219
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
252
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
220
253
221
254
return success ();
222
255
}
@@ -238,13 +271,13 @@ LogicalResult LoadNdOp::verify() {
238
271
return emitOpError (" Invalid result, it should be a VectorType.\n " );
239
272
240
273
if (!isReadHintOrNone (getL1HintAttr ()))
241
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
274
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
242
275
243
276
if (!isReadHintOrNone (getL2HintAttr ()))
244
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
277
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
245
278
246
279
if (!isReadHintOrNone (getL3HintAttr ()))
247
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
280
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
248
281
249
282
auto array_len = tdescTy.getArrayLength ();
250
283
auto tdescShape = getShapeOf (tdescTy);
@@ -280,8 +313,9 @@ LogicalResult LoadNdOp::verify() {
280
313
auto it = tdescShape.begin ();
281
314
tdescShape.insert (it, array_len);
282
315
}
316
+ auto sgMap = tdescTy.getSGMapAttr ();
283
317
284
- if (tdescShape != valueShape)
318
+ if (! isArgShapesValid ( tdescShape, valueShape, sgMap) )
285
319
return emitOpError () << " Result shape doesn't match TensorDesc shape."
286
320
<< " The expected shape is " << makeString (tdescShape)
287
321
<< " . But the given shape is "
@@ -303,17 +337,26 @@ LogicalResult StoreNdOp::verify() {
303
337
return emitOpError (" Expects a non-scattered TensorDesc.\n " );
304
338
305
339
if (!valTy)
306
- return emitOpError (" Exepcting a VectorType result.\n " );
340
+ return emitOpError (" Expecting a VectorType result.\n " );
307
341
308
342
if (!isWriteHintOrNone (getL1HintAttr ()))
309
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
343
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
310
344
311
345
if (!isWriteHintOrNone (getL2HintAttr ()))
312
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
346
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
313
347
314
348
if (!isWriteHintOrNone (getL3HintAttr ()))
315
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
349
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
350
+
351
+ auto tdescShape = getShapeOf (dstTy);
352
+ auto valueShape = getShapeOf (valTy);
353
+ auto sgMap = dstTy.getSGMapAttr ();
316
354
355
+ if (!isArgShapesValid (tdescShape, valueShape, sgMap))
356
+ return emitOpError () << " Result shape doesn't match TensorDesc shape."
357
+ << " The expected shape is " << makeString (tdescShape)
358
+ << " . But the given shape is "
359
+ << makeString (valueShape) << " .\n " ;
317
360
return success ();
318
361
}
319
362
@@ -423,13 +466,13 @@ LogicalResult PrefetchOp::verify() {
423
466
return emitOpError (" Expects a scattered TensorDesc.\n " );
424
467
425
468
if (!isReadHintOrNone (getL1HintAttr ()))
426
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
469
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
427
470
428
471
if (!isReadHintOrNone (getL2HintAttr ()))
429
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
472
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
430
473
431
474
if (!isReadHintOrNone (getL3HintAttr ()))
432
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
475
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
433
476
434
477
return success ();
435
478
}
@@ -446,13 +489,13 @@ LogicalResult LoadGatherOp::verify() {
446
489
return emitOpError (" Expects a scattered TensorDesc.\n " );
447
490
448
491
if (!isReadHintOrNone (getL1HintAttr ()))
449
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
492
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
450
493
451
494
if (!isReadHintOrNone (getL2HintAttr ()))
452
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
495
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
453
496
454
497
if (!isReadHintOrNone (getL3HintAttr ()))
455
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
498
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
456
499
457
500
auto tdescElemTy = tdescTy.getElementType ();
458
501
auto valueElemTy = getElementType ();
@@ -490,13 +533,13 @@ LogicalResult StoreScatterOp::verify() {
490
533
return emitOpError (" Expects a scattered TensorDesc.\n " );
491
534
492
535
if (!isWriteHintOrNone (getL1HintAttr ()))
493
- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
536
+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
494
537
495
538
if (!isWriteHintOrNone (getL2HintAttr ()))
496
- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
539
+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
497
540
498
541
if (!isWriteHintOrNone (getL3HintAttr ()))
499
- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
542
+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
500
543
501
544
auto maskTy = getMaskType ();
502
545
auto valueTy = getValueType ();
0 commit comments