Skip to content

Commit d977624

Browse files
committed
ConstantFolding: add constant folding for some floating point intrinsics
* `sitofp` signed integer to floating point * `rint` round floating point to integral * `bitcast` between integer and floating point Constant folding `bitcast`s also made it necessary to rewrite constant folding for Nan and inf values, because the old code explicitly checked for `bitcast` intrinsics. Relying on constant folded `bitcast`s makes the new version much simpler. It is important to constant fold these intrinsics already in SIL because it enables other optimizations.
1 parent 17a1b83 commit d977624

File tree

3 files changed

+137
-308
lines changed

3 files changed

+137
-308
lines changed

lib/SILOptimizer/Utils/ConstantFolding.cpp

Lines changed: 68 additions & 289 deletions
Original file line numberDiff line numberDiff line change
@@ -378,44 +378,26 @@ static SILValue constantFoldIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID,
378378
return constantFoldBinaryWithOverflow(BI, ID,
379379
/* ReportOverflow */ false,
380380
ResultsInError);
381+
case llvm::Intrinsic::rint:
382+
if (auto *floatLiteral = dyn_cast<FloatLiteralInst>(BI->getArguments()[0])) {
383+
SILBuilderWithScope builder(BI);
384+
APFloat result = floatLiteral->getValue();
385+
// The following code is taken from LLVM's constant folder.
386+
result.roundToIntegral(APFloat::rmNearestTiesToEven);
387+
return builder.createFloatLiteral(BI->getLoc(), BI->getType(), result);
388+
}
381389
}
382390
return nullptr;
383391
}
384392

