@@ -12043,6 +12043,9 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
12043
12043
/// Adds 2 input vectors and the mask for their shuffling.
12044
12044
void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
12045
12045
assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors.");
12046
+ assert(isa<FixedVectorType>(V1->getType()) &&
12047
+ isa<FixedVectorType>(V2->getType()) &&
12048
+ "castToScalarTyElem expects V1 and V2 to be FixedVectorType");
12046
12049
V1 = castToScalarTyElem(V1);
12047
12050
V2 = castToScalarTyElem(V2);
12048
12051
if (InVectors.empty()) {
@@ -12072,22 +12075,18 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
12072
12075
}
12073
12076
/// Adds another one input vector and the mask for the shuffling.
12074
12077
void add(Value *V1, ArrayRef<int> Mask, bool = false) {
12078
+ assert(isa<FixedVectorType>(V1->getType()) &&
12079
+ "castToScalarTyElem expects V1 to be FixedVectorType");
12075
12080
V1 = castToScalarTyElem(V1);
12076
12081
if (InVectors.empty()) {
12077
- if (!isa<FixedVectorType>(V1->getType())) {
12078
- V1 = createShuffle(V1, nullptr, CommonMask);
12079
- CommonMask.assign(Mask.size(), PoisonMaskElem);
12080
- transformMaskAfterShuffle(CommonMask, Mask);
12081
- }
12082
12082
InVectors.push_back(V1);
12083
12083
CommonMask.assign(Mask.begin(), Mask.end());
12084
12084
return;
12085
12085
}
12086
12086
const auto *It = find(InVectors, V1);
12087
12087
if (It == InVectors.end()) {
12088
12088
if (InVectors.size() == 2 ||
12089
- InVectors.front()->getType() != V1->getType() ||
12090
- !isa<FixedVectorType>(V1->getType())) {
12089
+ InVectors.front()->getType() != V1->getType()) {
12091
12090
Value *V = InVectors.front();
12092
12091
if (InVectors.size() == 2) {
12093
12092
V = createShuffle(InVectors.front(), InVectors.back(), CommonMask);
@@ -12121,9 +12120,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
12121
12120
break;
12122
12121
}
12123
12122
}
12124
- int VF = CommonMask.size();
12125
- if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType()))
12126
- VF = FTy->getNumElements();
12123
+ int VF = cast<FixedVectorType>(V1->getType())->getNumElements();
12127
12124
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
12128
12125
if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
12129
12126
CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
0 commit comments