Skip to content

Commit af7587d

Browse files
committed
[InstCombine] reduce code duplication in visitTrunc(); NFC
1 parent 2e5bba6 commit af7587d

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -685,48 +685,50 @@ static Instruction *shrinkInsertElt(CastInst &Trunc,
685685
return nullptr;
686686
}
687687

688-
Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
689-
if (Instruction *Result = commonCastTransforms(CI))
688+
Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
689+
if (Instruction *Result = commonCastTransforms(Trunc))
690690
return Result;
691691

692-
Value *Src = CI.getOperand(0);
693-
Type *DestTy = CI.getType(), *SrcTy = Src->getType();
692+
Value *Src = Trunc.getOperand(0);
693+
Type *DestTy = Trunc.getType(), *SrcTy = Src->getType();
694+
unsigned DestWidth = DestTy->getScalarSizeInBits();
695+
unsigned SrcWidth = SrcTy->getScalarSizeInBits();
694696
ConstantInt *Cst;
695697

696698
// Attempt to truncate the entire input expression tree to the destination
697699
// type. Only do this if the dest type is a simple type, don't convert the
698700
// expression tree to something weird like i93 unless the source is also
699701
// strange.
700702
if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
701-
canEvaluateTruncated(Src, DestTy, *this, &CI)) {
703+
canEvaluateTruncated(Src, DestTy, *this, &Trunc)) {
702704

703705
// If this cast is a truncate, evaluting in a different type always
704706
// eliminates the cast, so it is always a win.
705707
LLVM_DEBUG(
706708
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
707709
" to avoid cast: "
708-
<< CI << '\n');
710+
<< Trunc << '\n');
709711
Value *Res = EvaluateInDifferentType(Src, DestTy, false);
710712
assert(Res->getType() == DestTy);
711-
return replaceInstUsesWith(CI, Res);
713+
return replaceInstUsesWith(Trunc, Res);
712714
}
713715

714716
// Test if the trunc is the user of a select which is part of a
715717
// minimum or maximum operation. If so, don't do any more simplification.
716718
// Even simplifying demanded bits can break the canonical form of a
717719
// min/max.
718720
Value *LHS, *RHS;
719-
if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0)))
720-
if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN)
721+
if (SelectInst *Sel = dyn_cast<SelectInst>(Src))
722+
if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN)
721723
return nullptr;
722724

723725
// See if we can simplify any instructions used by the input whose sole
724726
// purpose is to compute bits we don't care about.
725-
if (SimplifyDemandedInstructionBits(CI))
726-
return &CI;
727+
if (SimplifyDemandedInstructionBits(Trunc))
728+
return &Trunc;
727729

728-
if (DestTy->getScalarSizeInBits() == 1) {
729-
Value *Zero = Constant::getNullValue(Src->getType());
730+
if (DestWidth == 1) {
731+
Value *Zero = Constant::getNullValue(SrcTy);
730732
if (DestTy->isIntegerTy()) {
731733
// Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only).
732734
// TODO: We canonicalize to more instructions here because we are probably
@@ -743,14 +745,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
743745
const APInt *C;
744746
if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
745747
// trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0
746-
APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C);
748+
APInt MaskC = APInt(SrcWidth, 1).shl(*C);
747749
Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC));
748750
return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
749751
}
750752
if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)),
751753
m_Deferred(X))))) {
752754
// trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0
753-
APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1;
755+
APInt MaskC = APInt(SrcWidth, 1).shl(*C) | 1;
754756
Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC));
755757
return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
756758
}
@@ -772,7 +774,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
772774
// If the shift amount is larger than the size of A, then the result is
773775
// known to be zero because all the input bits got shifted out.
774776
if (Cst->getZExtValue() >= ASize)
775-
return replaceInstUsesWith(CI, Constant::getNullValue(DestTy));
777+
return replaceInstUsesWith(Trunc, Constant::getNullValue(DestTy));
776778

