@@ -2022,6 +2022,11 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2022
2022
State.set (this , NewRed, /* IsScalar*/ true );
2023
2023
}
2024
2024
2025
+ static bool isZExtOrSExt (Instruction::CastOps CastOpcode) {
2026
+ return CastOpcode == Instruction::CastOps::ZExt ||
2027
+ CastOpcode == Instruction::CastOps::SExt;
2028
+ }
2029
+
2025
2030
InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
2026
2031
VPCostContext &Ctx) const {
2027
2032
RecurKind RdxKind = RdxDesc.getRecurrenceKind ();
@@ -2030,17 +2035,149 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2030
2035
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2031
2036
unsigned Opcode = RdxDesc.getOpcode ();
2032
2037
2033
- // Cost = Reduction cost + BinOp cost
2034
- InstructionCost Cost =
2035
- Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, CostKind);
2038
+ InstructionCost BaseCost;
2036
2039
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
2037
2040
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2038
- return Cost + Ctx.TTI .getMinMaxReductionCost (
2039
- Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2041
+ BaseCost = Ctx.TTI .getMinMaxReductionCost (
2042
+ Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2043
+ } else {
2044
+ BaseCost = Ctx.TTI .getArithmeticReductionCost (
2045
+ Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2046
+ }
2047
+
2048
+ // For a call to the llvm.fmuladd intrinsic we need to add the cost of a
2049
+ // normal fmul instruction to the cost of the fadd reduction.
2050
+ if (RdxKind == RecurKind::FMulAdd)
2051
+ BaseCost +=
2052
+ Ctx.TTI .getArithmeticInstrCost (Instruction::FMul, VectorTy, CostKind);
2053
+
2054
+ // If we're using ordered reductions then we can just return the base cost
2055
+ // here, since getArithmeticReductionCost calculates the full ordered
2056
+ // reduction cost when FP reassociation is not allowed.
2057
+ if (IsOrdered && Opcode == Instruction::FAdd)
2058
+ return BaseCost;
2059
+
2060
+ // Special case for arm from D93476
2061
+ // The reduction instruction can be substituted in following condition.
2062
+ //
2063
+ // %sa = sext <16 x i8> A to <16 x i32>
2064
+ // %sb = sext <16 x i8> B to <16 x i32>
2065
+ // %m = mul <16 x i32> %sa, %sb
2066
+ // %r = vecreduce.add(%m)
2067
+ // ->
2068
+ // R = VMLADAV A, B
2069
+ //
2070
+ // There are other instructions for performing add reductions of
2071
+ // v4i32/v8i16/v16i8 into i32 (VADDV), for doing the same with v4i32->i64
2072
+ // (VADDLV) and for performing a v4i32/v8i16 MLA into an i64 (VMLALDAV).
2073
+ //
2074
+ // We are looking for a pattern of, and finding the minimal acceptable cost:
2075
+ // reduce.add(ext(mul(ext(A), ext(B)))) or
2076
+ // reduce(ext(A)) or
2077
+ // reduce.add(mul(ext(A), ext(B))) or
2078
+ // reduce.add(mul(A, B)) or
2079
+ // reduce(A).
2080
+
2081
+ // Try to match reduce(ext(...))
2082
+ auto *Ext = dyn_cast<VPWidenCastRecipe>(getVecOp ());
2083
+ if (Ext && isZExtOrSExt (Ext->getOpcode ())) {
2084
+ bool isUnsigned = Ext->getOpcode () == Instruction::CastOps::ZExt;
2085
+
2086
+ // Try to match reduce.add(ext(mul(...)))
2087
+ auto *ExtTy = cast<VectorType>(
2088
+ ToVectorTy (Ext->getOperand (0 )->getUnderlyingValue ()->getType (), VF));
2089
+ auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
2090
+ Ext->getOperand (0 )->getDefiningRecipe ());
2091
+ if (Mul && Mul->getOpcode () == Instruction::Mul &&
2092
+ Opcode == Instruction::Add) {
2093
+ auto *MulTy = cast<VectorType>(
2094
+ ToVectorTy (Mul->getUnderlyingValue ()->getType (), VF));
2095
+ auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (0 ));
2096
+ auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (1 ));
2097
+
2098
+ // Match reduce.add(ext(mul(ext(A), ext(B))))
2099
+ if (InnerExt0 && isZExtOrSExt (InnerExt0->getOpcode ()) && InnerExt1 &&
2100
+ isZExtOrSExt (InnerExt1->getOpcode ()) &&
2101
+ InnerExt0->getOpcode () == InnerExt1->getOpcode ()) {
2102
+ Type *InnerExt0Ty =
2103
+ InnerExt0->getOperand (0 )->getUnderlyingValue ()->getType ();
2104
+ Type *InnerExt1Ty =
2105
+ InnerExt1->getOperand (0 )->getUnderlyingValue ()->getType ();
2106
+ // Get the largest type.
2107
+ auto *MaxExtVecTy = cast<VectorType>(
2108
+ ToVectorTy (InnerExt0Ty->getIntegerBitWidth () >
2109
+ InnerExt1Ty->getIntegerBitWidth ()
2110
+ ? InnerExt0Ty
2111
+ : InnerExt1Ty,
2112
+ VF));
2113
+ InstructionCost RedCost = Ctx.TTI .getMulAccReductionCost (
2114
+ isUnsigned, ElementTy, MaxExtVecTy, CostKind);
2115
+ InstructionCost InnerExtCost =
2116
+ Ctx.TTI .getCastInstrCost (InnerExt0->getOpcode (), MulTy, MaxExtVecTy,
2117
+ TTI::CastContextHint::None, CostKind);
2118
+ InstructionCost MulCost =
2119
+ Ctx.TTI .getArithmeticInstrCost (Instruction::Mul, MulTy, CostKind);
2120
+ InstructionCost ExtCost =
2121
+ Ctx.TTI .getCastInstrCost (Ext->getOpcode (), VectorTy, ExtTy,
2122
+ TTI::CastContextHint::None, CostKind);
2123
+ if (RedCost.isValid () &&
2124
+ RedCost < InnerExtCost * 2 + MulCost + ExtCost + BaseCost)
2125
+ return RedCost;
2126
+ }
2127
+ }
2128
+
2129
+ // Match reduce(ext(A))
2130
+ InstructionCost RedCost =
2131
+ Ctx.TTI .getExtendedReductionCost (Opcode, isUnsigned, ElementTy, ExtTy,
2132
+ RdxDesc.getFastMathFlags (), CostKind);
2133
+ InstructionCost ExtCost =
2134
+ Ctx.TTI .getCastInstrCost (Ext->getOpcode (), VectorTy, ExtTy,
2135
+ TTI::CastContextHint::None, CostKind);
2136
+ if (RedCost.isValid () && RedCost < RedCost + ExtCost)
2137
+ return RedCost;
2138
+ }
2139
+
2140
+ // Try to match reduce.add(mul(...))
2141
+ auto *Mul =
2142
+ dyn_cast_if_present<VPWidenRecipe>(getVecOp ()->getDefiningRecipe ());
2143
+ if (Mul && Mul->getOpcode () == Instruction::Mul &&
2144
+ Opcode == Instruction::Add) {
2145
+ // Match reduce.add(mul(ext(A), ext(B)))
2146
+ auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (0 ));
2147
+ auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (1 ));
2148
+ auto *MulTy =
2149
+ cast<VectorType>(ToVectorTy (Mul->getUnderlyingValue ()->getType (), VF));
2150
+ InstructionCost MulCost =
2151
+ Ctx.TTI .getArithmeticInstrCost (Instruction::Mul, MulTy, CostKind);
2152
+ if (InnerExt0 && isZExtOrSExt (InnerExt0->getOpcode ()) && InnerExt1 &&
2153
+ InnerExt0->getOpcode () == InnerExt1->getOpcode ()) {
2154
+ Type *InnerExt0Ty =
2155
+ InnerExt0->getOperand (0 )->getUnderlyingValue ()->getType ();
2156
+ Type *InnerExt1Ty =
2157
+ InnerExt1->getOperand (0 )->getUnderlyingValue ()->getType ();
2158
+ auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy (
2159
+ InnerExt0Ty->getIntegerBitWidth () > InnerExt1Ty->getIntegerBitWidth ()
2160
+ ? InnerExt0Ty
2161
+ : InnerExt1Ty,
2162
+ VF));
2163
+ bool isUnsigned = InnerExt0->getOpcode () == Instruction::CastOps::ZExt;
2164
+ InstructionCost RedCost = Ctx.TTI .getMulAccReductionCost (
2165
+ isUnsigned, ElementTy, MaxInnerExtVecTy, CostKind);
2166
+ InstructionCost InnerExtCost = Ctx.TTI .getCastInstrCost (
2167
+ InnerExt0->getOpcode (), MulTy, MaxInnerExtVecTy,
2168
+ TTI::CastContextHint::None, CostKind);
2169
+ if (RedCost.isValid () && RedCost < BaseCost + MulCost + 2 * InnerExtCost)
2170
+ return RedCost;
2171
+ }
2172
+ // Match reduce.add(mul)
2173
+ InstructionCost RedCost =
2174
+ Ctx.TTI .getMulAccReductionCost (true , ElementTy, VectorTy, CostKind);
2175
+ if (RedCost.isValid () && RedCost < BaseCost + MulCost)
2176
+ return RedCost;
2040
2177
}
2041
2178
2042
- return Cost + Ctx. TTI . getArithmeticReductionCost (
2043
- Opcode, VectorTy, RdxDesc. getFastMathFlags () , CostKind);
2179
+ // Normal cost = Reduction cost + BinOp cost
2180
+ return BaseCost + Ctx. TTI . getArithmeticInstrCost ( Opcode, ElementTy , CostKind);
2044
2181
}
2045
2182
2046
2183
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments