Skip to content

Commit 3231a36

Browse files
authored
[mlir][sparse][gpu] add CSC to libgen GPU sparsification using cuSparse (#67713)
Add CSC, but also adds BSR as a future format. Coming soon!
1 parent de7881e commit 3231a36

File tree

3 files changed

+167
-70
lines changed

3 files changed

+167
-70
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 91 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ using namespace mlir::sparse_tensor;
3333

3434
namespace {
3535

36+
// Sparse formats supported by cuSparse.
37+
enum class CuSparseFormat {
38+
kNone,
39+
kCOO,
40+
kCSR,
41+
kCSC,
42+
kBSR, // TODO: coming soon!
43+
};
44+
3645
//===----------------------------------------------------------------------===//
3746
// Helper methods.
3847
//===----------------------------------------------------------------------===//
@@ -385,73 +394,92 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
385394
return false;
386395
}
387396

388-
/// Determines if the given value is a dense tensor instead of a sparse one.
397+
/// Test for dense tensor.
389398
static bool isDenseTensor(Value v) {
390-
return (sparse_tensor::getSparseTensorType(v).isAllDense());
399+
auto sTp = getSparseTensorType(v);
400+
return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
401+
}
402+
403+
/// Test for suitable positions/coordinates width.
404+
static bool isAdmissibleMetaData(SparseTensorType &aTp) {
405+
return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
406+
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
391407
}
392408

393-
/// Test for sorted COO with suitable data and coordinates types.
409+
/// Test for sorted COO matrix with suitable metadata.
394410
static bool isAdmissibleCOO(SparseTensorType &aTp) {
395-
return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
411+
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
412+
aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
396413
aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
397-
(aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
398-
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
399-
aTp.getCrdWidth() == 64);
414+
isAdmissibleMetaData(aTp);
400415
}
401416

402-
/// Test for CSR with suitable data and coordinates types.
417+
/// Test for CSR matrix with suitable metadata.
403418
static bool isAdmissibleCSR(SparseTensorType &aTp) {
404-
return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
405-
aTp.isUniqueLvl(1) &&
406-
(aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
407-
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
408-
aTp.getCrdWidth() == 64);
419+
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
420+
aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
421+
aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
409422
}
410423

411-
/// Test for admissible types on operands (with output parameter `isCOO`).
412-
static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
413-
SparseTensorType cTp, bool enableRT,
414-
bool isMatVec, bool &isCOO) {
424+
/// Test for CSC matrix with suitable metadata.
425+
static bool isAdmissibleCSC(SparseTensorType &aTp) {
426+
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
427+
aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
428+
aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429+
}
430+
431+
/// Returns a suitable sparse format for the operation and given operand
432+
/// types with cuSparse, or kNone if none is available.
433+
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
434+
SparseTensorType bTp,
435+
SparseTensorType cTp, bool enableRT,
436+
bool isMatVec) {
437+
// The other operands have a dense type.
415438
if (bTp.hasEncoding() || cTp.hasEncoding())
416-
return false;
417-
if (isAdmissibleCOO(aTp)) {
418-
isCOO = true;
439+
return CuSparseFormat::kNone;
440+
// Now check for suitable operand type for the main operand.
441+
if (isAdmissibleCOO(aTp))
419442
#ifdef CUSPARSE_COO_AOS
420-
return isMatVec;
443+
return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
421444
#else
422-
return enableRT;
445+
return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
423446
#endif
424-
}
425-
return isAdmissibleCSR(aTp);
447+
if (isAdmissibleCSR(aTp))
448+
return CuSparseFormat::kCSR;
449+
if (isAdmissibleCSC(aTp))
450+
return CuSparseFormat::kCSC;
451+
return CuSparseFormat::kNone;
426452
}
427453

