Skip to content

Commit 285bc69

Browse files
committed
[SLP]Fix PR80027: Fix costs processing for minbitwidth types.
Need to switch the types, the destination is first in getCastInstrCost function.
1 parent 5f6640e commit 285bc69

File tree

2 files changed

+80
-24
lines changed

2 files changed

+80
-24
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7888,19 +7888,18 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
78887888
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
78897889
unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy);
78907890
unsigned VecOpcode;
7891-
auto *SrcVecTy =
7891+
auto *UserVecTy =
78927892
FixedVectorType::get(UserScalarTy, E->getVectorFactor());
78937893
if (BWSz > SrcBWSz)
78947894
VecOpcode = Instruction::Trunc;
78957895
else
78967896
VecOpcode =
78977897
It->second.second ? Instruction::SExt : Instruction::ZExt;
78987898
TTI::CastContextHint CCH = GetCastContextHint(VL0);
7899-
VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH,
7899+
VecCost += TTI->getCastInstrCost(VecOpcode, UserVecTy, VecTy, CCH,
79007900
CostKind);
7901-
ScalarCost +=
7902-
Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy,
7903-
CCH, CostKind);
7901+
ScalarCost += Sz * TTI->getCastInstrCost(VecOpcode, UserScalarTy,
7902+
ScalarTy, CCH, CostKind);
79047903
}
79057904
}
79067905
}
@@ -8981,7 +8980,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
89818980
SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers;
89828981
SmallVector<APInt> DemandedElts;
89838982
SmallDenseSet<Value *, 4> UsedInserts;
8984-
DenseSet<Value *> VectorCasts;
8983+
DenseSet<std::pair<const TreeEntry *, Type *>> VectorCasts;
89858984
for (ExternalUser &EU : ExternalUses) {
89868985
// We only add extract cost once for the same scalar.
89878986
if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
@@ -9051,11 +9050,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
90519050
DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
90529051
VecId = FirstUsers.size() - 1;
90539052
auto It = MinBWs.find(ScalarTE);
9054-
if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
9053+
if (It != MinBWs.end() &&
9054+
VectorCasts
9055+
.insert(std::make_pair(ScalarTE, FTy->getElementType()))
9056+
.second) {
90559057
unsigned BWSz = It->second.second;
9056-
unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
9058+
unsigned DstBWSz = DL->getTypeSizeInBits(FTy->getElementType());
90579059
unsigned VecOpcode;
9058-
if (BWSz < SrcBWSz)
9060+
if (DstBWSz < BWSz)
90599061
VecOpcode = Instruction::Trunc;
90609062
else
90619063
VecOpcode =
@@ -9108,17 +9110,20 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
91089110
}
91099111
// Add reduced value cost, if resized.
91109112
if (!VectorizedVals.empty()) {
9111-
auto BWIt = MinBWs.find(VectorizableTree.front().get());
9113+
const TreeEntry &Root = *VectorizableTree.front().get();
9114+
auto BWIt = MinBWs.find(&Root);
91129115
if (BWIt != MinBWs.end()) {
9113-
Type *DstTy = VectorizableTree.front()->Scalars.front()->getType();
9116+
Type *DstTy = Root.Scalars.front()->getType();
91149117
unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
9115-
unsigned Opcode = Instruction::Trunc;
9116-
if (OriginalSz < BWIt->second.first)
9117-
Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
9118-
Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
9119-
Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
9120-
TTI::CastContextHint::None,
9121-
TTI::TCK_RecipThroughput);
9118+
if (OriginalSz != BWIt->second.first) {
9119+
unsigned Opcode = Instruction::Trunc;
9120+
if (OriginalSz < BWIt->second.first)
9121+
Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
9122+
Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
9123+
Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
9124+
TTI::CastContextHint::None,
9125+
TTI::TCK_RecipThroughput);
9126+
}
91229127
}
91239128
}
91249129

