Skip to content

Commit 3aeb28b

Browse files
author
Peiming Liu
authored
[mlir][sparse] fold sparse convert into producer linalg op. (#89999)
1 parent c49b74a commit 3aeb28b

File tree

5 files changed

+151
-71
lines changed

5 files changed

+151
-71
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,21 @@ inline MemRefType getMemRefType(T &&t) {
8989
/// Returns null-attribute for any type without an encoding.
9090
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
9191

92+
/// Returns true iff the type range has any sparse tensor type.
93+
inline bool hasAnySparseType(TypeRange types) {
94+
return llvm::any_of(types, [](Type type) {
95+
return getSparseTensorEncoding(type) != nullptr;
96+
});
97+
}
98+
9299
/// Returns true iff MLIR operand has any sparse operand.
93100
inline bool hasAnySparseOperand(Operation *op) {
94-
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
95-
return getSparseTensorEncoding(t) != nullptr;
96-
});
101+
return hasAnySparseType(op->getOperands().getTypes());
97102
}
98103

99104
/// Returns true iff MLIR operand has any sparse result.
100105
inline bool hasAnySparseResult(Operation *op) {
101-
return llvm::any_of(op->getResults().getTypes(), [](Type t) {
102-
return getSparseTensorEncoding(t) != nullptr;
103-
});
106+
return hasAnySparseType(op->getResults().getTypes());
104107
}
105108

106109
/// Returns true iff MLIR operand has any sparse operand or result.

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
289289
}
290290
};
291291

292+
/// Rewriting rule that fuses sparse_tensor.convert into producer.
293+
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294+
public:
295+
using OpRewritePattern::OpRewritePattern;
296+
297+
LogicalResult matchAndRewrite(ConvertOp op,
298+
PatternRewriter &rewriter) const override {
299+
auto producer = op.getSource().getDefiningOp<GenericOp>();
300+
if (!producer || producer.getDpsInits().size() != 1 ||
301+
!isMaterializing(producer.getDpsInitOperand(0), false) ||
302+
!producer.getResult(0).hasOneUse()) {
303+
return failure();
304+
}
305+
rewriter.modifyOpInPlace(producer, [&]() {
306+
producer.getResult(0).setType(op.getResult().getType());
307+
});
308+
309+
Operation *materializeOp =
310+
producer.getDpsInitOperand(0)->get().getDefiningOp();
311+
312+
rewriter.modifyOpInPlace(materializeOp, [&]() {
313+
materializeOp->getResult(0).setType(op.getResult().getType());
314+
});
315+
316+
rewriter.replaceAllOpUsesWith(op, producer);
317+
op->erase();
318+
319+
return success();
320+
}
321+
};
322+
292323
/// Rewriting rule that converts direct yield of zero with initial allocation.
293324
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
294325
public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
15061537
//===---------------------------------------------------------------------===//
15071538

15081539
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1509-
patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
1510-
FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
1511-
GenSemiRingSelect, PrintRewriter>(patterns.getContext());
1540+
patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1541+
FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1542+
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1543+
patterns.getContext());
15121544
}
15131545

15141546
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
403403
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
404404
}
405405