428454
/// Generates the first positions/coordinates of a sparse matrix.
429455
static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
430-
bool isCOO, bool enableRT) {
431-
if (isCOO) {
456+
CuSparseFormat format, bool enableRT) {
457+
if (format == CuSparseFormat::kCOO) {
432458
// Library uses SoA COO, direct IR uses AoS COO.
433459
if (enableRT)
434460
return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
435461
return genToCoordinatesBuffer(builder, loc, a);
436462
}
437-
// CSR uses positions.
463+
// Formats CSR/CSC and BSR use positions at 1.
438464
return genToPositions(builder, loc, a, 1);
439465
}
440466

441467
/// Generates the second coordinates of a sparse matrix.
442468
static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
443-
bool isCOO, bool enableRT) {
469+
CuSparseFormat format, bool enableRT) {
470+
bool isCOO = format == CuSparseFormat::kCOO;
444471
if (isCOO && !enableRT)
445472
return Value(); // nothing needed
473+
// Formats CSR/CSC and BSR use coordinates at 1.
446474
return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
447475
}
448476

449-
/// Generates the sparse matrix multiplication.
477+
/// Generates the sparse matrix handle.
450478
static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
451479
Type tokenTp, Value token, Value sz1, Value sz2,
452480
Value nseA, Value rowA, Value colA, Value valA,
453-
bool isCOO, bool enableRT) {
454-
if (isCOO) {
481+
CuSparseFormat format, bool enableRT) {
482+
if (format == CuSparseFormat::kCOO) {
455483
// Library uses SoA COO, direct IR uses AoS COO.
456484
if (enableRT) {
457485
assert(colA);
@@ -467,7 +495,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
467495
#endif
468496
}
469497
assert(colA);
470-
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
498+
if (format == CuSparseFormat::kCSR)
499+
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
500+
sz2, nseA, rowA, colA, valA);
501+
assert(format == CuSparseFormat::kCSC);
502+
return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
471503
sz2, nseA, rowA, colA, valA);
472504
}
473505

@@ -484,12 +516,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
484516
bool isZeroCopy =
485517
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
486518

487-
// Only admissible sparse matrix format and dense vectors.
488-
bool isCOO = false;
519+
// Only admissible sparse matrix format and dense vectors (no BSR).
489520
SparseTensorType aTp = getSparseTensorType(a);
490521
SparseTensorType xTp = getSparseTensorType(x);
491522
SparseTensorType yTp = getSparseTensorType(y);
492-
if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
523+
auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
524+
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
493525
return failure();
494526

495527
// Start sparse kernel and copy data from host to device.
@@ -499,8 +531,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
499531
Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
500532
Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
501533
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
502-
Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
503-
Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
534+
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
535+
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
504536
Value memV = genToValues(rewriter, loc, a);
505537
Value memX, memY;
506538
Value castR, castC, castV, castX, castY;
@@ -535,7 +567,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
535567
Value token = genFirstWait(rewriter, loc);
536568
Operation *spGenA =
537569
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
538-
rowA, colA, valA, isCOO, enableRT);
570+
rowA, colA, valA, format, enableRT);
539571
Value spMatA = spGenA->getResult(0);
540572
token = spGenA->getResult(1);
541573
auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
@@ -546,7 +578,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
546578
loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
547579
Value dnY = dvecY.getResult(0);
548580
token = dvecY.getAsyncToken();
549-
550581
auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
551582

552583
// Precompute buffersize for SpMV.
@@ -610,12 +641,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
610641
bool isZeroCopy =
611642
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
612643

613-
// Only admissible sparse matrix format and dense matrices.
614-
bool isCOO = false;
644+
// Only admissible sparse matrix format and dense matrices (no BSR).
615645
SparseTensorType aTp = getSparseTensorType(a);
616646
SparseTensorType bTp = getSparseTensorType(b);
617647
SparseTensorType cTp = getSparseTensorType(c);
618-
if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
648+
auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
649+
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
619650
return failure();
620651

