Skip to content

Commit 0062975

Browse files
authored
[SLP][REVEC] Fix cost model for getGatherCost with FixedVectorType ScalarTy. (#109369)
1 parent f77bbc0 commit 0062975

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12326,8 +12326,7 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
1232612326
// Find the cost of inserting/extracting values from the vector.
1232712327
// Check if the same elements are inserted several times and count them as
1232812328
// shuffle candidates.
12329-
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
12330-
APInt ShuffledElements = APInt::getZero(VecTy->getNumElements());
12329+
APInt ShuffledElements = APInt::getZero(VL.size());
1233112330
DenseMap<Value *, unsigned> UniqueElements;
1233212331
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1233312332
InstructionCost Cost;
@@ -12347,8 +12346,7 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
1234712346
Value *V = VL[I];
1234812347
// No need to shuffle duplicates for constants.
1234912348
if ((ForPoisonSrc && isConstant(V)) || isa<UndefValue>(V)) {
12350-
ShuffledElements.setBits(I * ScalarTyNumElements,
12351-
I * ScalarTyNumElements + ScalarTyNumElements);
12349+
ShuffledElements.setBit(I);
1235212350
ShuffleMask[I] = isa<PoisonValue>(V) ? PoisonMaskElem : I;
1235312351
continue;
1235412352
}
@@ -12361,14 +12359,27 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
1236112359
}
1236212360

1236312361
DuplicateNonConst = true;
12364-
ShuffledElements.setBits(I * ScalarTyNumElements,
12365-
I * ScalarTyNumElements + ScalarTyNumElements);
12362+
ShuffledElements.setBit(I);
1236612363
ShuffleMask[I] = Res.first->second;
1236712364
}
12368-
if (ForPoisonSrc)
12369-
Cost =
12370-
TTI->getScalarizationOverhead(VecTy, ~ShuffledElements, /*Insert*/ true,
12371-
/*Extract*/ false, CostKind);
12365+
if (ForPoisonSrc) {
12366+
if (isa<FixedVectorType>(ScalarTy)) {
12367+
assert(SLPReVec && "Only supported by REVEC.");
12368+
// We don't need to insert elements one by one. Instead, we can insert the
12369+
// entire vector into the destination.
12370+
Cost = 0;
12371+
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
12372+
for (unsigned I : seq<unsigned>(VL.size()))
12373+
if (!ShuffledElements[I])
12374+
Cost += TTI->getShuffleCost(
12375+
TTI::SK_InsertSubvector, VecTy, std::nullopt, CostKind,
12376+
I * ScalarTyNumElements, cast<FixedVectorType>(ScalarTy));
12377+
} else {
12378+
Cost = TTI->getScalarizationOverhead(VecTy, ~ShuffledElements,
12379+
/*Insert*/ true,
12380+
/*Extract*/ false, CostKind);
12381+
}
12382+
}
1237212383
if (DuplicateNonConst)
1237312384
Cost += ::getShuffleCost(*TTI, TargetTransformInfo::SK_PermuteSingleSrc,
1237412385
VecTy, ShuffleMask);
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -mtriple=riscv64 -mcpu=sifive-x280 -passes=slp-vectorizer -S -slp-revec -slp-max-reg-size=1024 -slp-threshold=-10 -pass-remarks-output=%t %s | FileCheck %s
3+
; RUN: FileCheck --input-file=%t --check-prefix=YAML %s
4+
5+
; YAML: --- !Passed
6+
; YAML: Pass: slp-vectorizer
7+
; YAML: Name: StoresVectorized
8+
; YAML: Function: test
9+
; YAML: Args:
10+
; YAML: - String: 'Stores SLP vectorized with cost '
11+
; YAML: - Cost: '6'
12+
; YAML: - String: ' and with tree size '
13+
; YAML: - TreeSize: '5'
14+
15+
define void @test(<4 x float> %load6, <4 x float> %load7, <4 x float> %load8, <4 x float> %load17, <4 x float> %fmuladd7, <4 x float> %fmuladd16, ptr %out_ptr) {
16+
; CHECK-LABEL: @test(
17+
; CHECK-NEXT: entry:
18+
; CHECK-NEXT: [[VEXT165_I:%.*]] = shufflevector <4 x float> [[LOAD6:%.*]], <4 x float> [[LOAD7:%.*]], <4 x i32> <i32 2, i32 3, i32 4, i32 5>
19+
; CHECK-NEXT: [[VEXT309_I:%.*]] = shufflevector <4 x float> [[LOAD7]], <4 x float> [[LOAD8:%.*]], <4 x i32> <i32 2, i32 3, i32 4, i32 5>
20+
; CHECK-NEXT: [[TMP0:%.*]] = call <8 x float> @llvm.vector.insert.v8f32.v4f32(<8 x float> poison, <4 x float> [[VEXT165_I]], i64 0)
21+
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x float> @llvm.vector.insert.v8f32.v4f32(<8 x float> [[TMP0]], <4 x float> [[VEXT309_I]], i64 4)
22+
; CHECK-NEXT: [[TMP2:%.*]] = call <8 x float> @llvm.vector.insert.v8f32.v4f32(<8 x float> poison, <4 x float> poison, i64 4)
23+
; CHECK-NEXT: [[TMP3:%.*]] = call <8 x float> @llvm.vector.insert.v8f32.v4f32(<8 x float> [[TMP2]], <4 x float> [[LOAD17:%.*]], i64 0)
24+
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x float> [[TMP3]], <8 x float> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
25+
; CHECK-NEXT: [[TMP5:%.*]] = call <8 x float> @llvm.vector.insert.v8f32.v4f32(<8 x float> poison, <4 x float> [[FMULADD7:%.*]], i64 0)
26+
; CHECK-NEXT: [[TMP6:%.*]] = call <8 x float> @llvm.vector.insert.v8f32.v4f32(<8 x float> [[TMP5]], <4 x float> [[FMULADD16:%.*]], i64 4)
27+
; CHECK-NEXT: [[TMP7:%.*]] = call <8 x float> @llvm.fmuladd.v8f32(<8 x float> [[TMP1]], <8 x float> [[TMP4]], <8 x float> [[TMP6]])
28+
; CHECK-NEXT: store <8 x float> [[TMP7]], ptr [[OUT_PTR:%.*]], align 4
29+
; CHECK-NEXT: ret void
30+
;
31+
entry:
32+
%vext165.i = shufflevector <4 x float> %load6, <4 x float> %load7, <4 x i32> <i32 2, i32 3, i32 4, i32 5>
33+
%vext309.i = shufflevector <4 x float> %load7, <4 x float> %load8, <4 x i32> <i32 2, i32 3, i32 4, i32 5>
34+
%fmuladd8 = tail call noundef <4 x float> @llvm.fmuladd.v4f32(<4 x float> %vext165.i, <4 x float> %load17, <4 x float> %fmuladd7)
35+
%fmuladd17 = tail call noundef <4 x float> @llvm.fmuladd.v4f32(<4 x float> %vext309.i, <4 x float> %load17, <4 x float> %fmuladd16)
36+
%add.ptr.i.i = getelementptr inbounds i8, ptr %out_ptr, i64 16
37+
store <4 x float> %fmuladd8, ptr %out_ptr, align 4
38+
store <4 x float> %fmuladd17, ptr %add.ptr.i.i, align 4
39+
ret void
40+
}
41+
42+
declare <4 x float> @llvm.fmuladd.v4f32(<4 x float>, <4 x float>, <4 x float>)

0 commit comments

Comments
 (0)