@@ -2030,11 +2030,19 @@ static bool isZExtOrSExt(Instruction::CastOps CastOpcode) {
2030
2030
InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
2031
2031
VPCostContext &Ctx) const {
2032
2032
RecurKind RdxKind = RdxDesc.getRecurrenceKind ();
2033
- Type *ElementTy = RdxDesc.getRecurrenceType ();
2033
+ Type *ElementTy = Ctx.Types .inferScalarType (this ->getVPSingleValue ());
2034
+ assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID ());
2035
+
2034
2036
auto *VectorTy = cast<VectorType>(ToVectorTy (ElementTy, VF));
2035
2037
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2036
2038
unsigned Opcode = RdxDesc.getOpcode ();
2037
2039
2040
+ // TODO: Remove the assertion when we support any-of reduction in VPlan-base
2041
+ // cost model.
2042
+ assert (!RecurrenceDescriptor::isAnyOfRecurrenceKind (
2043
+ RdxDesc.getRecurrenceKind ()) &&
2044
+ " VPlan-base cost model not support any-of reduction." );
2045
+
2038
2046
InstructionCost BaseCost;
2039
2047
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
2040
2048
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
@@ -2085,24 +2093,22 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2085
2093
2086
2094
// Try to match reduce.add(ext(mul(...)))
2087
2095
auto *ExtTy = cast<VectorType>(
2088
- ToVectorTy (Ext->getOperand (0 )-> getUnderlyingValue ()-> getType ( ), VF));
2096
+ ToVectorTy (Ctx. Types . inferScalarType ( Ext->getOperand (0 )), VF));
2089
2097
auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
2090
2098
Ext->getOperand (0 )->getDefiningRecipe ());
2091
2099
if (Mul && Mul->getOpcode () == Instruction::Mul &&
2092
2100
Opcode == Instruction::Add) {
2093
2101
auto *MulTy = cast<VectorType>(
2094
- ToVectorTy (Mul->getUnderlyingValue ()-> getType ( ), VF));
2102
+ ToVectorTy (Ctx. Types . inferScalarType ( Mul->getVPSingleValue () ), VF));
2095
2103
auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (0 ));
2096
2104
auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (1 ));
2097
2105
2098
2106
// Match reduce.add(ext(mul(ext(A), ext(B))))
2099
2107
if (InnerExt0 && isZExtOrSExt (InnerExt0->getOpcode ()) && InnerExt1 &&
2100
2108
isZExtOrSExt (InnerExt1->getOpcode ()) &&
2101
2109
InnerExt0->getOpcode () == InnerExt1->getOpcode ()) {
2102
- Type *InnerExt0Ty =
2103
- InnerExt0->getOperand (0 )->getUnderlyingValue ()->getType ();
2104
- Type *InnerExt1Ty =
2105
- InnerExt1->getOperand (0 )->getUnderlyingValue ()->getType ();
2110
+ Type *InnerExt0Ty = Ctx.Types .inferScalarType (InnerExt0->getOperand (0 ));
2111
+ Type *InnerExt1Ty = Ctx.Types .inferScalarType (InnerExt1->getOperand (0 ));
2106
2112
// Get the largest type.
2107
2113
auto *MaxExtVecTy = cast<VectorType>(
2108
2114
ToVectorTy (InnerExt0Ty->getIntegerBitWidth () >
@@ -2145,16 +2151,14 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2145
2151
// Match reduce.add(mul(ext(A), ext(B)))
2146
2152
auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (0 ));
2147
2153
auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (1 ));
2148
- auto *MulTy =
2149
- cast<VectorType>( ToVectorTy (Mul->getUnderlyingValue ()-> getType ( ), VF));
2154
+ auto *MulTy = cast<VectorType>(
2155
+ ToVectorTy (Ctx. Types . inferScalarType ( Mul->getVPSingleValue () ), VF));
2150
2156
InstructionCost MulCost =
2151
2157
Ctx.TTI .getArithmeticInstrCost (Instruction::Mul, MulTy, CostKind);
2152
2158
if (InnerExt0 && isZExtOrSExt (InnerExt0->getOpcode ()) && InnerExt1 &&
2153
2159
InnerExt0->getOpcode () == InnerExt1->getOpcode ()) {
2154
- Type *InnerExt0Ty =
2155
- InnerExt0->getOperand (0 )->getUnderlyingValue ()->getType ();
2156
- Type *InnerExt1Ty =
2157
- InnerExt1->getOperand (0 )->getUnderlyingValue ()->getType ();
2160
+ Type *InnerExt0Ty = Ctx.Types .inferScalarType (InnerExt0->getOperand (0 ));
2161
+ Type *InnerExt1Ty = Ctx.Types .inferScalarType (InnerExt1->getOperand (0 ));
2158
2162
auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy (
2159
2163
InnerExt0Ty->getIntegerBitWidth () > InnerExt1Ty->getIntegerBitWidth ()
2160
2164
? InnerExt0Ty
0 commit comments