Skip to content

Commit 71d6952

Browse files
[LV][SLP] Vectorizers now use getFRemInstrCost for frem costs
SLP vectorization for frem now happens when vector library calls are available, given its type and vector length. This is due to using the updated cost that amounts to a call. Add tests that do SLP vectorization for code that contains 2x double and 4x float frem instructions. LoopVectorizer now also uses getFRemInstrCost.
1 parent ce534dc commit 71d6952

File tree

6 files changed

+51
-51
lines changed

6 files changed

+51
-51
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,18 @@ class TargetTransformInfo {
12551255
ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
12561256
const Instruction *CxtI = nullptr) const;
12571257

1258+
/// Returns the cost of a vector instruction based on the assumption that frem
1259+
/// will be later transformed (by ReplaceWithVecLib) into a call to a
1260+
/// platform specific frem vector math function.
1261+
/// If unsupported, it will return cost using getArithmeticInstrCost.
1262+
InstructionCost getFRemInstrCost(
1263+
const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
1264+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
1265+
TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
1266+
TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None},
1267+
ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
1268+
const Instruction *CxtI = nullptr) const;
1269+
12581270
/// Returns the cost estimation for alternating opcode pattern that can be
12591271
/// lowered to a single instruction on the target. In X86 this is for the
12601272
/// addsub instruction which corrsponds to a Shuffle + Fadd + FSub pattern in

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "llvm/ADT/MapVector.h"
1717
#include "llvm/ADT/SmallVector.h"
1818
#include "llvm/Analysis/LoopAccessAnalysis.h"
19-
#include "llvm/Analysis/TargetTransformInfo.h"
2019
#include "llvm/IR/VFABIDemangler.h"
2120
#include "llvm/Support/CheckedArithmetic.h"
2221

@@ -120,6 +119,7 @@ template <typename InstTy> class InterleaveGroup;
120119
class IRBuilderBase;
121120
class Loop;
122121
class ScalarEvolution;
122+
class TargetTransformInfo;
123123
class Type;
124124
class Value;
125125

@@ -410,14 +410,6 @@ bool maskIsAllOneOrUndef(Value *Mask);
410410
/// for each lane which may be active.
411411
APInt possiblyDemandedEltsInMask(Value *Mask);
412412

413-
/// Returns the cost of a call when a target has a vector library function for
414-
/// the given \p VecTy, otherwise an invalid cost.
415-
InstructionCost getVecLibCallCost(const Instruction *I,
416-
const TargetTransformInfo *TTI,
417-
const TargetLibraryInfo *TLI,
418-
VectorType *VecTy,
419-
TargetTransformInfo::TargetCostKind CostKind);
420-
421413
/// The group of interleaved loads/stores sharing the same stride and
422414
/// close to each other.
423415
///

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llvm/Analysis/TargetTransformInfo.h"
1010
#include "llvm/Analysis/CFG.h"
1111
#include "llvm/Analysis/LoopIterator.h"
12+
#include "llvm/Analysis/TargetLibraryInfo.h"
1213
#include "llvm/Analysis/TargetTransformInfoImpl.h"
1314
#include "llvm/IR/CFG.h"
1415
#include "llvm/IR/Dominators.h"
@@ -881,6 +882,24 @@ InstructionCost TargetTransformInfo::getArithmeticInstrCost(
881882
return Cost;
882883
}
883884

885+
InstructionCost TargetTransformInfo::getFRemInstrCost(
886+
const TargetLibraryInfo *TLI, unsigned Opcode, Type *Ty,
887+
TTI::TargetCostKind CostKind, OperandValueInfo Op1Info,
888+
OperandValueInfo Op2Info, ArrayRef<const Value *> Args,
889+
const Instruction *CxtI) const {
890+
assert(Opcode == Instruction::FRem && "Instruction must be frem");
891+
892+
VectorType *VecTy = dyn_cast<VectorType>(Ty);
893+
Type *ScalarTy = VecTy ? VecTy->getScalarType() : Ty;
894+
LibFunc Func;
895+
if (VecTy && TLI->getLibFunc(Opcode, ScalarTy, Func) &&
896+
TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
897+
return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind);
898+
899+
return getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info, Args,
900+
CxtI);
901+
}
902+
884903
InstructionCost TargetTransformInfo::getAltInstrCost(
885904
VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
886905
const SmallBitVector &OpcodeMask, TTI::TargetCostKind CostKind) const {

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "llvm/Analysis/LoopIterator.h"
1919
#include "llvm/Analysis/ScalarEvolution.h"
2020
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
21-
#include "llvm/Analysis/TargetLibraryInfo.h"
2221
#include "llvm/Analysis/TargetTransformInfo.h"
2322
#include "llvm/Analysis/ValueTracking.h"
2423
#include "llvm/IR/Constants.h"
@@ -1032,22 +1031,6 @@ APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
10321031
return DemandedElts;
10331032
}
10341033

