Skip to content

Commit 70e227a

Browse files
authored
[mlir][sparse] recognize ReLu operation during sparsification (#92016)
This is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back). See discussion at https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918
1 parent ef9090f commit 70e227a

File tree

4 files changed

+127
-11
lines changed

4 files changed

+127
-11
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ enum class TensorExp::Kind {
144144
kExpm1C,
145145
kLog1pF,
146146
kLog1pC,
147+
kRelu,
147148
kSinF,
148149
kSinC,
149150
kTanhF,
@@ -316,7 +317,7 @@ class Merger {
316317
/// lattice point on an expression E is simply copied over, but with OP E
317318
/// as new expression. Returns the identifier of the new set.
318319
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
319-
Operation *op = nullptr);
320+
Operation *op = nullptr, Attribute attr = nullptr);
320321

321322
/// Maps the binary operator to the same operation but with one of its operand
322323
/// set to zero, i.e. each lattice point on an expression E is simply copied

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
4444
case TensorExp::Kind::kExpm1C:
4545
case TensorExp::Kind::kLog1pF:
4646
case TensorExp::Kind::kLog1pC:
47+
case TensorExp::Kind::kRelu:
4748
case TensorExp::Kind::kSinF:
4849
case TensorExp::Kind::kSinC:
4950
case TensorExp::Kind::kTanhF:
@@ -104,7 +105,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
104105

105106
TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
106107
Operation *o, Attribute a)
107-
: kind(k), val(v), op(o) {
108+
: kind(k), val(v), op(o), attr(a) {
108109
switch (kind) {
109110
// Leaf.
110111
case TensorExp::Kind::kTensor:
@@ -133,6 +134,7 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
133134
case TensorExp::Kind::kExpm1C:
134135
case TensorExp::Kind::kLog1pF:
135136
case TensorExp::Kind::kLog1pC:
137+
case TensorExp::Kind::kRelu:
136138
case TensorExp::Kind::kSinF:
137139
case TensorExp::Kind::kSinC:
138140
case TensorExp::Kind::kTanhF:
@@ -201,7 +203,6 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
201203
case TensorExp::Kind::kCmpF:
202204
case TensorExp::Kind::kCmpI:
203205
assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204-
attr = a;
205206
children.e0 = x;
206207
children.e1 = y;
207208
return;
@@ -337,7 +338,6 @@ LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
337338
LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
338339
const LatSetId sNew = conjSet(e, s0, s1, op);
339340
TensorExp::Kind kind = exp(e).kind;
340-
341341
// Followed by all in s0.
342342
latSets[sNew].append(latSets[s0]);
343343
// Map binary 0-y to unary -y.
@@ -381,31 +381,32 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381381
bool includeLeft, TensorExp::Kind ltrans,
382382
Operation *opleft, bool includeRight,
383383
TensorExp::Kind rtrans, Operation *opright) {
384+
Attribute a = exp(e).attr;
384385
const LatSetId sNew = conjSet(e, s0, s1, orig);
385386
// Left Region.
386387
if (includeLeft) {
387388
if (opleft)
388-
s0 = mapSet(ltrans, s0, Value(), opleft);
389+
s0 = mapSet(ltrans, s0, Value(), opleft, a);
389390
latSets[sNew].append(latSets[s0]);
390391
}
391392
// Right Region.
392393
if (includeRight) {
393394
if (opright)
394-
s1 = mapSet(rtrans, s1, Value(), opright);
395+
s1 = mapSet(rtrans, s1, Value(), opright, a);
395396
latSets[sNew].append(latSets[s1]);
396397
}
397398
return sNew;
398399
}
399400

400401
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
401-
Operation *op) {
402+
Operation *op, Attribute a) {
402403
assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
403404
TensorExp::Kind::kDenseOp == kind);
404405
const LatSetId sNew = addSet();
405406
auto &setNew = latSets[sNew];
406407
for (const LatPointId p : set(s0)) {
407408
const auto &point = latPoints[p];
408-
setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
409+
setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
409410
}
410411
return sNew;
411412
}
@@ -596,6 +597,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
596597
case TensorExp::Kind::kExpm1C:
597598
case TensorExp::Kind::kLog1pF:
598599
case TensorExp::Kind::kLog1pC:
600+
case TensorExp::Kind::kRelu:
599601
case TensorExp::Kind::kSinF:
600602
case TensorExp::Kind::kSinC:
601603
case TensorExp::Kind::kTanhF:
@@ -717,6 +719,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
717719
case TensorExp::Kind::kLog1pF:
718720
case TensorExp::Kind::kLog1pC:
719721
return "log1p";
722+
case TensorExp::Kind::kRelu:
723+
return "relu";
720724
case TensorExp::Kind::kSinF:
721725
case TensorExp::Kind::kSinC:
722726
return "sin";
@@ -824,6 +828,7 @@ void Merger::dumpExp(ExprId e) const {
824828
case TensorExp::Kind::kExpm1C:
825829
case TensorExp::Kind::kLog1pF:
826830
case TensorExp::Kind::kLog1pC:
831+
case TensorExp::Kind::kRelu:
827832
case TensorExp::Kind::kSinF:
828833
case TensorExp::Kind::kSinC:
829834
case TensorExp::Kind::kTanhF:
@@ -972,6 +977,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
972977
case TensorExp::Kind::kExpm1C:
973978
case TensorExp::Kind::kLog1pF:
974979
case TensorExp::Kind::kLog1pC:
980+
case TensorExp::Kind::kRelu:
975981
case TensorExp::Kind::kSinF:
976982
case TensorExp::Kind::kSinC:
977983
case TensorExp::Kind::kTanhF:
@@ -1001,7 +1007,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
10011007
{
10021008
const ExprId e0 = expr.children.e0;
10031009
const Value v = expr.val;
1004-
return mapSet(kind, buildLattices(e0, i), v);
1010+
Attribute a = expr.attr;
1011+
return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
10051012
}
10061013
case TensorExp::Kind::kBinaryBranch:
10071014
case TensorExp::Kind::kSelect:
@@ -1190,10 +1197,26 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
11901197
return buildTensorExp(op, yield->getOperand(0)).first;
11911198
}
11921199

1200+
/// Only returns true if we are certain this is a zero.
1201+
static bool isCertainZero(Value val) {
1202+
if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
1203+
ArrayAttr arrayAttr = c.getValue();
1204+
return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205+
cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1206+
}
1207+
if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
1208+
return c.value() == 0;
1209+
if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
1210+
return c.value().isZero();
1211+
return false;
1212+
}
1213+
11931214
/// Only returns false if we are certain this is a nonzero.
11941215
bool Merger::maybeZero(ExprId e) const {
11951216
const auto &expr = exp(e);
11961217
if (expr.kind == TensorExp::Kind::kInvariant) {
1218+
// Note that this is different from isCertainZero() in a subtle
1219+
// way by always returning true for non-constants.
11971220
if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
11981221
ArrayAttr arrayAttr = c.getValue();
11991222
return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
@@ -1247,6 +1270,21 @@ static bool isAdmissibleBranch(Operation *op, Region &region) {
12471270
return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
12481271
}
12491272

1273+
// Recognizes a direct GT comparison.
1274+
static bool isGreater(TensorExp::Kind kind, Attribute attr) {
1275+
if (kind == TensorExp::Kind::kCmpI) {
1276+
auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277+
return pred == arith::CmpIPredicate::ugt ||
1278+
pred == arith::CmpIPredicate::sgt;
1279+
}
1280+
if (kind == TensorExp::Kind::kCmpF) {
1281+
auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282+
return pred == arith::CmpFPredicate::UGT ||
1283+
pred == arith::CmpFPredicate::OGT;
1284+
}
1285+
return false;
1286+
}
1287+
12501288
std::pair<std::optional<ExprId>, bool>
12511289
Merger::buildTensorExp(linalg::GenericOp op, Value v) {
12521290
// Recursion leaves.
@@ -1266,6 +1304,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
12661304
// or belonging to an enveloping op) is considered invariant.
12671305
return {addInvariantExp(v), /*hasSpDep=*/false};
12681306
}
1307+
12691308
// Something defined outside is invariant.
12701309
Operation *def = v.getDefiningOp();
12711310
if (def->getBlock() != &op.getRegion().front())
@@ -1352,6 +1391,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
13521391
}
13531392
}
13541393
}
1394+
13551395
// Construct binary operations if subexpressions can be built.
13561396
// See buildLattices() for an explanation of rejecting certain
13571397
// division and shift operations.
@@ -1447,6 +1487,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14471487
}
14481488
}
14491489
}
1490+
14501491
// Construct ternary operations if subexpressions can be built.
14511492
if (def->getNumOperands() == 3) {
14521493
const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
@@ -1460,6 +1501,26 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14601501
if (isAdmissibleBranch(redop, redop.getRegion()))
14611502
return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
14621503
}
1504+
if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505+
// Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506+
// operation inside a very specific ternary select operation.
1507+
// TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508+
const auto &cnd = exp(*x);
1509+
if (isGreater(cnd.kind, cnd.attr) &&
1510+
exp(*y).kind == TensorExp::Kind::kTensor &&
1511+
exp(*z).kind == TensorExp::Kind::kInvariant &&
1512+
isCertainZero(exp(*z).val)) {
1513+
const auto &a = exp(cnd.children.e0);
1514+
const auto &b = exp(cnd.children.e1);
1515+
if (a.kind == TensorExp::Kind::kTensor &&
1516+
a.tensor == exp(*y).tensor &&
1517+
b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
1518+
return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
1519+
nullptr, cnd.attr),
1520+
yDepSp};
1521+
}
1522+
}
1523+
}
14631524
}
14641525
}
14651526

@@ -1469,7 +1530,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14691530
// tensors).
14701531
if (def->getNumResults() != 1) // only handle single result operation.
14711532
return {std::nullopt, false};
1472-
14731533
SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
14741534
// Builds all the sub-expressions
14751535
for (Value operand : def->getOperands())
@@ -1489,6 +1549,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14891549
return {e, false};
14901550
}
14911551
}
1552+
14921553
// Cannot build.
14931554
return {std::nullopt, false};
14941555
}
@@ -1538,6 +1599,22 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
15381599
return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
15391600
}
15401601

1602+
static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
1603+
Attribute attr) {
1604+
Type tp = v0.getType();
1605+
auto zero =
1606+
rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
1607+
Value cmp;
1608+
if (isa<FloatType>(tp)) {
1609+
auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610+
cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
1611+
} else {
1612+
auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613+
cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
1614+
}
1615+
return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
1616+
}
1617+
15411618
Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
15421619
Value v1) const {
15431620
const auto &expr = exp(e);
@@ -1574,6 +1651,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
15741651
return rewriter.create<math::Log1pOp>(loc, v0);
15751652
case TensorExp::Kind::kLog1pC:
15761653
return rewriter.create<complex::Log1pOp>(loc, v0);
1654+
case TensorExp::Kind::kRelu:
1655+
return buildRelu(rewriter, loc, v0, expr.attr);
15771656
case TensorExp::Kind::kSinF:
15781657
return rewriter.create<math::SinOp>(loc, v0);
15791658
case TensorExp::Kind::kSinC:
@@ -1677,7 +1756,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
16771756
case TensorExp::Kind::kUnary:
16781757
return buildUnaryPresent(rewriter, loc, expr.op, v0);
16791758
case TensorExp::Kind::kSelect:
1680-
return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1759+
return insertYieldOp(rewriter, loc,
1760+
cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
16811761
{v0});
16821762
case TensorExp::Kind::kBinary:
16831763
return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s
2+
3+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
5+
#sparse = #sparse_tensor.encoding<{
6+
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
7+
}>
8+
9+
//
10+
// Make sure a simple ReLU passes the sparsifier
11+
//
12+
// CHECK-LABEL: func.func @relu
13+
// CHECK: scf.for
14+
// CHECK: scf.for
15+
// CHECK: scf.for
16+
// CHECK: arith.cmpf ugt
17+
// CHECK: arith.select
18+
//
19+
func.func @relu(%arg0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
20+
%cst = arith.constant 0.000000e+00 : f64
21+
%0 = tensor.empty() : tensor<10x20x30xf64>
22+
%1 = linalg.generic {
23+
indexing_maps = [#map, #map],
24+
iterator_types = ["parallel", "parallel", "parallel"]}
25+
ins(%arg0 : tensor<10x20x30xf64, #sparse>)
26+
outs(%0 : tensor<10x20x30xf64>) {
27+
^bb0(%in: f64, %out: f64):
28+
%2 = arith.cmpf ugt, %in, %cst : f64
29+
%3 = arith.select %2, %in, %cst : f64
30+
linalg.yield %3 : f64
31+
} -> tensor<10x20x30xf64>
32+
%cast = tensor.cast %1 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
33+
return %cast : tensor<10x20x30xf64, #sparse>
34+
}

mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
236236
case TensorExp::Kind::kExpm1C:
237237
case TensorExp::Kind::kLog1pF:
238238
case TensorExp::Kind::kLog1pC:
239+
case TensorExp::Kind::kRelu:
239240
case TensorExp::Kind::kSinF:
240241
case TensorExp::Kind::kSinC:
241242
case TensorExp::Kind::kTanhF:

0 commit comments

Comments
 (0)