@@ -44,6 +44,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
44
44
case TensorExp::Kind::kExpm1C :
45
45
case TensorExp::Kind::kLog1pF :
46
46
case TensorExp::Kind::kLog1pC :
47
+ case TensorExp::Kind::kRelu :
47
48
case TensorExp::Kind::kSinF :
48
49
case TensorExp::Kind::kSinC :
49
50
case TensorExp::Kind::kTanhF :
@@ -104,7 +105,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
104
105
105
106
TensorExp::TensorExp (TensorExp::Kind k, unsigned x, ExprId y, Value v,
106
107
Operation *o, Attribute a)
107
- : kind(k), val(v), op(o) {
108
+ : kind(k), val(v), op(o), attr(a) {
108
109
switch (kind) {
109
110
// Leaf.
110
111
case TensorExp::Kind::kTensor :
@@ -133,6 +134,7 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
133
134
case TensorExp::Kind::kExpm1C :
134
135
case TensorExp::Kind::kLog1pF :
135
136
case TensorExp::Kind::kLog1pC :
137
+ case TensorExp::Kind::kRelu :
136
138
case TensorExp::Kind::kSinF :
137
139
case TensorExp::Kind::kSinC :
138
140
case TensorExp::Kind::kTanhF :
@@ -201,7 +203,6 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
201
203
case TensorExp::Kind::kCmpF :
202
204
case TensorExp::Kind::kCmpI :
203
205
assert (x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204
- attr = a;
205
206
children.e0 = x;
206
207
children.e1 = y;
207
208
return ;
@@ -337,7 +338,6 @@ LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
337
338
LatSetId Merger::disjSet (ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
338
339
const LatSetId sNew = conjSet (e, s0, s1, op);
339
340
TensorExp::Kind kind = exp (e).kind ;
340
-
341
341
// Followed by all in s0.
342
342
latSets[sNew ].append (latSets[s0]);
343
343
// Map binary 0-y to unary -y.
@@ -381,31 +381,32 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381
381
bool includeLeft, TensorExp::Kind ltrans,
382
382
Operation *opleft, bool includeRight,
383
383
TensorExp::Kind rtrans, Operation *opright) {
384
+ Attribute a = exp (e).attr ;
384
385
const LatSetId sNew = conjSet (e, s0, s1, orig);
385
386
// Left Region.
386
387
if (includeLeft) {
387
388
if (opleft)
388
- s0 = mapSet (ltrans, s0, Value (), opleft);
389
+ s0 = mapSet (ltrans, s0, Value (), opleft, a );
389
390
latSets[sNew ].append (latSets[s0]);
390
391
}
391
392
// Right Region.
392
393
if (includeRight) {
393
394
if (opright)
394
- s1 = mapSet (rtrans, s1, Value (), opright);
395
+ s1 = mapSet (rtrans, s1, Value (), opright, a );
395
396
latSets[sNew ].append (latSets[s1]);
396
397
}
397
398
return sNew ;
398
399
}
399
400
400
401
LatSetId Merger::mapSet (TensorExp::Kind kind, LatSetId s0, Value v,
401
- Operation *op) {
402
+ Operation *op, Attribute a ) {
402
403
assert ((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect ) ||
403
404
TensorExp::Kind::kDenseOp == kind);
404
405
const LatSetId sNew = addSet ();
405
406
auto &setNew = latSets[sNew ];
406
407
for (const LatPointId p : set (s0)) {
407
408
const auto &point = latPoints[p];
408
- setNew.push_back (addLat (point.bits , addExp (kind, point.exp , v, op)));
409
+ setNew.push_back (addLat (point.bits , addExp (kind, point.exp , v, op, a )));
409
410
}
410
411
return sNew ;
411
412
}
@@ -596,6 +597,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
596
597
case TensorExp::Kind::kExpm1C :
597
598
case TensorExp::Kind::kLog1pF :
598
599
case TensorExp::Kind::kLog1pC :
600
+ case TensorExp::Kind::kRelu :
599
601
case TensorExp::Kind::kSinF :
600
602
case TensorExp::Kind::kSinC :
601
603
case TensorExp::Kind::kTanhF :
@@ -717,6 +719,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
717
719
case TensorExp::Kind::kLog1pF :
718
720
case TensorExp::Kind::kLog1pC :
719
721
return " log1p" ;
722
+ case TensorExp::Kind::kRelu :
723
+ return " relu" ;
720
724
case TensorExp::Kind::kSinF :
721
725
case TensorExp::Kind::kSinC :
722
726
return " sin" ;
@@ -824,6 +828,7 @@ void Merger::dumpExp(ExprId e) const {
824
828
case TensorExp::Kind::kExpm1C :
825
829
case TensorExp::Kind::kLog1pF :
826
830
case TensorExp::Kind::kLog1pC :
831
+ case TensorExp::Kind::kRelu :
827
832
case TensorExp::Kind::kSinF :
828
833
case TensorExp::Kind::kSinC :
829
834
case TensorExp::Kind::kTanhF :
@@ -972,6 +977,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
972
977
case TensorExp::Kind::kExpm1C :
973
978
case TensorExp::Kind::kLog1pF :
974
979
case TensorExp::Kind::kLog1pC :
980
+ case TensorExp::Kind::kRelu :
975
981
case TensorExp::Kind::kSinF :
976
982
case TensorExp::Kind::kSinC :
977
983
case TensorExp::Kind::kTanhF :
@@ -1001,7 +1007,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
1001
1007
{
1002
1008
const ExprId e0 = expr.children .e0 ;
1003
1009
const Value v = expr.val ;
1004
- return mapSet (kind, buildLattices (e0 , i), v);
1010
+ Attribute a = expr.attr ;
1011
+ return mapSet (kind, buildLattices (e0 , i), v, nullptr , a);
1005
1012
}
1006
1013
case TensorExp::Kind::kBinaryBranch :
1007
1014
case TensorExp::Kind::kSelect :
@@ -1190,10 +1197,26 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1190
1197
return buildTensorExp (op, yield->getOperand (0 )).first ;
1191
1198
}
1192
1199
1200
+ // / Only returns true if we are certain this is a zero.
1201
+ static bool isCertainZero (Value val) {
1202
+ if (auto c = val.getDefiningOp <complex::ConstantOp>()) {
1203
+ ArrayAttr arrayAttr = c.getValue ();
1204
+ return cast<FloatAttr>(arrayAttr[0 ]).getValue ().isZero () &&
1205
+ cast<FloatAttr>(arrayAttr[1 ]).getValue ().isZero ();
1206
+ }
1207
+ if (auto c = val.getDefiningOp <arith::ConstantIntOp>())
1208
+ return c.value () == 0 ;
1209
+ if (auto c = val.getDefiningOp <arith::ConstantFloatOp>())
1210
+ return c.value ().isZero ();
1211
+ return false ;
1212
+ }
1213
+
1193
1214
// / Only returns false if we are certain this is a nonzero.
1194
1215
bool Merger::maybeZero (ExprId e) const {
1195
1216
const auto &expr = exp (e);
1196
1217
if (expr.kind == TensorExp::Kind::kInvariant ) {
1218
+ // Note that this is different from isCertainZero() in a subtle
1219
+ // way by always returning true for non-constants.
1197
1220
if (auto c = expr.val .getDefiningOp <complex::ConstantOp>()) {
1198
1221
ArrayAttr arrayAttr = c.getValue ();
1199
1222
return cast<FloatAttr>(arrayAttr[0 ]).getValue ().isZero () &&
@@ -1247,6 +1270,21 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
1247
1270
return isAdmissibleBranchExp (op, ®ion.front (), yield->getOperand (0 ));
1248
1271
}
1249
1272
1273
+ // Recognizes a direct GT comparison.
1274
+ static bool isGreater (TensorExp::Kind kind, Attribute attr) {
1275
+ if (kind == TensorExp::Kind::kCmpI ) {
1276
+ auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue ();
1277
+ return pred == arith::CmpIPredicate::ugt ||
1278
+ pred == arith::CmpIPredicate::sgt;
1279
+ }
1280
+ if (kind == TensorExp::Kind::kCmpF ) {
1281
+ auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue ();
1282
+ return pred == arith::CmpFPredicate::UGT ||
1283
+ pred == arith::CmpFPredicate::OGT;
1284
+ }
1285
+ return false ;
1286
+ }
1287
+
1250
1288
std::pair<std::optional<ExprId>, bool >
1251
1289
Merger::buildTensorExp (linalg::GenericOp op, Value v) {
1252
1290
// Recursion leaves.
@@ -1266,6 +1304,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1266
1304
// or belonging to an enveloping op) is considered invariant.
1267
1305
return {addInvariantExp (v), /* hasSpDep=*/ false };
1268
1306
}
1307
+
1269
1308
// Something defined outside is invariant.
1270
1309
Operation *def = v.getDefiningOp ();
1271
1310
if (def->getBlock () != &op.getRegion ().front ())
@@ -1352,6 +1391,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1352
1391
}
1353
1392
}
1354
1393
}
1394
+
1355
1395
// Construct binary operations if subexpressions can be built.
1356
1396
// See buildLattices() for an explanation of rejecting certain
1357
1397
// division and shift operations.
@@ -1447,6 +1487,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1447
1487
}
1448
1488
}
1449
1489
}
1490
+
1450
1491
// Construct ternary operations if subexpressions can be built.
1451
1492
if (def->getNumOperands () == 3 ) {
1452
1493
const auto [x, xDepSp] = buildTensorExp (op, def->getOperand (0 ));
@@ -1460,6 +1501,26 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1460
1501
if (isAdmissibleBranch (redop, redop.getRegion ()))
1461
1502
return {addExp (TensorExp::Kind::kReduce , e0 , e1 , def), hasSpDep};
1462
1503
}
1504
+ if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505
+ // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506
+ // operation inside a very specific ternary select operation.
1507
+ // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508
+ const auto &cnd = exp (*x);
1509
+ if (isGreater (cnd.kind , cnd.attr ) &&
1510
+ exp (*y).kind == TensorExp::Kind::kTensor &&
1511
+ exp (*z).kind == TensorExp::Kind::kInvariant &&
1512
+ isCertainZero (exp (*z).val )) {
1513
+ const auto &a = exp (cnd.children .e0 );
1514
+ const auto &b = exp (cnd.children .e1 );
1515
+ if (a.kind == TensorExp::Kind::kTensor &&
1516
+ a.tensor == exp (*y).tensor &&
1517
+ b.kind == TensorExp::Kind::kInvariant && isCertainZero (b.val )) {
1518
+ return {addExp (TensorExp::Kind::kRelu , *y, detail::kInvalidId ,
1519
+ nullptr , cnd.attr ),
1520
+ yDepSp};
1521
+ }
1522
+ }
1523
+ }
1463
1524
}
1464
1525
}
1465
1526
@@ -1469,7 +1530,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1469
1530
// tensors).
1470
1531
if (def->getNumResults () != 1 ) // only handle single result operation.
1471
1532
return {std::nullopt, false };
1472
-
1473
1533
SmallVector<std::pair<std::optional<ExprId>, bool >, 2 > subExp;
1474
1534
// Builds all the sub-expressions
1475
1535
for (Value operand : def->getOperands ())
@@ -1489,6 +1549,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1489
1549
return {e, false };
1490
1550
}
1491
1551
}
1552
+
1492
1553
// Cannot build.
1493
1554
return {std::nullopt, false };
1494
1555
}
@@ -1538,6 +1599,22 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1538
1599
return insertYieldOp (rewriter, loc, overlapRegion, {v0, v1});
1539
1600
}
1540
1601
1602
+ static Value buildRelu (RewriterBase &rewriter, Location loc, Value v0,
1603
+ Attribute attr) {
1604
+ Type tp = v0.getType ();
1605
+ auto zero =
1606
+ rewriter.create <arith::ConstantOp>(loc, tp, rewriter.getZeroAttr (tp));
1607
+ Value cmp;
1608
+ if (isa<FloatType>(tp)) {
1609
+ auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610
+ cmp = rewriter.create <arith::CmpFOp>(loc, pred, v0, zero);
1611
+ } else {
1612
+ auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613
+ cmp = rewriter.create <arith::CmpIOp>(loc, pred, v0, zero);
1614
+ }
1615
+ return rewriter.create <arith::SelectOp>(loc, cmp, v0, zero);
1616
+ }
1617
+
1541
1618
Value Merger::buildExp (RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1542
1619
Value v1) const {
1543
1620
const auto &expr = exp (e);
@@ -1574,6 +1651,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1574
1651
return rewriter.create <math::Log1pOp>(loc, v0);
1575
1652
case TensorExp::Kind::kLog1pC :
1576
1653
return rewriter.create <complex::Log1pOp>(loc, v0);
1654
+ case TensorExp::Kind::kRelu :
1655
+ return buildRelu (rewriter, loc, v0, expr.attr );
1577
1656
case TensorExp::Kind::kSinF :
1578
1657
return rewriter.create <math::SinOp>(loc, v0);
1579
1658
case TensorExp::Kind::kSinC :
@@ -1677,7 +1756,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1677
1756
case TensorExp::Kind::kUnary :
1678
1757
return buildUnaryPresent (rewriter, loc, expr.op , v0);
1679
1758
case TensorExp::Kind::kSelect :
1680
- return insertYieldOp (rewriter, loc, cast<SelectOp>(expr.op ).getRegion (),
1759
+ return insertYieldOp (rewriter, loc,
1760
+ cast<sparse_tensor::SelectOp>(expr.op ).getRegion (),
1681
1761
{v0});
1682
1762
case TensorExp::Kind::kBinary :
1683
1763
return buildBinaryOverlap (rewriter, loc, expr.op , v0, v1);
0 commit comments