Skip to content

[mlir][sparse][gpu] add CSC to libgen GPU sparsification using cuSparse #67713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 91 additions & 64 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ using namespace mlir::sparse_tensor;

namespace {

// Sparse formats supported by cuSparse.
enum class CuSparseFormat {
kNone,
kCOO,
kCSR,
kCSC,
kBSR, // TODO: coming soon!
};

//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -385,73 +394,92 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
return false;
}

/// Determines if the given value is a dense tensor instead of a sparse one.
/// Test for dense tensor.
static bool isDenseTensor(Value v) {
return (sparse_tensor::getSparseTensorType(v).isAllDense());
auto sTp = getSparseTensorType(v);
return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
}

/// Test for suitable positions/coordinates width.
static bool isAdmissibleMetaData(SparseTensorType &aTp) {
return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
}

/// Test for sorted COO with suitable data and coordinates types.
/// Test for sorted COO matrix with suitable metadata.
static bool isAdmissibleCOO(SparseTensorType &aTp) {
return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
(aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
aTp.getCrdWidth() == 64);
isAdmissibleMetaData(aTp);
}

/// Test for CSR with suitable data and coordinates types.
/// Test for CSR matrix with suitable metadata.
static bool isAdmissibleCSR(SparseTensorType &aTp) {
return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
aTp.isUniqueLvl(1) &&
(aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
aTp.getCrdWidth() == 64);
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
}

/// Test for admissible types on operands (with output parameter `isCOO`).
static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
SparseTensorType cTp, bool enableRT,
bool isMatVec, bool &isCOO) {
/// Test for CSC matrix with suitable metadata.
static bool isAdmissibleCSC(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
}

/// Returns a suitable sparse format for the operation and given operand
/// types with cuSparse, or kNone if none is available.
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
SparseTensorType bTp,
SparseTensorType cTp, bool enableRT,
bool isMatVec) {
// The other operands have a dense type.
if (bTp.hasEncoding() || cTp.hasEncoding())
return false;
if (isAdmissibleCOO(aTp)) {
isCOO = true;
return CuSparseFormat::kNone;
// Now check for suitable operand type for the main operand.
if (isAdmissibleCOO(aTp))
#ifdef CUSPARSE_COO_AOS
return isMatVec;
return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
#else
return enableRT;
return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
#endif
}
return isAdmissibleCSR(aTp);
if (isAdmissibleCSR(aTp))
return CuSparseFormat::kCSR;
if (isAdmissibleCSC(aTp))
return CuSparseFormat::kCSC;
return CuSparseFormat::kNone;
}

/// Generates the first positions/coordinates of a sparse matrix.
static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
bool isCOO, bool enableRT) {
if (isCOO) {
CuSparseFormat format, bool enableRT) {
if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
return genToCoordinatesBuffer(builder, loc, a);
}
// CSR uses positions.
// Formats CSR/CSC and BSR use positions at 1.
return genToPositions(builder, loc, a, 1);
}

/// Generates the second coordinates of a sparse matrix.
static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
bool isCOO, bool enableRT) {
CuSparseFormat format, bool enableRT) {
bool isCOO = format == CuSparseFormat::kCOO;
if (isCOO && !enableRT)
return Value(); // nothing needed
// Formats CSR/CSC and BSR use coordinates at 1.
return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
}

/// Generates the sparse matrix multiplication.
/// Generates the sparse matrix handle.
static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
Type tokenTp, Value token, Value sz1, Value sz2,
Value nseA, Value rowA, Value colA, Value valA,
bool isCOO, bool enableRT) {
if (isCOO) {
CuSparseFormat format, bool enableRT) {
if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT) {
assert(colA);
Expand All @@ -467,7 +495,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
#endif
}
assert(colA);
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
if (format == CuSparseFormat::kCSR)
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
sz2, nseA, rowA, colA, valA);
assert(format == CuSparseFormat::kCSC);
return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
sz2, nseA, rowA, colA, valA);
}

Expand All @@ -484,12 +516,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
bool isZeroCopy =
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

// Only admissible sparse matrix format and dense vectors.
bool isCOO = false;
// Only admissible sparse matrix format and dense vectors (no BSR).
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType xTp = getSparseTensorType(x);
SparseTensorType yTp = getSparseTensorType(y);
if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
return failure();

// Start sparse kernel and copy data from host to device.
Expand All @@ -499,8 +531,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value memV = genToValues(rewriter, loc, a);
Value memX, memY;
Value castR, castC, castV, castX, castY;
Expand Down Expand Up @@ -535,7 +567,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
rowA, colA, valA, isCOO, enableRT);
rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
Expand All @@ -546,7 +578,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
Value dnY = dvecY.getResult(0);
token = dvecY.getAsyncToken();

auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();

// Precompute buffersize for SpMV.
Expand Down Expand Up @@ -610,12 +641,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
bool isZeroCopy =
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

// Only admissible sparse matrix format and dense matrices.
bool isCOO = false;
// Only admissible sparse matrix format and dense matrices (no BSR).
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
return failure();

// Start sparse kernel and copy data from host to device.
Expand All @@ -626,8 +657,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value memV = genToValues(rewriter, loc, a);
Value bufB, bufC;
Value castR, castC, castV, castB, castBufC;
Expand Down Expand Up @@ -661,7 +692,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
rowA, colA, valA, isCOO, enableRT);
rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
Expand All @@ -674,7 +705,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
SmallVector<Value>{szm, szn});
Value dnC = dmatC.getResult(0);
token = dmatC.getAsyncToken();

auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();

// Precompute buffersize for SpMM.
Expand All @@ -686,7 +716,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
token = buf.getAsyncToken();

auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();

// Perform the SpMM.
Expand Down Expand Up @@ -738,7 +767,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
SmallVector<Value> tokens;

// Only CSR <- CSR x CSR supported.
bool isCOO = false;
auto format = CuSparseFormat::kCSR;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
Expand All @@ -755,11 +784,11 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value amemV = genToValues(rewriter, loc, a);
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
Value bmemV = genToValues(rewriter, loc, b);
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
Expand All @@ -778,12 +807,12 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
rowA, colA, valA, isCOO, enableRT);
rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
Operation *spGenB =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
rowB, colB, valB, isCOO, enableRT);
rowB, colB, valB, format, enableRT);
Value spMatB = spGenB->getResult(0);
token = spGenB->getResult(1);

Expand All @@ -802,7 +831,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
token = e3.getAsyncToken();
Operation *spGenC =
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
rowC, colC, valC, isCOO, enableRT);
rowC, colC, valC, format, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);

Expand Down Expand Up @@ -1045,14 +1074,13 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
bool isZeroCopy =
gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;

// Only admissible sparse matrix format and dense matrices, no COO.
bool isCOO = false;
// Only admissible sparse matrix format (no COO/CSC) and dense matrices.
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
return failure();
if (isCOO)
auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
format == CuSparseFormat::kCSC)
return failure();

// The SDDMM does the in-place operation.
Expand All @@ -1071,8 +1099,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
Value bufB = genTensorToMemref(rewriter, loc, b);
if (!isZeroCopy)
matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
Value memV = genToValues(rewriter, loc, c);
Value castB, castA, castR, castC, castV;
if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
Expand Down Expand Up @@ -1107,10 +1135,9 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
Value dnB = dmatB.getResult(0);
token = dmatB.getAsyncToken();

Operation *spGenC =
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
rowC, colC, valC, isCOO, enableRT);
rowC, colC, valC, format, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
Expand Down
Loading