Skip to content

Commit faa86e5

Browse files
committed
Address comments and use inferSalarType
1 parent 6a38817 commit faa86e5

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,11 +2030,19 @@ static bool isZExtOrSExt(Instruction::CastOps CastOpcode) {
20302030
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
20312031
VPCostContext &Ctx) const {
20322032
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2033-
Type *ElementTy = RdxDesc.getRecurrenceType();
2033+
Type *ElementTy = Ctx.Types.inferScalarType(this->getVPSingleValue());
2034+
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID());
2035+
20342036
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
20352037
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
20362038
unsigned Opcode = RdxDesc.getOpcode();
20372039

2040+
// TODO: Remove the assertion when we support any-of reduction in VPlan-base
2041+
// cost model.
2042+
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(
2043+
RdxDesc.getRecurrenceKind()) &&
2044+
"VPlan-base cost model not support any-of reduction.");
2045+
20382046
InstructionCost BaseCost;
20392047
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
20402048
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
@@ -2085,24 +2093,22 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
20852093

20862094
// Try to match reduce.add(ext(mul(...)))
20872095
auto *ExtTy = cast<VectorType>(
2088-
ToVectorTy(Ext->getOperand(0)->getUnderlyingValue()->getType(), VF));
2096+
ToVectorTy(Ctx.Types.inferScalarType(Ext->getOperand(0)), VF));
20892097
auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
20902098
Ext->getOperand(0)->getDefiningRecipe());
20912099
if (Mul && Mul->getOpcode() == Instruction::Mul &&
20922100
Opcode == Instruction::Add) {
20932101
auto *MulTy = cast<VectorType>(
2094-
ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
2102+
ToVectorTy(Ctx.Types.inferScalarType(Mul->getVPSingleValue()), VF));
20952103
auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
20962104
auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
20972105

20982106
// Match reduce.add(ext(mul(ext(A), ext(B))))
20992107
if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
21002108
isZExtOrSExt(InnerExt1->getOpcode()) &&
21012109
InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
2102-
Type *InnerExt0Ty =
2103-
InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
2104-
Type *InnerExt1Ty =
2105-
InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
2110+
Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
2111+
Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
21062112
// Get the largest type.
21072113
auto *MaxExtVecTy = cast<VectorType>(
21082114
ToVectorTy(InnerExt0Ty->getIntegerBitWidth() >
@@ -2145,16 +2151,14 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
21452151
// Match reduce.add(mul(ext(A), ext(B)))
21462152
auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(0));
21472153
auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand(1));
2148-
auto *MulTy =
2149-
cast<VectorType>(ToVectorTy(Mul->getUnderlyingValue()->getType(), VF));
2154+
auto *MulTy = cast<VectorType>(
2155+
ToVectorTy(Ctx.Types.inferScalarType(Mul->getVPSingleValue()), VF));
21502156
InstructionCost MulCost =
21512157
Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, MulTy, CostKind);
21522158
if (InnerExt0 && isZExtOrSExt(InnerExt0->getOpcode()) && InnerExt1 &&
21532159
InnerExt0->getOpcode() == InnerExt1->getOpcode()) {
2154-
Type *InnerExt0Ty =
2155-
InnerExt0->getOperand(0)->getUnderlyingValue()->getType();
2156-
Type *InnerExt1Ty =
2157-
InnerExt1->getOperand(0)->getUnderlyingValue()->getType();
2160+
Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
2161+
Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
21582162
auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy(
21592163
InnerExt0Ty->getIntegerBitWidth() > InnerExt1Ty->getIntegerBitWidth()
21602164
? InnerExt0Ty

0 commit comments

Comments
 (0)