777779
// Since we're doing an lshr and a zero extend, and know that the shift
778780
// amount is smaller than ASize, it is always safe to do the shift in A's
@@ -791,10 +793,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
791793
if (Src->hasOneUse() &&
792794
match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) {
793795
Value *SExt = cast<Instruction>(Src)->getOperand(0);
794-
const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits();
795-
const unsigned ASize = A->getType()->getPrimitiveSizeInBits();
796-
const unsigned CISize = CI.getType()->getPrimitiveSizeInBits();
797-
const unsigned MaxAmt = SExtSize - std::max(CISize, ASize);
796+
unsigned ASize = A->getType()->getPrimitiveSizeInBits();
797+
unsigned MaxAmt = SrcWidth - std::max(DestWidth, ASize);
798798
unsigned ShiftAmt = Cst->getZExtValue();
799799

800800
// This optimization can be only performed when zero bits generated by
@@ -803,24 +803,24 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
803803
// FIXME: Instead of bailing when the shift is too large, use and to clear
804804
// the extra bits.
805805
if (ShiftAmt <= MaxAmt) {
806-
if (CISize == ASize)
807-
return BinaryOperator::CreateAShr(A, ConstantInt::get(CI.getType(),
808-
std::min(ShiftAmt, ASize - 1)));
806+
if (DestWidth == ASize)
807+
return BinaryOperator::CreateAShr(
808+
A, ConstantInt::get(DestTy, std::min(ShiftAmt, ASize - 1)));
809809
if (SExt->hasOneUse()) {
810810
Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1));
811811
Shift->takeName(Src);
812-
return CastInst::CreateIntegerCast(Shift, CI.getType(), true);
812+
return CastInst::CreateIntegerCast(Shift, DestTy, true);
813813
}
814814
}
815815
}
816816

817-
if (Instruction *I = narrowBinOp(CI))
817+
if (Instruction *I = narrowBinOp(Trunc))
818818
return I;
819819

820-
if (Instruction *I = shrinkSplatShuffle(CI, Builder))
820+
if (Instruction *I = shrinkSplatShuffle(Trunc, Builder))
821821
return I;
822822

823-
if (Instruction *I = shrinkInsertElt(CI, Builder))
823+
if (Instruction *I = shrinkInsertElt(Trunc, Builder))
824824
return I;
825825

826826
if (Src->hasOneUse() && isa<IntegerType>(SrcTy) &&
@@ -831,18 +831,17 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
831831
!match(A, m_Shr(m_Value(), m_Constant()))) {
832832
// Skip shifts of shift by constants. It undoes a combine in
833833
// FoldShiftByConstant and is the extend in reg pattern.
834-
const unsigned DestSize = DestTy->getScalarSizeInBits();
835-
if (Cst->getValue().ult(DestSize)) {
834+
if (Cst->getValue().ult(DestWidth)) {
836835
Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr");
837836

838837
return BinaryOperator::Create(
839838
Instruction::Shl, NewTrunc,
840-
ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize)));
839+
ConstantInt::get(DestTy, Cst->getValue().trunc(DestWidth)));
841840
}
842841
}
843842
}
844843

845-
if (Instruction *I = foldVecTruncToExtElt(CI, *this))
844+
if (Instruction *I = foldVecTruncToExtElt(Trunc, *this))
846845
return I;
847846

848847
// Whenever an element is extracted from a vector, and then truncated,
@@ -856,13 +855,11 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
856855
Value *VecOp;
857856
if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
858857
auto *VecOpTy = cast<VectorType>(VecOp->getType());
859-
unsigned DestScalarSize = DestTy->getScalarSizeInBits();
860-
unsigned VecOpScalarSize = VecOpTy->getScalarSizeInBits();
861858
unsigned VecNumElts = VecOpTy->getNumElements();
862859

863860
// A badly fit destination size would result in an invalid cast.
864-
if (VecOpScalarSize % DestScalarSize == 0) {
865-
uint64_t TruncRatio = VecOpScalarSize / DestScalarSize;
861+
if (SrcWidth % DestWidth == 0) {
862+
uint64_t TruncRatio = SrcWidth / DestWidth;
866863
uint64_t BitCastNumElts = VecNumElts * TruncRatio;
867864
uint64_t VecOpIdx = Cst->getZExtValue();
868865
uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1

0 commit comments

Comments
 (0)