@@ -33,6 +33,15 @@ using namespace mlir::sparse_tensor;
33
33
34
34
namespace {
35
35
36
+ // Sparse formats supported by cuSparse.
37
+ enum class CuSparseFormat {
38
+ kNone ,
39
+ kCOO ,
40
+ kCSR ,
41
+ kCSC ,
42
+ kBSR , // TODO: coming soon!
43
+ };
44
+
36
45
// ===----------------------------------------------------------------------===//
37
46
// Helper methods.
38
47
// ===----------------------------------------------------------------------===//
@@ -385,73 +394,92 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
385
394
return false ;
386
395
}
387
396
388
- // / Determines if the given value is a dense tensor instead of a sparse one .
397
+ // / Test for dense tensor.
389
398
static bool isDenseTensor (Value v) {
390
- return (sparse_tensor::getSparseTensorType (v).isAllDense ());
399
+ auto sTp = getSparseTensorType (v);
400
+ return sTp .getDimRank () == sTp .getLvlRank () && sTp .isAllDense ();
401
+ }
402
+
403
+ // / Test for suitable positions/coordinates width.
404
+ static bool isAdmissibleMetaData (SparseTensorType &aTp) {
405
+ return (aTp.getPosWidth () == 0 || aTp.getPosWidth () >= 16 ) &&
406
+ (aTp.getCrdWidth () == 0 || aTp.getCrdWidth () >= 16 );
391
407
}
392
408
393
- // / Test for sorted COO with suitable data and coordinates types .
409
+ // / Test for sorted COO matrix with suitable metadata .
394
410
static bool isAdmissibleCOO (SparseTensorType &aTp) {
395
- return aTp.isCompressedLvl (0 ) && aTp.isOrderedLvl (0 ) && !aTp.isUniqueLvl (0 ) &&
411
+ return aTp.getDimRank () == 2 && aTp.getLvlRank () == 2 && aTp.isIdentity () &&
412
+ aTp.isCompressedLvl (0 ) && aTp.isOrderedLvl (0 ) && !aTp.isUniqueLvl (0 ) &&
396
413
aTp.isSingletonLvl (1 ) && aTp.isOrderedLvl (1 ) && aTp.isUniqueLvl (1 ) &&
397
- (aTp.getElementType ().isF64 () || aTp.getElementType ().isF32 ()) &&
398
- (aTp.getCrdWidth () == 0 || aTp.getCrdWidth () == 32 ||
399
- aTp.getCrdWidth () == 64 );
414
+ isAdmissibleMetaData (aTp);
400
415
}
401
416
402
- // / Test for CSR with suitable data and coordinates types .
417
+ // / Test for CSR matrix with suitable metadata .
403
418
static bool isAdmissibleCSR (SparseTensorType &aTp) {
404
- return aTp.isDenseLvl (0 ) && aTp.isCompressedLvl (1 ) && aTp.isOrderedLvl (1 ) &&
405
- aTp.isUniqueLvl (1 ) &&
406
- (aTp.getElementType ().isF64 () || aTp.getElementType ().isF32 ()) &&
407
- (aTp.getCrdWidth () == 0 || aTp.getCrdWidth () == 32 ||
408
- aTp.getCrdWidth () == 64 );
419
+ return aTp.getDimRank () == 2 && aTp.getLvlRank () == 2 && aTp.isIdentity () &&
420
+ aTp.isDenseLvl (0 ) && aTp.isCompressedLvl (1 ) && aTp.isOrderedLvl (1 ) &&
421
+ aTp.isUniqueLvl (1 ) && isAdmissibleMetaData (aTp);
409
422
}
410
423
411
- // / Test for admissible types on operands (with output parameter `isCOO`).
412
- static bool areAdmissibleTypes (SparseTensorType aTp, SparseTensorType bTp,
413
- SparseTensorType cTp, bool enableRT,
414
- bool isMatVec, bool &isCOO) {
424
+ // / Test for CSC matrix with suitable metadata.
425
+ static bool isAdmissibleCSC (SparseTensorType &aTp) {
426
+ return aTp.getDimRank () == 2 && aTp.getLvlRank () == 2 && !aTp.isIdentity () &&
427
+ aTp.isPermutation () && aTp.isDenseLvl (0 ) && aTp.isCompressedLvl (1 ) &&
428
+ aTp.isOrderedLvl (1 ) && aTp.isUniqueLvl (1 ) && isAdmissibleMetaData (aTp);
429
+ }
430
+
431
+ // / Returns a suitable sparse format for the operation and given operand
432
+ // / types with cuSparse, or kNone if none is available.
433
+ static CuSparseFormat getCuSparseFormat (SparseTensorType aTp,
434
+ SparseTensorType bTp,
435
+ SparseTensorType cTp, bool enableRT,
436
+ bool isMatVec) {
437
+ // The other operands have a dense type.
415
438
if (bTp.hasEncoding () || cTp.hasEncoding ())
416
- return false ;
417
- if ( isAdmissibleCOO (aTp)) {
418
- isCOO = true ;
439
+ return CuSparseFormat:: kNone ;
440
+ // Now check for suitable operand type for the main operand.
441
+ if ( isAdmissibleCOO (aTp))
419
442
#ifdef CUSPARSE_COO_AOS
420
- return isMatVec;
443
+ return isMatVec ? CuSparseFormat:: kCOO : CuSparseFormat:: kNone ;
421
444
#else
422
- return enableRT;
445
+ return enableRT ? CuSparseFormat:: kCOO : CuSparseFormat:: kNone ;
423
446
#endif
424
- }
425
- return isAdmissibleCSR (aTp);
447
+ if (isAdmissibleCSR (aTp))
448
+ return CuSparseFormat::kCSR ;
449
+ if (isAdmissibleCSC (aTp))
450
+ return CuSparseFormat::kCSC ;
451
+ return CuSparseFormat::kNone ;
426
452
}
427
453
428
454
// / Generates the first positions/coordinates of a sparse matrix.
429
455
static Value genFirstPosOrCrds (OpBuilder &builder, Location loc, Value a,
430
- bool isCOO , bool enableRT) {
431
- if (isCOO ) {
456
+ CuSparseFormat format , bool enableRT) {
457
+ if (format == CuSparseFormat:: kCOO ) {
432
458
// Library uses SoA COO, direct IR uses AoS COO.
433
459
if (enableRT)
434
460
return genToCoordinates (builder, loc, a, 0 , /* cooStart=*/ 0 );
435
461
return genToCoordinatesBuffer (builder, loc, a);
436
462
}
437
- // CSR uses positions.
463
+ // Formats CSR/CSC and BSR use positions at 1 .
438
464
return genToPositions (builder, loc, a, 1 );
439
465
}
440
466
441
467
// / Generates the second coordinates of a sparse matrix.
442
468
static Value genSecondCrds (OpBuilder &builder, Location loc, Value a,
443
- bool isCOO, bool enableRT) {
469
+ CuSparseFormat format, bool enableRT) {
470
+ bool isCOO = format == CuSparseFormat::kCOO ;
444
471
if (isCOO && !enableRT)
445
472
return Value (); // nothing needed
473
+ // Formats CSR/CSC and BSR use coordinates at 1.
446
474
return genToCoordinates (builder, loc, a, 1 , /* cooStart=*/ isCOO ? 0 : 2 );
447
475
}
448
476
449
- // / Generates the sparse matrix multiplication .
477
+ // / Generates the sparse matrix handle .
450
478
static Operation *genSpMat (OpBuilder &builder, Location loc, Type handleTp,
451
479
Type tokenTp, Value token, Value sz1, Value sz2,
452
480
Value nseA, Value rowA, Value colA, Value valA,
453
- bool isCOO , bool enableRT) {
454
- if (isCOO ) {
481
+ CuSparseFormat format , bool enableRT) {
482
+ if (format == CuSparseFormat:: kCOO ) {
455
483
// Library uses SoA COO, direct IR uses AoS COO.
456
484
if (enableRT) {
457
485
assert (colA);
@@ -467,7 +495,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
467
495
#endif
468
496
}
469
497
assert (colA);
470
- return builder.create <gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
498
+ if (format == CuSparseFormat::kCSR )
499
+ return builder.create <gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
500
+ sz2, nseA, rowA, colA, valA);
501
+ assert (format == CuSparseFormat::kCSC );
502
+ return builder.create <gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
471
503
sz2, nseA, rowA, colA, valA);
472
504
}
473
505
@@ -484,12 +516,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
484
516
bool isZeroCopy =
485
517
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy ;
486
518
487
- // Only admissible sparse matrix format and dense vectors.
488
- bool isCOO = false ;
519
+ // Only admissible sparse matrix format and dense vectors (no BSR).
489
520
SparseTensorType aTp = getSparseTensorType (a);
490
521
SparseTensorType xTp = getSparseTensorType (x);
491
522
SparseTensorType yTp = getSparseTensorType (y);
492
- if (!areAdmissibleTypes (aTp, xTp, yTp, enableRT, /* isMatVec=*/ true , isCOO))
523
+ auto format = getCuSparseFormat (aTp, xTp, yTp, enableRT, /* isMatVec=*/ true );
524
+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR )
493
525
return failure ();
494
526
495
527
// Start sparse kernel and copy data from host to device.
@@ -499,8 +531,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
499
531
Value nseA = rewriter.create <NumberOfEntriesOp>(loc, a);
500
532
Value szY = linalg::createOrFoldDimOp (rewriter, loc, a, 0 );
501
533
Value szX = linalg::createOrFoldDimOp (rewriter, loc, a, 1 );
502
- Value memR = genFirstPosOrCrds (rewriter, loc, a, isCOO , enableRT);
503
- Value memC = genSecondCrds (rewriter, loc, a, isCOO , enableRT);
534
+ Value memR = genFirstPosOrCrds (rewriter, loc, a, format , enableRT);
535
+ Value memC = genSecondCrds (rewriter, loc, a, format , enableRT);
504
536
Value memV = genToValues (rewriter, loc, a);
505
537
Value memX, memY;
506
538
Value castR, castC, castV, castX, castY;
@@ -535,7 +567,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
535
567
Value token = genFirstWait (rewriter, loc);
536
568
Operation *spGenA =
537
569
genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
538
- rowA, colA, valA, isCOO , enableRT);
570
+ rowA, colA, valA, format , enableRT);
539
571
Value spMatA = spGenA->getResult (0 );
540
572
token = spGenA->getResult (1 );
541
573
auto dvecX = rewriter.create <gpu::CreateDnTensorOp>(
@@ -546,7 +578,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
546
578
loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
547
579
Value dnY = dvecY.getResult (0 );
548
580
token = dvecY.getAsyncToken ();
549
-
550
581
auto dnYType = llvm::cast<ShapedType>(y.getType ()).getElementType ();
551
582
552
583
// Precompute buffersize for SpMV.
@@ -610,12 +641,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
610
641
bool isZeroCopy =
611
642
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy ;
612
643
613
- // Only admissible sparse matrix format and dense matrices.
614
- bool isCOO = false ;
644
+ // Only admissible sparse matrix format and dense matrices (no BSR).
615
645
SparseTensorType aTp = getSparseTensorType (a);
616
646
SparseTensorType bTp = getSparseTensorType (b);
617
647
SparseTensorType cTp = getSparseTensorType (c);
618
- if (!areAdmissibleTypes (aTp, bTp, cTp, enableRT, /* isMatVec=*/ false , isCOO))
648
+ auto format = getCuSparseFormat (aTp, bTp, cTp, enableRT, /* isMatVec=*/ false );
649
+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR )
619
650
return failure ();
620
651
621
652
// Start sparse kernel and copy data from host to device.
@@ -626,8 +657,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
626
657
Value szm = linalg::createOrFoldDimOp (rewriter, loc, a, 0 );
627
658
Value szk = linalg::createOrFoldDimOp (rewriter, loc, a, 1 );
628
659
Value szn = linalg::createOrFoldDimOp (rewriter, loc, b, 1 );
629
- Value memR = genFirstPosOrCrds (rewriter, loc, a, isCOO , enableRT);
630
- Value memC = genSecondCrds (rewriter, loc, a, isCOO , enableRT);
660
+ Value memR = genFirstPosOrCrds (rewriter, loc, a, format , enableRT);
661
+ Value memC = genSecondCrds (rewriter, loc, a, format , enableRT);
631
662
Value memV = genToValues (rewriter, loc, a);
632
663
Value bufB, bufC;
633
664
Value castR, castC, castV, castB, castBufC;
@@ -661,7 +692,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
661
692
Value token = genFirstWait (rewriter, loc);
662
693
Operation *spGenA =
663
694
genSpMat (rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
664
- rowA, colA, valA, isCOO , enableRT);
695
+ rowA, colA, valA, format , enableRT);
665
696
Value spMatA = spGenA->getResult (0 );
666
697
token = spGenA->getResult (1 );
667
698
auto dmatB = rewriter.create <gpu::CreateDnTensorOp>(
@@ -674,7 +705,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
674
705
SmallVector<Value>{szm, szn});
675
706
Value dnC = dmatC.getResult (0 );
676
707
token = dmatC.getAsyncToken ();
677
-
678
708
auto dmatCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
679
709
680
710
// Precompute buffersize for SpMM.
@@ -686,7 +716,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
686
716
auto buf = genAllocBuffer (rewriter, loc, bufferSz, token);
687
717
Value buffer = buf.getResult (0 );
688
718
token = buf.getAsyncToken ();
689
-
690
719
auto dnCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
691
720
692
721
// Perform the SpMM.
@@ -738,7 +767,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
738
767
SmallVector<Value> tokens;
739
768
740
769
// Only CSR <- CSR x CSR supported.
741
- bool isCOO = false ;
770
+ auto format = CuSparseFormat:: kCSR ;
742
771
SparseTensorType aTp = getSparseTensorType (a);
743
772
SparseTensorType bTp = getSparseTensorType (b);
744
773
SparseTensorType cTp = getSparseTensorType (c);
@@ -755,11 +784,11 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
755
784
Value szm = linalg::createOrFoldDimOp (rewriter, loc, a, 0 );
756
785
Value szk = linalg::createOrFoldDimOp (rewriter, loc, a, 1 );
757
786
Value szn = linalg::createOrFoldDimOp (rewriter, loc, b, 1 );
758
- Value amemR = genFirstPosOrCrds (rewriter, loc, a, isCOO , enableRT);
759
- Value amemC = genSecondCrds (rewriter, loc, a, isCOO , enableRT);
787
+ Value amemR = genFirstPosOrCrds (rewriter, loc, a, format , enableRT);
788
+ Value amemC = genSecondCrds (rewriter, loc, a, format , enableRT);
760
789
Value amemV = genToValues (rewriter, loc, a);
761
- Value bmemR = genFirstPosOrCrds (rewriter, loc, b, isCOO , enableRT);
762
- Value bmemC = genSecondCrds (rewriter, loc, b, isCOO , enableRT);
790
+ Value bmemR = genFirstPosOrCrds (rewriter, loc, b, format , enableRT);
791
+ Value bmemC = genSecondCrds (rewriter, loc, b, format , enableRT);
763
792
Value bmemV = genToValues (rewriter, loc, b);
764
793
Value rowA = genAllocCopy (rewriter, loc, amemR, tokens);
765
794
Value colA = genAllocCopy (rewriter, loc, amemC, tokens);
@@ -778,12 +807,12 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
778
807
Value token = genFirstWait (rewriter, loc);
779
808
Operation *spGenA =
780
809
genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
781
- rowA, colA, valA, isCOO , enableRT);
810
+ rowA, colA, valA, format , enableRT);
782
811
Value spMatA = spGenA->getResult (0 );
783
812
token = spGenA->getResult (1 );
784
813
Operation *spGenB =
785
814
genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
786
- rowB, colB, valB, isCOO , enableRT);
815
+ rowB, colB, valB, format , enableRT);
787
816
Value spMatB = spGenB->getResult (0 );
788
817
token = spGenB->getResult (1 );
789
818
@@ -802,7 +831,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
802
831
token = e3 .getAsyncToken ();
803
832
Operation *spGenC =
804
833
genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
805
- rowC, colC, valC, isCOO , enableRT);
834
+ rowC, colC, valC, format , enableRT);
806
835
Value spMatC = spGenC->getResult (0 );
807
836
token = spGenC->getResult (1 );
808
837
@@ -1046,14 +1075,13 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
1046
1075
bool isZeroCopy =
1047
1076
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy ;
1048
1077
1049
- // Only admissible sparse matrix format and dense matrices, no COO.
1050
- bool isCOO = false ;
1078
+ // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1051
1079
SparseTensorType aTp = getSparseTensorType (a);
1052
1080
SparseTensorType bTp = getSparseTensorType (b);
1053
1081
SparseTensorType cTp = getSparseTensorType (c);
1054
- if (! areAdmissibleTypes ( cTp, bTp, aTp, enableRT, false , isCOO))
1055
- return failure ();
1056
- if (isCOO )
1082
+ auto format = getCuSparseFormat ( cTp, bTp, aTp, enableRT, /* isMatVec= */ false );
1083
+ if (format == CuSparseFormat:: kNone || format == CuSparseFormat:: kCOO ||
1084
+ format == CuSparseFormat:: kCSC )
1057
1085
return failure ();
1058
1086
1059
1087
// The SDDMM does the in-place operation.
@@ -1072,8 +1100,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
1072
1100
Value bufB = genTensorToMemref (rewriter, loc, b);
1073
1101
if (!isZeroCopy)
1074
1102
matB = isZeroCopy ? bufB : genAllocCopy (rewriter, loc, bufB, tokens);
1075
- Value memR = genFirstPosOrCrds (rewriter, loc, c, isCOO , enableRT);
1076
- Value memC = genSecondCrds (rewriter, loc, c, isCOO , enableRT);
1103
+ Value memR = genFirstPosOrCrds (rewriter, loc, c, format , enableRT);
1104
+ Value memC = genSecondCrds (rewriter, loc, c, format , enableRT);
1077
1105
Value memV = genToValues (rewriter, loc, c);
1078
1106
Value castB, castA, castR, castC, castV;
1079
1107
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA ) {
@@ -1108,10 +1136,9 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
1108
1136
loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1109
1137
Value dnB = dmatB.getResult (0 );
1110
1138
token = dmatB.getAsyncToken ();
1111
-
1112
1139
Operation *spGenC =
1113
1140
genSpMat (rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
1114
- rowC, colC, valC, isCOO , enableRT);
1141
+ rowC, colC, valC, format , enableRT);
1115
1142
Value spMatC = spGenC->getResult (0 );
1116
1143
token = spGenC->getResult (1 );
1117
1144
auto dnCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
0 commit comments