Skip to content

Commit d30dccd

Browse files
author
Peiming Liu
committed
[mlir][sparse] Favors synthetic tensor over other undefined tensors
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135385
1 parent ddb3553 commit d30dccd

File tree

2 files changed

+104
-31
lines changed

2 files changed

+104
-31
lines changed

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,26 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
265265

266266
BitVector simple = latPoints[p0].bits;
267267
bool reset = isSingleton && hasAnySparse(simple);
268-
unsigned offset = 0;
268+
unsigned be = simple.size();
269+
unsigned offset = 0; // relative to the end
269270
if (!reset)
270271
// Starts resetting from a dense dimension, so that the first bit (if kept)
271272
// is not undefined dimension type.
272-
for (unsigned b = 0, be = simple.size(); b < be; b++)
273-
if (simple[b] && isDimLevelType(b, DimLvlType::kDense))
274-
offset = b;
273+
for (unsigned b = 0; b < be; b++) {
274+
if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) {
275+
offset = be - b - 1; // relative to the end
276+
break;
277+
}
278+
}
275279

276-
// Now apply the two basic rules.
277-
for (unsigned b = 0, be = simple.size(); b < be; b++) {
278-
unsigned i = (offset + b) % be;
279-
if (simple[i] && (!isDimLevelType(i, DimLvlType::kCompressed) &&
280-
!isDimLevelType(i, DimLvlType::kSingleton))) {
280+
// Now apply the two basic rules. We also iterate the bits reversely to always
281+
// keep the rightmost bit (which could possibly be a synthetic tensor).
282+
for (unsigned b = be - 1 - offset, i = 0; i < be;
283+
b = b == 0 ? be - 1 : b - 1, i++) {
284+
if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) &&
285+
!isDimLevelType(b, DimLvlType::kSingleton))) {
281286
if (reset)
282-
simple.reset(i);
287+
simple.reset(b);
283288
reset = true;
284289
}
285290
}

mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,16 @@ class MergerTest3T1LD : public MergerTestBase {
380380
///
381381
/// Tests with both undef and dense input.
382382
///
383-
class MergerTest3T1LU : public MergerTestBase {
383+
384+
class MergerTest4T1LU : public MergerTestBase {
384385
protected:
385386
// Our three tensors (two inputs, one output).
386-
const unsigned t0 = 0, t1 = 1, t2 = 2;
387+
const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
387388

388389
// Our single loop.
389390
const unsigned l0 = 0;
390391

391-
MergerTest3T1LU() : MergerTestBase(3, 1) {
392+
MergerTest4T1LU() : MergerTestBase(4, 1) {
392393
// Tensor 0: undef input vector.
393394
merger.addExp(Kind::kTensor, t0, -1u);
394395
merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
@@ -397,43 +398,110 @@ class MergerTest3T1LU : public MergerTestBase {
397398
merger.addExp(Kind::kTensor, t1, -1u);
398399
merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
399400

400-
// Tensor 2: dense output vector.
401+
// Tensor 2: undef input vector.
401402
merger.addExp(Kind::kTensor, t2, -1u);
402-
merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
403+
merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kUndef));
404+
405+
// Tensor 3: dense output vector.
406+
merger.addExp(Kind::kTensor, t3, -1u);
407+
merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense));
408+
}
409+
};
410+
411+
///
412+
/// Tests with operation on sparse output.
413+
///
414+
415+
class MergerTest3T1L_SO : public MergerTestBase {
416+
protected:
417+
// Our three tensors (two inputs, one output, one synthetic).
418+
const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
419+
420+
// Our single loop.
421+
const unsigned l0 = 0;
422+
423+
MergerTest3T1L_SO() : MergerTestBase(3, 1) {
424+
merger.setHasSparseOut(true);
425+
426+
// Tensor 0: undef input vector.
427+
merger.addExp(Kind::kTensor, t0, -1u);
428+
merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
429+
430+
// Tensor 1: undef input vector.
431+
merger.addExp(Kind::kTensor, t1, -1u);
432+
merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kUndef));
433+
434+
// Tensor 2: sparse output vector.
435+
merger.addExp(Kind::kTensor, t2, -1u);
436+
merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed));
403437
}
404438
};
439+
405440
} // namespace
406441

407-
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
408-
/// a(i) = b(i) * c(i)
442+
/// Vector multiplication (conjunction) of 3 vectors, i.e.;
443+
/// a(i) = b(i) * c(i) * d(i)
409444
/// which should form the single lattice point
410445
/// {
411-
/// lat( i_00_U i_01_D / (tensor_0 * tensor_1) )
446+
/// lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
412447
/// }
413448
/// after optimization, the dense dimesion should be kept, despite it appears
414-
/// after the undef dimension
449+
/// in the middle
415450
/// {
416-
/// lat( i_01_D / (tensor_0 * tensor_1) )
451+
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
417452
/// }
418-
#define IMPL_MERGER_TEST_CONJ(OP) \
419-
TEST_F(MergerTest3T1LU, vector_##OP) { \
420-
auto e = OP##Expr(t0, t1); \
453+
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
454+
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
455+
auto em = CONJ1##Expr(t0, t1); \
456+
auto e = CONJ2##Expr(em, t2); \
421457
auto p0 = tensorPattern(t0); \
422458
auto p1 = tensorPattern(t1); \
459+
auto p2 = tensorPattern(t2); \
423460
auto s = merger.buildLattices(e, l0); \
424-
\
425461
expectNumLatPoints(s, 1); \
426-
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
427-
loopsToBits({{l0, t0}, {l0, t1}})); \
428-
\
462+
expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
463+
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
429464
s = merger.optimizeSet(s); \
430465
expectNumLatPoints(s, 1); \
431-
expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t1}}), \
432-
true); \
466+
expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
467+
loopsToBits({{l0, t1}}), true); \
433468
}
434-
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
435469

436-
#undef IMPL_MERGER_TEST_CONJ
470+
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
471+
472+
#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
473+
474+
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
475+
/// o(i) = b(i) * c(i) * o(i)
476+
/// which should form the single lattice point (note how a synthetic tensor
477+
/// i_03_U is created for the sparse output)
478+
/// {
479+
/// lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
480+
/// }
481+
/// after optimization, the synthetic tensor should be preserved.
482+
/// {
483+
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
484+
/// }
485+
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
486+
TEST_F(MergerTest3T1L_SO, vector_##CONJ1##_##CONJ2) { \
487+
auto em = CONJ1##Expr(t0, t1); \
488+
auto e = CONJ2##Expr(em, t2); \
489+
auto p0 = tensorPattern(t0); \
490+
auto p1 = tensorPattern(t1); \
491+
auto p2 = tensorPattern(t2); \
492+
auto s = merger.buildLattices(e, l0); \
493+
expectNumLatPoints(s, 1); \
494+
expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
495+
loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
496+
s = merger.optimizeSet(s); \
497+
expectNumLatPoints(s, 1); \
498+
expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
499+
loopsToBits({{l0, t3}}), true); \
500+
}
501+
502+
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
503+
504+
#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
437505

438506
/// Vector addition (disjunction) of 2 vectors. i.e.;
439507
/// a(i) = b(i) + c(i)

0 commit comments

Comments
 (0)