406+
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
407+
Value sparseOut, ValueRange ivs, Value v) {
408+
scf::IfOp condInsert =
409+
builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
410+
// True branch.
411+
builder.setInsertionPointToStart(condInsert.thenBlock());
412+
Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
413+
builder.create<scf::YieldOp>(loc, res);
414+
// False branch.
415+
builder.setInsertionPointToStart(condInsert.elseBlock());
416+
builder.create<scf::YieldOp>(loc, sparseOut);
417+
// Value assignment.
418+
builder.setInsertionPointAfter(condInsert);
419+
return condInsert.getResult(0);
420+
}
421+
406422
/// Generates insertion code to implement dynamic tensor store.
407423
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
408424
Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
423439
// return updated chain
424440
// else
425441
// return unmodified chain
426-
scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
427-
loc, chain.getType(), env.getValidLexInsert(),
428-
/*else=*/true);
429-
// True branch.
430-
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
431-
Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
432-
builder.create<scf::YieldOp>(loc, res);
433-
// False branch.
434-
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
435-
builder.create<scf::YieldOp>(loc, chain);
436-
// Value assignment.
437-
builder.setInsertionPointAfter(ifValidLexInsert);
438-
env.updateInsertionChain(ifValidLexInsert.getResult(0));
442+
Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
443+
chain, ivs, rhs);
444+
env.updateInsertionChain(out);
439445
} else {
446+
Value sparseOut;
447+
if (!hasAnySparseType(env.op().getInputs().getTypes())) {
448+
// This is an all-dense -> sparse kernel, test rhs != 0 before
449+
// insertion.
450+
Value nz = genIsNonzero(builder, loc, rhs);
451+
sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
452+
} else {
453+
sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
454+
}
440455
// Generates regular insertion chain.
441-
env.updateInsertionChain(
442-
builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
456+
env.updateInsertionChain(sparseOut);
443457
}
444458
return;
445459
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map | FileCheck %s --check-prefix=CHECK-FOLD
2+
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
3+
4+
#trait = {
5+
indexing_maps = [
6+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
7+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
8+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
9+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
10+
],
11+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
12+
}
13+
14+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
15+
16+
#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
17+
#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
18+
19+
// CHECK-LABEL: func.func @fold_convert(
20+
// CHECK: scf.for
21+
// CHECK: scf.for
22+
// CHECK: scf.for
23+
// CHECK: scf.if
24+
// CHECK-NEXT: tensor.insert
25+
// CHECK-NEXT: scf.yield
26+
// CHECK-NEXT: else
27+
// CHECK-NEXT: scf.yield
28+
// CHECK: scf.yield
29+
// CHECK: scf.yield
30+
// CHECK: scf.yield
31+
// CHECK: sparse_tensor.load
32+
33+
// CHECK-FOLD-LABEL: func.func @fold_convert(
34+
// CHECK-FOLD-NOT: sparse_tensor.convert
35+
func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
36+
%cst = arith.constant 0.000000e+00 : f32
37+
%cst_0 = arith.constant 1.000000e+00 : f32
38+
%cst_1 = arith.constant 1.000000e+00 : f32
39+
%0 = tensor.empty() : tensor<128x32x32x1xf32>
40+
%1 = linalg.generic #trait
41+
ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
42+
outs(%0 : tensor<128x32x32x1xf32>) {
43+
^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
44+
%3 = arith.subf %cst_0, %in_2 : f32
45+
%4 = arith.mulf %in, %3 : f32
46+
%5 = arith.mulf %4, %cst_1 : f32
47+
%6 = arith.addf %5, %in_3 : f32
48+
%7 = arith.subf %6, %cst_0 : f32
49+
%8 = arith.cmpf uge, %7, %cst : f32
50+
%9 = arith.uitofp %8 : i1 to f32
51+
linalg.yield %9 : f32
52+
} -> tensor<128x32x32x1xf32>
53+
%2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
54+
return %2 : tensor<128x32x32x1xf32, #CCCD>
55+
}
56+
57+
58+
// FIXME: The following kernel is not sparsifiable because `arith.select`
59+
// operations is not handled by the sparse compiler at the moment.
60+
//
61+
// CHECK-FOLD-LABEL: func.func @fold_cast(
62+
// CHECK-FOLD-NOT: sparse_tensor.convert
63+
func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
64+
%cst = arith.constant 0.000000e+00 : f64
65+
%1 = tensor.empty() : tensor<10x20x30xf64>
66+
%2 = linalg.generic { indexing_maps = [#map, #map],
67+
iterator_types = ["parallel", "parallel", "parallel"]
68+
}
69+
ins (%0 : tensor<10x20x30xf64, #COO>)
70+
outs(%1 : tensor<10x20x30xf64>) {
71+
^bb0(%in: f64, %out: f64):
72+
%4 = arith.cmpf ugt, %in, %cst : f64
73+
%5 = arith.select %4, %in, %cst : f64
74+
linalg.yield %5 : f64
75+
} -> tensor<10x20x30xf64>
76+
%cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
77+
return %cast : tensor<10x20x30xf64, #COO>
78+
}

mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)