@@ -968,6 +968,16 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
968
968
TypeInfo.inferScalarType (R.getOperand (1 )) ==
969
969
TypeInfo.inferScalarType (R.getVPSingleValue ()))
970
970
return R.getVPSingleValue ()->replaceAllUsesWith (R.getOperand (1 ));
971
+
972
+ if (match (&R, m_VPInstruction<VPInstruction::WideIVStep>(
973
+ m_VPValue (X), m_SpecificInt (1 ), m_VPValue (Y)))) {
974
+ if (TypeInfo.inferScalarType (X) != TypeInfo.inferScalarType (Y)) {
975
+ X = new VPWidenCastRecipe (Instruction::Trunc, X,
976
+ TypeInfo.inferScalarType (Y));
977
+ X->getDefiningRecipe ()->insertBefore (&R);
978
+ }
979
+ R.getVPSingleValue ()->replaceAllUsesWith (X);
980
+ }
971
981
}
972
982
973
983
// / Try to simplify the recipes in \p Plan. Use \p CanonicalIVTy as type for all
@@ -2050,9 +2060,10 @@ void VPlanTransforms::createInterleaveGroups(
2050
2060
}
2051
2061
}
2052
2062
2053
- void VPlanTransforms::convertToConcreteRecipes (VPlan &Plan) {
2054
- Type *CanonicalIVType = Plan.getCanonicalIV ()->getScalarType ();
2055
- VPTypeAnalysis TypeInfo (CanonicalIVType);
2063
+ void VPlanTransforms::convertToConcreteRecipes (VPlan &Plan,
2064
+ Type *CanonicalIVTy) {
2065
+ using namespace llvm ::VPlanPatternMatch;
2066
+ VPTypeAnalysis TypeInfo (CanonicalIVTy);
2056
2067
2057
2068
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
2058
2069
vp_depth_first_deep (Plan.getEntry ()))) {
@@ -2070,42 +2081,44 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
2070
2081
continue ;
2071
2082
}
2072
2083
2073
- auto *VPI = dyn_cast<VPInstruction>(&R);
2074
- if (VPI && VPI->getOpcode () == VPInstruction::WideIVStep) {
2075
- VPBuilder Builder (VPI->getParent (), VPI->getIterator ());
2076
- VPValue *VectorStep = VPI->getOperand (0 );
2077
- Type *IVTy = TypeInfo.inferScalarType (VPI->getOperand (2 ));
2078
- if (TypeInfo.inferScalarType (VectorStep) != IVTy) {
2079
- Instruction::CastOps CastOp = IVTy->isFloatingPointTy ()
2080
- ? Instruction::UIToFP
2081
- : Instruction::Trunc;
2082
- VectorStep = Builder.createWidenCast (CastOp, VectorStep, IVTy);
2083
- }
2084
-
2085
- VPValue *ScalarStep = VPI->getOperand (1 );
2086
- auto *ConstStep =
2087
- ScalarStep->isLiveIn ()
2088
- ? dyn_cast<ConstantInt>(ScalarStep->getLiveInIRValue ())
2089
- : nullptr ;
2090
- if (!ConstStep || ConstStep->getValue () != 1 ) {
2091
- if (TypeInfo.inferScalarType (ScalarStep) != IVTy) {
2092
- ScalarStep =
2093
- Builder.createWidenCast (Instruction::Trunc, ScalarStep, IVTy);
2094
- }
2095
-
2096
- std::optional<FastMathFlags> FMFs;
2097
- if (IVTy->isFloatingPointTy ())
2098
- FMFs = VPI->getFastMathFlags ();
2084
+ VPValue *VectorStep;
2085
+ VPValue *ScalarStep;
2086
+ VPValue *IVTyOp;
2087
+ if (!match (&R, m_VPInstruction<VPInstruction::WideIVStep>(
2088
+ m_VPValue (VectorStep), m_VPValue (ScalarStep),
2089
+ m_VPValue (IVTyOp))))
2090
+ continue ;
2091
+ auto *VPI = cast<VPInstruction>(&R);
2092
+ VPBuilder Builder (VPI->getParent (), VPI->getIterator ());
2093
+ Type *IVTy = TypeInfo.inferScalarType (IVTyOp);
2094
+ if (TypeInfo.inferScalarType (VectorStep) != IVTy) {
2095
+ Instruction::CastOps CastOp = IVTy->isFloatingPointTy ()
2096
+ ? Instruction::UIToFP
2097
+ : Instruction::Trunc;
2098
+ VectorStep = Builder.createWidenCast (CastOp, VectorStep, IVTy);
2099
+ }
2099
2100
2100
- unsigned MulOpc =
2101
- IVTy-> isFloatingPointTy () ? Instruction::FMul : Instruction::Mul;
2102
- VPInstruction *Mul = Builder. createNaryOp (
2103
- MulOpc, {VectorStep, ScalarStep}, FMFs, R. getDebugLoc ()) ;
2104
- VectorStep = Mul ;
2105
- }
2106
- VPI-> replaceAllUsesWith (VectorStep);
2107
- VPI-> eraseFromParent ( );
2101
+ auto *ConstStep =
2102
+ ScalarStep-> isLiveIn ()
2103
+ ? dyn_cast<ConstantInt>(ScalarStep-> getLiveInIRValue ())
2104
+ : nullptr ;
2105
+ assert (!ConstStep || ConstStep-> getValue () != 1 ) ;
2106
+ if (TypeInfo. inferScalarType (ScalarStep) != IVTy) {
2107
+ ScalarStep =
2108
+ Builder. createWidenCast (Instruction::Trunc, ScalarStep, IVTy );
2108
2109
}
2110
+
2111
+ std::optional<FastMathFlags> FMFs;
2112
+ if (IVTy->isFloatingPointTy ())
2113
+ FMFs = VPI->getFastMathFlags ();
2114
+
2115
+ unsigned MulOpc =
2116
+ IVTy->isFloatingPointTy () ? Instruction::FMul : Instruction::Mul;
2117
+ VPInstruction *Mul = Builder.createNaryOp (
2118
+ MulOpc, {VectorStep, ScalarStep}, FMFs, R.getDebugLoc ());
2119
+ VectorStep = Mul;
2120
+ VPI->replaceAllUsesWith (VectorStep);
2121
+ VPI->eraseFromParent ();
2109
2122
}
2110
2123
}
2111
2124
}
0 commit comments