Skip to content

Commit ab6334d

Browse files
authored
[mlir][sparse] add expanded size to API (#68614)
Used for asserting we do not run out of bounds on the expanded access pattern.
1 parent d5622de commit ab6334d

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ class SparseTensorStorageBase {
214214
/// * `added` a map from `[0..count)` to last-level coordinates for
215215
/// which `filled` is true and `values` contains the assotiated value.
216216
/// * `count` the size of `added`.
217+
/// * `expsz` the size of the expanded vector (verification only).
217218
#define DECL_EXPINSERT(VNAME, V) \
218-
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t);
219+
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t, \
220+
uint64_t);
219221
MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
220222
#undef DECL_EXPINSERT
221223

@@ -426,7 +428,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
426428

427429
/// Partially specialize expanded insertions based on template types.
428430
void expInsert(uint64_t *lvlCoords, V *values, bool *filled, uint64_t *added,
429-
uint64_t count) final {
431+
uint64_t count, uint64_t expsz) final {
430432
assert((lvlCoords && values && filled && added) && "Received nullptr");
431433
if (count == 0)
432434
return;
@@ -435,6 +437,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
435437
// Restore insertion path for first insert.
436438
const uint64_t lastLvl = getLvlRank() - 1;
437439
uint64_t c = added[0];
440+
assert(c <= expsz);
438441
assert(filled[c] && "added coordinate is not filled");
439442
lvlCoords[lastLvl] = c;
440443
lexInsert(lvlCoords, values[c]);
@@ -444,6 +447,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
444447
for (uint64_t i = 1; i < count; ++i) {
445448
assert(c < added[i] && "non-lexicographic insertion");
446449
c = added[i];
450+
assert(c <= expsz);
447451
assert(filled[c] && "added coordinate is not filled");
448452
lvlCoords[lastLvl] = c;
449453
insPath(lvlCoords, lastLvl, added[i - 1] + 1, values[c]);

mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
9090

9191
#define IMPL_EXPINSERT(VNAME, V) \
9292
void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
93-
uint64_t) { \
93+
uint64_t, uint64_t) { \
9494
FATAL_PIV("expInsert" #VNAME); \
9595
}
9696
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
480480
V *values = MEMREF_GET_PAYLOAD(vref); \
481481
bool *filled = MEMREF_GET_PAYLOAD(fref); \
482482
index_type *added = MEMREF_GET_PAYLOAD(aref); \
483-
tensor.expInsert(lvlCoords, values, filled, added, count); \
483+
uint64_t expsz = vref->sizes[0]; \
484+
tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \
484485
}
485486
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
486487
#undef IMPL_EXPINSERT

0 commit comments

Comments
 (0)