@@ -335,34 +335,83 @@ LogicalResult LdMatrixOp::verify() {
335
335
// NVGPU_TmaAsyncLoadOp
336
336
// ===----------------------------------------------------------------------===//
337
337
338
- LogicalResult TmaAsyncLoadOp::verify () {
339
- // Destination memref
340
- auto dstMemref = llvm::cast<MemRefType>(getDst ().getType ());
338
+ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref (
339
+ Operation *op, nvgpu::TensorMapDescriptorType descType,
340
+ std::optional<MemRefType> memrefType = std::nullopt) {
341
+ MemRefType descMemref = descType.getTensor ();
342
+ // Limitation
343
+ if (descType.getInterleave () != TensorMapInterleaveKind::INTERLEAVE_NONE)
344
+ return op->emitError () << " Interleave options are not supported yet." ;
345
+
346
+ // Address space check for shared memory check
347
+ if (!NVGPUDialect::hasSharedMemoryAddressSpace (descMemref)) {
348
+ return op->emitError () << " the tensor map descriptor has incorrect address "
349
+ " space, it must be shared memory address space." ;
350
+ }
351
+ // Support only static shape for the time being
352
+ if (!descMemref.hasStaticShape ())
353
+ return op->emitError () << " the tensor map descriptor must be static shaped" ;
354
+
355
+ // No verification if memref type is not provided
356
+ if (!memrefType.has_value ())
357
+ return std::nullopt;
358
+
359
+ MemRefType dstMemref = memrefType.value ();
360
+
361
+ // Check element type
362
+ if (descMemref.getElementType () != dstMemref.getElementType ()) {
363
+ return op->emitError () << " the element type of tensor map descriptor and "
364
+ " memref must be same" ;
365
+ }
366
+
341
367
if (!NVGPUDialect::hasSharedMemoryAddressSpace (dstMemref)) {
342
- return emitError ()
343
- << " The operation stores data to shared memory, but "
344
- " the destination memref does not have a memory space of "
345
- << NVGPUDialect::kSharedMemoryAddressSpace ;
368
+ return op->emitError () << " the destination memref has incorrect address "
369
+ " space, it must be shared memory address space." ;
346
370
}
347
- if (getCoordinates ().size () > 5 ) {
348
- return emitError () << " Maximum 5 coordinates are supported." ;
371
+ if (!dstMemref.hasStaticShape ())
372
+ return op->emitError () << " the destination memref must be static shaped" ;
373
+
374
+ if (dstMemref.getRank () != descMemref.getRank ()) {
375
+ return op->emitError () << " the shape of tensor map descriptor and "
376
+ " memref must have same rank" ;
349
377
}
350
- if (getCoordinates ().size () != size_t (dstMemref.getRank ())) {
351
- return emitError () << " Destination memref rank is "
352
- << size_t (dstMemref.getRank ()) << " but there are "
353
- << getCoordinates ().size ()
354
- << " coordinates. They must match." ;
378
+ if (!descMemref.getShape ().equals (dstMemref.getShape ())) {
379
+ return op->emitError () << " memref and tensor map shapes mismatch "
380
+ << descMemref << " != " << dstMemref;
355
381
}
382
+
383
+ return std::nullopt;
384
+ }
385
+
386
+ LogicalResult TmaAsyncLoadOp::verify () {
387
+ std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref (
388
+ *this , getTensorMapDescriptor ().getType (), getDst ().getType ());
389
+ if (error.has_value ())
390
+ return error.value ();
391
+
392
+ if (getCoordinates ().size () > kMaxTMATensorDimension ) {
393
+ return emitError () << " Maximum " << kMaxTMATensorDimension
394
+ << " coordinates are supported." ;
395
+ }
396
+ if (getCoordinates ().size () !=
397
+ getTensorMapDescriptor ().getType ().getTensor ().getRank ()) {
398
+ return emitError () << " number of coordinates do not match with the rank of "
399
+ " tensor descriptor map." ;
400
+ }
401
+
356
402
return success ();
357
403
}
358
404
359
405
LogicalResult TmaCreateDescriptorOp::verify () {
360
- if (getBoxDimensions ().size () > 5 ) {
361
- return emitError () << " Maximum 5 dimensional box is supported." ;
406
+ if (getBoxDimensions ().size () > kMaxTMATensorDimension ) {
407
+ return emitError () << " Maximum " << kMaxTMATensorDimension
408
+ << " coordinates are supported." ;
362
409
}
363
- nvgpu::TensorMapDescriptorType desc = getTensorMap ().getType ();
364
- if (desc.getInterleave () != TensorMapInterleaveKind::INTERLEAVE_NONE)
365
- return emitError () << " Interleave options are not supported yet." ;
410
+
411
+ std::optional<InFlightDiagnostic> error =
412
+ verifyTmaDescriptorWithMemref (*this , getTensorMap ().getType ());
413
+ if (error.has_value ())
414
+ return error.value ();
366
415
367
416
return success ();
368
417
}
@@ -372,17 +421,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
372
421
// ===----------------------------------------------------------------------===//
373
422
374
423
LogicalResult WarpgroupGenerateDescriptorOp::verify () {
375
- MemRefType memrefType = getTensor ().getType ();
376
- MemRefType tensorMapType = getTensorMap ().getType ().getTensor ();
377
-
378
- if (memrefType != tensorMapType)
379
- return emitError () << " memref and tensor map type mismatch" ;
380
-
381
- if (!memrefType.hasStaticShape () || !tensorMapType.hasStaticShape ())
382
- return emitError () << " supports only static shapes" ;
383
-
384
- if (memrefType.getRank () != 2 )
385
- return emitError () << " supports only 2d memref is supported for now" ;
424
+ std::optional<InFlightDiagnostic> error =
425
+ verifyTmaDescriptorWithMemref (*this , getTensorMap ().getType ());
426
+ if (error.has_value ())
427
+ return error.value ();
386
428
387
429
if (getTensorMap ().getType ().getSwizzle () !=
388
430
TensorMapSwizzleKind::SWIZZLE_128B) {
0 commit comments