Skip to content

Commit 53ffafb

Browse files
authored
[mlir][sparse] support sparse constant to BSR conversion. (#71114)
support direct convert from a constant tensor defined by SparseArrayElements to BSR
1 parent 89d5635 commit 53ffafb

File tree

4 files changed

+45
-35
lines changed

4 files changed

+45
-35
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,11 +1666,10 @@ LogicalResult ForeachOp::verify() {
16661666
const Dimension dimRank = t.getDimRank();
16671667
const auto args = getBody()->getArguments();
16681668

1669-
if (getOrder().has_value() &&
1670-
(t.getEncoding() || !getOrder()->isPermutation()))
1671-
return emitError("Only support permuted order on non encoded dense tensor");
1669+
if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1670+
return emitError("Level traverse order does not match tensor's level rank");
16721671

1673-
if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
1672+
if (dimRank + 1 + getInitArgs().size() != args.size())
16741673
return emitError("Unmatched number of arguments in the block");
16751674

16761675
if (getNumResults() != getInitArgs().size())

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,11 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
421421
void sparse_tensor::foreachInSparseConstant(
422422
OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
423423
function_ref<void(ArrayRef<Value>, Value)> callback) {
424-
const Dimension dimRank =
425-
SparseTensorType(getRankedTensorType(attr)).getDimRank();
424+
if (!order)
425+
order = builder.getMultiDimIdentityMap(attr.getType().getRank());
426+
427+
auto stt = SparseTensorType(getRankedTensorType(attr));
428+
const Dimension dimRank = stt.getDimRank();
426429
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
427430
const auto values = attr.getValues().getValues<Attribute>();
428431

@@ -446,20 +449,23 @@ void sparse_tensor::foreachInSparseConstant(
446449

447450
// Sorts the sparse element attribute based on coordinates.
448451
std::sort(elems.begin(), elems.end(),
449-
[order, dimRank](const ElementAttr &lhs, const ElementAttr &rhs) {
450-
const auto &lhsCoords = lhs.first;
451-
const auto &rhsCoords = rhs.first;
452-
for (Dimension d = 0; d < dimRank; d++) {
453-
// FIXME: This only makes sense for permutations.
454-
// And since we don't check that `order` is a permutation,
455-
// it can also cause OOB errors when we use `l`.
456-
const Level l = order ? order.getDimPosition(d) : d;
457-
if (lhsCoords[l].getInt() == rhsCoords[l].getInt())
458-
continue;
459-
return lhsCoords[l].getInt() < rhsCoords[l].getInt();
460-
}
452+
[order](const ElementAttr &lhs, const ElementAttr &rhs) {
461453
if (std::addressof(lhs) == std::addressof(rhs))
462454
return false;
455+
456+
auto lhsCoords = llvm::map_to_vector(
457+
lhs.first, [](IntegerAttr i) { return i.getInt(); });
458+
auto rhsCoords = llvm::map_to_vector(
459+
rhs.first, [](IntegerAttr i) { return i.getInt(); });
460+
461+
SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
462+
SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
463+
// Sort the element based on the lvl coordinates.
464+
for (Level l = 0; l < order.getNumResults(); l++) {
465+
if (lhsLvlCrds[l] == rhsLvlCrds[l])
466+
continue;
467+
return lhsLvlCrds[l] < rhsLvlCrds[l];
468+
}
463469
llvm_unreachable("no equal coordinate in sparse element attr");
464470
});
465471

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,14 +1129,11 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
11291129

11301130
SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
11311131
if (op.getOrder()) {
1132-
// FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
1133-
const Dimension dimRank = stt.getDimRank();
1134-
SmallVector<Value> dcvs = lcvs; // keep a copy
1135-
for (Dimension d = 0; d < dimRank; d++) {
1136-
auto l = op.getOrder()->getDimPosition(d);
1137-
lcvs[l] = dcvs[d];
1138-
}
1132+
// TODO: Support it so that we can do direct conversion from CSR->BSR.
1133+
llvm_unreachable(
1134+
"Level order not yet implemented on non-constant input tensors.");
11391135
}
1136+
11401137
Value vals = loopEmitter.getValBuffer()[0];
11411138
Value pos = loopEmitter.getPosits()[0].back();
11421139
// Loads the value from sparse tensor using position-index;

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,34 @@ module {
7474
//
7575
// Initialize a 2-dim dense tensor.
7676
//
77-
%t = arith.constant dense<[
78-
[ 1.0, 2.0, 3.0, 4.0 ],
79-
[ 5.0, 6.0, 7.0, 8.0 ]
80-
]> : tensor<2x4xf64>
77+
%t = arith.constant sparse<[[0, 0], [0, 1], [0, 2], [0, 3],
78+
[1, 0], [1, 1], [1, 2], [1, 3]],
79+
[ 1.0, 2.0, 3.0, 4.0,
80+
5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>
8181

82+
%td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
83+
[ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<2x4xf64>
8284

83-
%1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
84-
%2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
85-
%3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
85+
// constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
86+
%1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
87+
%2 = sparse_tensor.convert %td : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
88+
%3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
89+
%4 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
8690

87-
%v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
91+
%v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
8892
%v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
89-
%v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
93+
%v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSR> to memref<?xf64>
94+
%v4 = sparse_tensor.values %4 : tensor<2x4xf64, #CSC> to memref<?xf64>
9095

91-
// CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )
96+
97+
// CHECK: ( 1, 2, 5, 6, 3, 4, 7, 8 )
9298
// CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
99+
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
93100
// CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
94101
call @dumpf64(%v1) : (memref<?xf64>) -> ()
95102
call @dumpf64(%v2) : (memref<?xf64>) -> ()
96103
call @dumpf64(%v3) : (memref<?xf64>) -> ()
104+
call @dumpf64(%v4) : (memref<?xf64>) -> ()
97105

98106
return
99107
}

0 commit comments

Comments
 (0)