621652
// Start sparse kernel and copy data from host to device.
@@ -626,8 +657,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
626657
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
627658
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
628659
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
629-
Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
630-
Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
660+
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
661+
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
631662
Value memV = genToValues(rewriter, loc, a);
632663
Value bufB, bufC;
633664
Value castR, castC, castV, castB, castBufC;
@@ -661,7 +692,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
661692
Value token = genFirstWait(rewriter, loc);
662693
Operation *spGenA =
663694
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
664-
rowA, colA, valA, isCOO, enableRT);
695+
rowA, colA, valA, format, enableRT);
665696
Value spMatA = spGenA->getResult(0);
666697
token = spGenA->getResult(1);
667698
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -674,7 +705,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
674705
SmallVector<Value>{szm, szn});
675706
Value dnC = dmatC.getResult(0);
676707
token = dmatC.getAsyncToken();
677-
678708
auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
679709

680710
// Precompute buffersize for SpMM.
@@ -686,7 +716,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
686716
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
687717
Value buffer = buf.getResult(0);
688718
token = buf.getAsyncToken();
689-
690719
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
691720

692721
// Perform the SpMM.
@@ -738,7 +767,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
738767
SmallVector<Value> tokens;
739768

740769
// Only CSR <- CSR x CSR supported.
741-
bool isCOO = false;
770+
auto format = CuSparseFormat::kCSR;
742771
SparseTensorType aTp = getSparseTensorType(a);
743772
SparseTensorType bTp = getSparseTensorType(b);
744773
SparseTensorType cTp = getSparseTensorType(c);
@@ -755,11 +784,11 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
755784
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
756785
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
757786
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
758-
Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
759-
Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
787+
Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
788+
Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
760789
Value amemV = genToValues(rewriter, loc, a);
761-
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
762-
Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
790+
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
791+
Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
763792
Value bmemV = genToValues(rewriter, loc, b);
764793
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
765794
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
@@ -778,12 +807,12 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
778807
Value token = genFirstWait(rewriter, loc);
779808
Operation *spGenA =
780809
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
781-
rowA, colA, valA, isCOO, enableRT);
810+
rowA, colA, valA, format, enableRT);
782811
Value spMatA = spGenA->getResult(0);
783812
token = spGenA->getResult(1);
784813
Operation *spGenB =
785814
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
786-
rowB, colB, valB, isCOO, enableRT);
815+
rowB, colB, valB, format, enableRT);
787816
Value spMatB = spGenB->getResult(0);
788817
token = spGenB->getResult(1);
789818

@@ -802,7 +831,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
802831
token = e3.getAsyncToken();
803832
Operation *spGenC =
804833
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
805-
rowC, colC, valC, isCOO, enableRT);
834+
rowC, colC, valC, format, enableRT);
806835
Value spMatC = spGenC->getResult(0);
807836
token = spGenC->getResult(1);
808837

@@ -1046,14 +1075,13 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
10461075
bool isZeroCopy =
10471076
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
10481077

1049-
// Only admissible sparse matrix format and dense matrices, no COO.
1050-
bool isCOO = false;
1078+
// Only admissible sparse matrix format (no COO/CSC) and dense matrices.
10511079
SparseTensorType aTp = getSparseTensorType(a);
10521080
SparseTensorType bTp = getSparseTensorType(b);
10531081
SparseTensorType cTp = getSparseTensorType(c);
1054-
if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
1055-
return failure();
1056-
if (isCOO)
1082+
auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1083+
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1084+
format == CuSparseFormat::kCSC)
10571085
return failure();
10581086

10591087
// The SDDMM does the in-place operation.
@@ -1072,8 +1100,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
10721100
Value bufB = genTensorToMemref(rewriter, loc, b);
10731101
if (!isZeroCopy)
10741102
matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
1075-
Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
1076-
Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
1103+
Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1104+
Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
10771105
Value memV = genToValues(rewriter, loc, c);
10781106
Value castB, castA, castR, castC, castV;
10791107
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
@@ -1108,10 +1136,9 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
11081136
loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
11091137
Value dnB = dmatB.getResult(0);
11101138
token = dmatB.getAsyncToken();
1111-
11121139
Operation *spGenC =
11131140
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
1114-
rowC, colC, valC, isCOO, enableRT);
1141+
rowC, colC, valC, format, enableRT);
11151142
Value spMatC = spGenC->getResult(0);
11161143
token = spGenC->getResult(1);
11171144
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();

0 commit comments

Comments
 (0)