385-
static bool isFiniteFloatLiteral(SILValue v) {
386-
if (auto *lit = dyn_cast<FloatLiteralInst>(v)) {
387-
return lit->getValue().isFinite();
393+
static bool isNanLiteral(SILValue v) {
394+
if (auto *literal = dyn_cast<FloatLiteralInst>(v)) {
395+
return literal->getValue().isNaN();
388396
}
389397
return false;
390398
}
391399

392400
static SILValue constantFoldCompareFloat(BuiltinInst *BI, BuiltinValueKind ID) {
393-
static auto hasIEEEFloatNanBitRepr = [](const APInt val) -> bool {
394-
auto bitWidth = val.getBitWidth();
395-
if (bitWidth == 32) {
396-
APInt nanBitRepr =
397-
APFloat::getNaN(llvm::APFloatBase::IEEEsingle()).bitcastToAPInt();
398-
return bitWidth == nanBitRepr.getBitWidth() && val == nanBitRepr;
399-
} else {
400-
APInt nanBitRepr =
401-
APFloat::getNaN(llvm::APFloatBase::IEEEdouble()).bitcastToAPInt();
402-
return bitWidth == nanBitRepr.getBitWidth() && val == nanBitRepr;
403-
}
404-
};
405-
406-
static auto hasIEEEFloatPosInfBitRepr = [](const APInt val) -> bool {
407-
auto bitWidth = val.getBitWidth();
408-
if (bitWidth == 32) {
409-
APInt infBitRepr =
410-
APFloat::getInf(llvm::APFloatBase::IEEEsingle()).bitcastToAPInt();
411-
return bitWidth == infBitRepr.getBitWidth() && val == infBitRepr;
412-
} else {
413-
APInt infBitRepr =
414-
APFloat::getInf(llvm::APFloatBase::IEEEdouble()).bitcastToAPInt();
415-
return bitWidth == infBitRepr.getBitWidth() && val == infBitRepr;
416-
}
417-
};
418-
419401
OperandValueArrayRef Args = BI->getArguments();
420402

421403
// Fold for floating point constant arguments.
@@ -428,270 +410,32 @@ static SILValue constantFoldCompareFloat(BuiltinInst *BI, BuiltinValueKind ID) {
428410
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), Res);
429411
}
430412

431-
using namespace swift::PatternMatch;
432-
433-
// Ordered comparisons with NaN always return false
434-
SILValue Other;
435-
IntegerLiteralInst *builtinArg;
436-
if (match(BI, m_CombineOr(
437-
// x == NaN
438-
m_BuiltinInst(BuiltinValueKind::FCMP_OEQ,
439-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
440-
// x == NaN
441-
m_BuiltinInst(BuiltinValueKind::FCMP_OGT,
442-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
443-
// x >= NaN
444-
m_BuiltinInst(BuiltinValueKind::FCMP_OGE,
445-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
446-
// x < NaN
447-
m_BuiltinInst(BuiltinValueKind::FCMP_OLT,
448-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
449-
// x <= NaN
450-
m_BuiltinInst(BuiltinValueKind::FCMP_OLE,
451-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
452-
// x != NaN
453-
m_BuiltinInst(BuiltinValueKind::FCMP_ONE,
454-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
455-
// NaN == x
456-
m_BuiltinInst(BuiltinValueKind::FCMP_OEQ,
457-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
458-
// NaN > x
459-
m_BuiltinInst(BuiltinValueKind::FCMP_OGT,
460-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
461-
// NaN >= x
462-
m_BuiltinInst(BuiltinValueKind::FCMP_OGE,
463-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
464-
// NaN < x
465-
m_BuiltinInst(BuiltinValueKind::FCMP_OLT,
466-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
467-
// NaN <= x
468-
m_BuiltinInst(BuiltinValueKind::FCMP_OLE,
469-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
470-
// NaN != x
471-
m_BuiltinInst(BuiltinValueKind::FCMP_ONE,
472-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other))))) {
473-
APInt val = builtinArg->getValue();
474-
if (hasIEEEFloatNanBitRepr(val)) {
475-
SILBuilderWithScope B(BI);
476-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
477-
} else {
478-
// An edge case where we're comparing NaN with another value
479-
// defined using the BitCast builtin instruction.
480-
//
481-
// In this case, the `builtinArg` capture does not actually represent the NaN
482-
// argument that we want. Therefore we need to pattern-match
483-
// the definition of the SILValue `Other`, to see if it represents a NaN.
484-
if (auto *bci = dyn_cast<BuiltinInst>(Other)) {
485-
if (bci->getBuiltinInfo().ID == BuiltinValueKind::BitCast) {
486-
if (auto *arg = dyn_cast<IntegerLiteralInst>(bci->getArguments()[0])) {
487-
if (hasIEEEFloatNanBitRepr(arg->getValue())) {
488-
SILBuilderWithScope B(BI);
489-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
490-
}
491-
}
492-
}
413+
if (isNanLiteral(Args[0]) || isNanLiteral(Args[1])) {
414+
switch (BI->getBuiltinInfo().ID) {
415+
// Ordered comparisons with NaN always return false
416+
case BuiltinValueKind::FCMP_OEQ: // ==
417+
case BuiltinValueKind::FCMP_OGT: // >=
418+
case BuiltinValueKind::FCMP_OGE: // <
419+
case BuiltinValueKind::FCMP_OLT: // <
420+
case BuiltinValueKind::FCMP_OLE: // <=
421+
case BuiltinValueKind::FCMP_ONE: { // !=
422+
SILBuilderWithScope B(BI);
423+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
493424
}
494-
}
495-
}
496-
497-
// Unordered comparisons with NaN always return true
498-
if (match(BI,
499-
m_CombineOr(
500-
// x == NaN
501-
m_BuiltinInst(BuiltinValueKind::FCMP_UEQ,
502-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
503-
// x == NaN
504-
m_BuiltinInst(BuiltinValueKind::FCMP_UGT,
505-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
506-
// x >= NaN
507-
m_BuiltinInst(BuiltinValueKind::FCMP_UGE,
508-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
509-
// x < NaN
510-
m_BuiltinInst(BuiltinValueKind::FCMP_ULT,
511-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
512-
// x <= NaN
513-
m_BuiltinInst(BuiltinValueKind::FCMP_ULE,
514-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
515-
// x != NaN
516-
m_BuiltinInst(BuiltinValueKind::FCMP_UNE,
517-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
518-
// NaN == x
519-
m_BuiltinInst(BuiltinValueKind::FCMP_UEQ,
520-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
521-
// NaN > x
522-
m_BuiltinInst(BuiltinValueKind::FCMP_UGT,
523-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
524-
// NaN >= x
525-
m_BuiltinInst(BuiltinValueKind::FCMP_UGE,
526-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
527-
// NaN < x
528-
m_BuiltinInst(BuiltinValueKind::FCMP_ULT,
529-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
530-
// NaN <= x
531-
m_BuiltinInst(BuiltinValueKind::FCMP_ULE,
532-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
533-
// NaN != x
534-
m_BuiltinInst(BuiltinValueKind::FCMP_UNE,
535-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other))))) {
536-
APInt val = builtinArg->getValue();
537-
if (hasIEEEFloatNanBitRepr(val)) {
538-
SILBuilderWithScope B(BI);
539-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
540-
} else {
541-
// An edge case where we're comparing NaN with another value
542-
// defined using the BitCast builtin instruction.
543-
//
544-
// In this case, the `builtinArg` capture does not actually represent the NaN
545-
// argument that we want. Therefore we need to pattern-match
546-
// the definition of the SILValue `Other`, to see if it represents a NaN.
547-
if (auto *bci = dyn_cast<BuiltinInst>(Other)) {
548-
if (bci->getBuiltinInfo().ID == BuiltinValueKind::BitCast) {
549-
if (auto *arg = dyn_cast<IntegerLiteralInst>(bci->getArguments()[0])) {
550-
if (hasIEEEFloatNanBitRepr(arg->getValue())) {
551-
SILBuilderWithScope B(BI);
552-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
553-
}
554-
}
555-
}
425+
// Unordered comparisons with NaN always return true
426+
case BuiltinValueKind::FCMP_UEQ: // ==
427+
case BuiltinValueKind::FCMP_UGT: // >=
428+
case BuiltinValueKind::FCMP_UGE: // <
429+
case BuiltinValueKind::FCMP_ULT: // <
430+
case BuiltinValueKind::FCMP_ULE: // <=
431+
case BuiltinValueKind::FCMP_UNE: { // !=
432+
SILBuilderWithScope B(BI);
433+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
556434
}
435+
default:
436+
break;
557437
}
558438
}
559-
560-
// Infinity is equal to, greater than equal to and less than equal to itself
561-
IntegerLiteralInst *inf1;
562-
IntegerLiteralInst *inf2;
563-
564-
if (match(BI,
565-
m_CombineOr(
566-
// Inf == Inf
567-
m_BuiltinInst(BuiltinValueKind::FCMP_OEQ,
568-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
569-
// Inf >= Inf
570-
m_BuiltinInst(BuiltinValueKind::FCMP_OGE,
571-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
572-
// Inf <= Inf
573-
m_BuiltinInst(BuiltinValueKind::FCMP_OLE,
574-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
575-
// Inf == Inf
576-
m_BuiltinInst(BuiltinValueKind::FCMP_UEQ,
577-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
578-
// Inf >= Inf
579-
m_BuiltinInst(BuiltinValueKind::FCMP_UGE,
580-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
581-
// Inf <= Inf
582-
m_BuiltinInst(BuiltinValueKind::FCMP_ULE,
583-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2)))))) {
584-
APInt val1 = inf1->getValue();
585-
APInt val2 = inf2->getValue();
586-
587-
if (hasIEEEFloatPosInfBitRepr(val1) && hasIEEEFloatPosInfBitRepr(val2)) {
588-
SILBuilderWithScope B(BI);
589-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
590-
}
591-
}
592-
593-
// Infinity cannot be unequal to, greater than or less than itself
594-
if (match(BI,
595-
m_CombineOr(
596-
// Inf != Inf
597-
m_BuiltinInst(BuiltinValueKind::FCMP_ONE,
598-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
599-
// Inf > Inf
600-
m_BuiltinInst(BuiltinValueKind::FCMP_OGT,
601-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
602-
// Inf < Inf
603-
m_BuiltinInst(BuiltinValueKind::FCMP_OLT,
604-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
605-
// Inf != Inf
606-
m_BuiltinInst(BuiltinValueKind::FCMP_UNE,
607-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
608-
// Inf > Inf
609-
m_BuiltinInst(BuiltinValueKind::FCMP_UGT,
610-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2))),
611-
// Inf < Inf
612-
m_BuiltinInst(BuiltinValueKind::FCMP_ULT,
613-
m_BitCast(m_IntegerLiteralInst(inf1)), m_BitCast(m_IntegerLiteralInst(inf2)))))) {
614-
APInt val1 = inf1->getValue();
615-
APInt val2 = inf2->getValue();
616-
617-
if (hasIEEEFloatPosInfBitRepr(val1) && hasIEEEFloatPosInfBitRepr(val2)) {
618-
SILBuilderWithScope B(BI);
619-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
620-
}
621-
}
622-
623-
// Everything is less than or less than equal to positive infinity
624-
if (match(BI,
625-
m_CombineOr(
626-
// Inf > x
627-
m_BuiltinInst(BuiltinValueKind::FCMP_OGT,
628-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
629-
// Inf >= x
630-
m_BuiltinInst(BuiltinValueKind::FCMP_OGE,
631-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
632-
// x < Inf
633-
m_BuiltinInst(BuiltinValueKind::FCMP_OLT,
634-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
635-
// x <= Inf
636-
m_BuiltinInst(BuiltinValueKind::FCMP_OLE,
637-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
638-
// Inf > x
639-
m_BuiltinInst(BuiltinValueKind::FCMP_UGT,
640-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
641-
// Inf >= x
642-
m_BuiltinInst(BuiltinValueKind::FCMP_UGE,
643-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
644-
// x < Inf
645-
m_BuiltinInst(BuiltinValueKind::FCMP_ULT,
646-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
647-
// x <= Inf
648-
m_BuiltinInst(BuiltinValueKind::FCMP_ULE,
649-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg)))))) {
650-
APInt val = builtinArg->getValue();
651-
if (hasIEEEFloatPosInfBitRepr(val) &&
652-
// Only if `Other` is a literal we can be sure that it's not Inf or NaN.
653-
isFiniteFloatLiteral(Other)) {
654-
SILBuilderWithScope B(BI);
655-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
656-
}
657-
}
658-
659-
// Positive infinity is not less than or less than equal to anything
660-
if (match(BI,
661-
m_CombineOr(
662-
// x > Inf
663-
m_BuiltinInst(BuiltinValueKind::FCMP_OGT,
664-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
665-
// x >= Inf
666-
m_BuiltinInst(BuiltinValueKind::FCMP_OGE,
667-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
668-
// Inf < x
669-
m_BuiltinInst(BuiltinValueKind::FCMP_OLT,
670-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
671-
// Inf <= x
672-
m_BuiltinInst(BuiltinValueKind::FCMP_OLE,
673-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
674-
// x > Inf
675-
m_BuiltinInst(BuiltinValueKind::FCMP_UGT,
676-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
677-
// x >= Inf
678-
m_BuiltinInst(BuiltinValueKind::FCMP_UGE,
679-
m_SILValue(Other), m_BitCast(m_IntegerLiteralInst(builtinArg))),
680-
// Inf < x
681-
m_BuiltinInst(BuiltinValueKind::FCMP_ULT,
682-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other)),
683-
// Inf <= x
684-
m_BuiltinInst(BuiltinValueKind::FCMP_ULE,
685-
m_BitCast(m_IntegerLiteralInst(builtinArg)), m_SILValue(Other))))) {
686-
APInt val = builtinArg->getValue();
687-
if (hasIEEEFloatPosInfBitRepr(val) &&
688-
// Only if `Other` is a literal we can be sure that it's not Inf or NaN.
689-
isFiniteFloatLiteral(Other)) {
690-
SILBuilderWithScope B(BI);
691-
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
692-
}
693-
}
694-
695439
return nullptr;
696440
}
697441

@@ -1561,6 +1305,41 @@ case BuiltinValueKind::id:
15611305
return constantFoldAndCheckIntegerConversions(BI, Builtin, ResultsInError);
15621306
}
15631307

1308+
case BuiltinValueKind::SIToFP: {
1309+
auto *intLiteral = dyn_cast<IntegerLiteralInst>(Args[0]);
1310+
if (!intLiteral)
1311+
return nullptr;
1312+
APInt api = intLiteral->getValue();
1313+
auto *destTy = Builtin.Types[1]->castTo<BuiltinFloatType>();
1314+
1315+
// The following code is taken from LLVM's constant folder.
1316+
APFloat apf(destTy->getAPFloatSemantics(),
1317+
APInt::getZero(destTy->getBitWidth()));
1318+
apf.convertFromAPInt(api, Builtin.ID==BuiltinValueKind::SIToFP,
1319+
APFloat::rmNearestTiesToEven);
1320+
1321+
SILBuilderWithScope B(BI);
1322+
return B.createFloatLiteral(BI->getLoc(), BI->getType(), apf);
1323+
}
1324+
1325+
case BuiltinValueKind::BitCast: {
1326+
auto destTy = Builtin.Types[1];
1327+
// The following code is taken from LLVM's constant folder.
1328+
if (auto *intLiteral = dyn_cast<IntegerLiteralInst>(Args[0])) {
1329+
if (auto *floatDestTy = destTy->getAs<BuiltinFloatType>()) {
1330+
SILBuilderWithScope B(BI);
1331+
return B.createFloatLiteral(BI->getLoc(), BI->getType(),
1332+
APFloat(floatDestTy->getAPFloatSemantics(), intLiteral->getValue()));
1333+
}
1334+
} else if (auto *floatLiteral = dyn_cast<FloatLiteralInst>(Args[0])) {
1335+
if (destTy->is<BuiltinIntegerType>()) {
1336+
SILBuilderWithScope B(BI);
1337+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), floatLiteral->getValue().bitcastToAPInt());
1338+
}
1339+
}
1340+
return nullptr;
1341+
}
1342+
15641343
case BuiltinValueKind::IntToFPWithOverflow: {
15651344
// Get the value. It should be a constant in most cases.
15661345
// Note, this will not always be a constant, for example, when analyzing

0 commit comments

Comments
 (0)