1035-
InstructionCost
1036-
llvm::getVecLibCallCost(const Instruction *I, const TargetTransformInfo *TTI,
1037-
const TargetLibraryInfo *TLI, VectorType *VecTy,
1038-
TargetTransformInfo::TargetCostKind CostKind) {
1039-
SmallVector<Type *, 4> OpTypes;
1040-
for (auto &Op : I->operands())
1041-
OpTypes.push_back(Op->getType());
1042-
1043-
LibFunc Func;
1044-
if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
1045-
TLI->isFunctionVectorizable(TLI->getName(Func), VecTy->getElementCount()))
1046-
return TTI->getCallInstrCost(nullptr, VecTy, OpTypes, CostKind);
1047-
1048-
return InstructionCost::getInvalid();
1049-
}
1050-
10511034
bool InterleavedAccessInfo::isStrided(int Stride) {
10521035
unsigned Factor = std::abs(Stride);
10531036
return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6899,25 +6899,15 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF,
68996899
Op2Info.Kind = TargetTransformInfo::OK_UniformValue;
69006900

69016901
SmallVector<const Value *, 4> Operands(I->operand_values());
6902-
auto InstrCost = TTI.getArithmeticInstrCost(
6903-
I->getOpcode(), VectorTy, CostKind,
6904-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
6905-
Op2Info, Operands, I);
6906-
6907-
// Some targets can replace frem with vector library calls.
6908-
InstructionCost VecCallCost = InstructionCost::getInvalid();
6909-
if (I->getOpcode() == Instruction::FRem) {
6910-
LibFunc Func;
6911-
if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) &&
6912-
TLI->isFunctionVectorizable(TLI->getName(Func), VF)) {
6913-
SmallVector<Type *, 4> OpTypes;
6914-
for (auto &Op : I->operands())
6915-
OpTypes.push_back(Op->getType());
6916-
VecCallCost =
6917-
TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind);
6918-
}
6919-
}
6920-
return std::min(InstrCost, VecCallCost);
6902+
TTI::OperandValueInfo Op1Info{TargetTransformInfo::OK_AnyValue,
6903+
TargetTransformInfo::OP_None};
6904+
// Some targets replace frem with vector library calls.
6905+
if (I->getOpcode() == Instruction::FRem)
6906+
return TTI.getFRemInstrCost(TLI, I->getOpcode(), VectorTy, CostKind,
6907+
Op1Info, Op2Info, Operands, I);
6908+
6909+
return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, CostKind,
6910+
Op1Info, Op2Info, Operands, I);
69216911
}
69226912
case Instruction::FNeg: {
69236913
return TTI.getArithmeticInstrCost(

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8616,12 +8616,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
86168616
unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1;
86178617
TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0));
86188618
TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx));
8619-
auto VecCost = TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind,
8620-
Op1Info, Op2Info);
8621-
// Some targets can replace frem with vector library calls.
8622-
InstructionCost VecCallCost =
8623-
getVecLibCallCost(VL0, TTI, TLI, VecTy, CostKind);
8624-
return std::min(VecInstrCost, VecCallCost) + CommonCost;
8619+
8620+
// Some targets replace frem with vector library calls.
8621+
if (ShuffleOrOp == Instruction::FRem)
8622+
return TTI->getFRemInstrCost(TLI, ShuffleOrOp, VecTy, CostKind, Op1Info,
8623+
Op2Info) +
8624+
CommonCost;
8625+
8626+
return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info,
8627+
Op2Info) +
8628+
CommonCost;
86258629
};
86268630
return GetCostDiff(GetScalarCost, GetVectorCost);
86278631
}

0 commit comments

Comments
 (0)