@@ -11419,9 +11424,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1141911424
VecOpcode = Instruction::BitCast;
1142011425
} else if (BWSz < SrcBWSz) {
1142111426
VecOpcode = Instruction::Trunc;
11422-
} else if (It != MinBWs.end()) {
11427+
} else if (SrcIt != MinBWs.end()) {
1142311428
assert(BWSz > SrcBWSz && "Invalid cast!");
11424-
VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
11429+
VecOpcode =
11430+
SrcIt->second.second ? Instruction::SExt : Instruction::ZExt;
1142511431
}
1142611432
}
1142711433
Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast)
@@ -11929,7 +11935,7 @@ Value *BoUpSLP::vectorizeTree(
1192911935
// basic block. Only one extractelement per block should be emitted.
1193011936
DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
1193111937
SmallDenseSet<Value *, 4> UsedInserts;
11932-
DenseMap<Value *, Value *> VectorCasts;
11938+
DenseMap<std::pair<Value *, Type *>, Value *> VectorCasts;
1193311939
SmallDenseSet<Value *, 4> ScalarsWithNullptrUser;
1193411940
// Extract all of the elements with the external uses.
1193511941
for (const auto &ExternalUse : ExternalUses) {
@@ -12050,18 +12056,20 @@ Value *BoUpSLP::vectorizeTree(
1205012056
// Need to use original vector, if the root is truncated.
1205112057
auto BWIt = MinBWs.find(E);
1205212058
if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
12053-
auto VecIt = VectorCasts.find(Scalar);
12059+
auto *ScalarTy = FTy->getElementType();
12060+
auto Key = std::make_pair(Vec, ScalarTy);
12061+
auto VecIt = VectorCasts.find(Key);
1205412062
if (VecIt == VectorCasts.end()) {
1205512063
IRBuilder<>::InsertPointGuard Guard(Builder);
1205612064
if (auto *IVec = dyn_cast<Instruction>(Vec))
1205712065
Builder.SetInsertPoint(IVec->getNextNonDebugInstruction());
1205812066
Vec = Builder.CreateIntCast(
1205912067
Vec,
1206012068
FixedVectorType::get(
12061-
cast<VectorType>(VU->getType())->getElementType(),
12069+
ScalarTy,
1206212070
cast<FixedVectorType>(Vec->getType())->getNumElements()),
1206312071
BWIt->second.second);
12064-
VectorCasts.try_emplace(Scalar, Vec);
12072+
VectorCasts.try_emplace(Key, Vec);
1206512073
} else {
1206612074
Vec = VecIt->second;
1206712075
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt -S --passes=slp-vectorizer -mtriple=s390x-unknown-linux -mcpu=z14 < %s | FileCheck %s
3+
4+
define void @test() {
5+
; CHECK-LABEL: define void @test(
6+
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
7+
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 0 to i32
8+
; CHECK-NEXT: [[TMP2:%.*]] = zext i8 0 to i32
9+
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> <i32 0, i32 poison, i32 0, i32 0>, i32 [[TMP2]], i32 1
10+
; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> zeroinitializer, <4 x i32> zeroinitializer, <4 x i32> [[TMP3]]
11+
; CHECK-NEXT: [[TMP5:%.*]] = select i1 false, i32 0, i32 0
12+
; CHECK-NEXT: [[TMP6:%.*]] = select i1 false, i32 0, i32 [[TMP1]]
13+
; CHECK-NEXT: [[TMP7:%.*]] = select i1 false, i32 0, i32 [[TMP2]]
14+
; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP4]])
15+
; CHECK-NEXT: [[OP_RDX:%.*]] = xor i32 [[TMP8]], [[TMP5]]
16+
; CHECK-NEXT: [[OP_RDX1:%.*]] = xor i32 [[TMP6]], [[TMP7]]
17+
; CHECK-NEXT: [[OP_RDX2:%.*]] = xor i32 [[OP_RDX]], [[OP_RDX1]]
18+
; CHECK-NEXT: [[TMP9:%.*]] = trunc i32 [[OP_RDX2]] to i16
19+
; CHECK-NEXT: store i16 [[TMP9]], ptr null, align 2
20+
; CHECK-NEXT: ret void
21+
;
22+
%1 = zext i8 0 to i32
23+
%.not = icmp sgt i32 0, %1
24+
%2 = zext i8 0 to i32
25+
%3 = select i1 %.not, i32 0, i32 0
26+
%4 = zext i8 0 to i32
27+
%.not.1 = icmp sgt i32 0, %4
28+
%5 = zext i8 0 to i32
29+
%6 = select i1 %.not.1, i32 0, i32 %5
30+
%7 = xor i32 %6, %3
31+
%8 = zext i8 0 to i32
32+
%.not.2 = icmp sgt i32 0, %8
33+
%9 = select i1 %.not.2, i32 0, i32 0
34+
%10 = xor i32 %9, %7
35+
%11 = zext i8 0 to i32
36+
%.not.3 = icmp sgt i32 0, %11
37+
%12 = select i1 %.not.3, i32 0, i32 0
38+
%13 = xor i32 %12, %10
39+
%14 = select i1 false, i32 0, i32 0
40+
%15 = xor i32 %14, %13
41+
%16 = select i1 false, i32 0, i32 %2
42+
%17 = xor i32 %16, %15
43+
%18 = select i1 false, i32 0, i32 %5
44+
%19 = xor i32 %18, %17
45+
%20 = trunc i32 %19 to i16
46+
store i16 %20, ptr null, align 2
47+
ret void
48+
}

0 commit comments

Comments
 (0)