@@ -2495,16 +2495,19 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
2495
2495
}
2496
2496
2497
2497
// / Check if instruction depends on ZExt and this ZExt can be moved after the
2498
- // / instruction. Move ZExt if it is profitable
2498
+ // / instruction. Move ZExt if it is profitable. For example:
2499
+ // / logic(zext(x),y) -> zext(logic(x,trunc(y)))
2500
+ // / lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
2501
+ // / Cost model calculations takes into account if zext(x) has other users and
2502
+ // / whether it can be propagated through them too.
2499
2503
bool VectorCombine::shrinkType (llvm::Instruction &I) {
2500
2504
Value *ZExted, *OtherOperand;
2501
2505
if (!match (&I, m_c_BitwiseLogic (m_ZExt (m_Value (ZExted)),
2502
2506
m_Value (OtherOperand))) &&
2503
2507
!match (&I, m_LShr (m_ZExt (m_Value (ZExted)), m_Value (OtherOperand))))
2504
2508
return false ;
2505
2509
2506
- Instruction *ZExtOperand =
2507
- cast<Instruction>(I.getOperand (I.getOperand (0 ) == OtherOperand ? 1 : 0 ));
2510
+ Value *ZExtOperand = I.getOperand (I.getOperand (0 ) == OtherOperand ? 1 : 0 );
2508
2511
2509
2512
auto *BigTy = cast<FixedVectorType>(I.getType ());
2510
2513
auto *SmallTy = cast<FixedVectorType>(ZExted->getType ());
@@ -2519,18 +2522,21 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
2519
2522
2520
2523
// Calculate costs of leaving current IR as it is and moving ZExt operation
2521
2524
// later, along with adding truncates if needed
2525
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2522
2526
InstructionCost ZExtCost = TTI.getCastInstrCost (
2523
2527
Instruction::ZExt, BigTy, SmallTy,
2524
- TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput );
2528
+ TargetTransformInfo::CastContextHint::None, CostKind );
2525
2529
InstructionCost CurrentCost = ZExtCost;
2526
2530
InstructionCost ShrinkCost = 0 ;
2527
2531
2528
2532
// Calculate total cost and check that we can propagate through all ZExt users
2529
2533
for (User *U : ZExtOperand->users ()) {
2530
2534
auto *UI = cast<Instruction>(U);
2531
2535
if (UI == &I) {
2532
- CurrentCost += TTI.getArithmeticInstrCost (UI->getOpcode (), BigTy);
2533
- ShrinkCost += TTI.getArithmeticInstrCost (UI->getOpcode (), SmallTy);
2536
+ CurrentCost +=
2537
+ TTI.getArithmeticInstrCost (UI->getOpcode (), BigTy, CostKind);
2538
+ ShrinkCost +=
2539
+ TTI.getArithmeticInstrCost (UI->getOpcode (), SmallTy, CostKind);
2534
2540
ShrinkCost += ZExtCost;
2535
2541
continue ;
2536
2542
}
@@ -2540,12 +2546,13 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
2540
2546
2541
2547
// Check if we can propagate ZExt through its other users
2542
2548
KB = computeKnownBits (UI, *DL);
2543
- unsigned UBW = KB.getBitWidth () - KB.Zero . countLeadingOnes ();
2549
+ unsigned UBW = KB.getBitWidth () - KB.countMinLeadingZeros ();
2544
2550
if (UBW > BW)
2545
2551
return false ;
2546
2552
2547
- CurrentCost += TTI.getArithmeticInstrCost (UI->getOpcode (), BigTy);
2548
- ShrinkCost += TTI.getArithmeticInstrCost (UI->getOpcode (), SmallTy);
2553
+ CurrentCost += TTI.getArithmeticInstrCost (UI->getOpcode (), BigTy, CostKind);
2554
+ ShrinkCost +=
2555
+ TTI.getArithmeticInstrCost (UI->getOpcode (), SmallTy, CostKind);
2549
2556
ShrinkCost += ZExtCost;
2550
2557
}
2551
2558
@@ -2554,7 +2561,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
2554
2561
if (!isa<Constant>(OtherOperand))
2555
2562
ShrinkCost += TTI.getCastInstrCost (
2556
2563
Instruction::Trunc, SmallTy, BigTy,
2557
- TargetTransformInfo::CastContextHint::None, TTI::TCK_RecipThroughput );
2564
+ TargetTransformInfo::CastContextHint::None, CostKind );
2558
2565
2559
2566
// If the cost of shrinking types and leaving the IR is the same, we'll lean
2560
2567
// towards modifying the IR because shrinking opens opportunities for other
0 commit comments