@@ -832,8 +832,107 @@ struct InstructionsState {
832
832
: OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {}
833
833
};
834
834
835
+ struct InterchangeableInstruction {
836
+ unsigned Opcode;
837
+ SmallVector<Value *> Ops;
838
+ template <class... ArgTypes>
839
+ InterchangeableInstruction(unsigned Opcode, ArgTypes &&...Args)
840
+ : Opcode(Opcode), Ops{std::forward<decltype(Args)>(Args)...} {}
841
+ };
842
+
843
+ bool operator<(const InterchangeableInstruction &LHS,
844
+ const InterchangeableInstruction &RHS) {
845
+ return LHS.Opcode < RHS.Opcode;
846
+ }
847
+
835
848
} // end anonymous namespace
836
849
850
+ /// \returns a sorted list of interchangeable instructions by instruction opcode
851
+ /// that \p I can be converted to.
852
+ /// e.g.,
853
+ /// x << y -> x * (2^y)
854
+ /// x << 1 -> x * 2
855
+ /// x << 0 -> x * 1 -> x - 0 -> x + 0 -> x & 11...1 -> x | 0
856
+ /// x * 0 -> x & 0
857
+ /// x * -1 -> 0 - x
858
+ /// TODO: support more patterns
859
+ static SmallVector<InterchangeableInstruction>
860
+ getInterchangeableInstruction(Instruction *I) {
861
+ // PII = Possible Interchangeable Instruction
862
+ SmallVector<InterchangeableInstruction> PII;
863
+ unsigned Opcode = I->getOpcode();
864
+ PII.emplace_back(Opcode, I->operands());
865
+ if (!is_contained({Instruction::Shl, Instruction::Mul, Instruction::Sub,
866
+ Instruction::Add},
867
+ Opcode))
868
+ return PII;
869
+ Constant *C;
870
+ if (match(I, m_BinOp(m_Value(), m_Constant(C)))) {
871
+ ConstantInt *V = nullptr;
872
+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
873
+ V = CI;
874
+ } else if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
875
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
876
+ V = CI;
877
+ }
878
+ if (!V)
879
+ return PII;
880
+ Value *Op0 = I->getOperand(0);
881
+ Type *Op1Ty = I->getOperand(1)->getType();
882
+ const APInt &Op1Int = V->getValue();
883
+ Constant *Zero =
884
+ ConstantInt::get(Op1Ty, APInt::getZero(Op1Int.getBitWidth()));
885
+ Constant *UnsignedMax =
886
+ ConstantInt::get(Op1Ty, APInt::getMaxValue(Op1Int.getBitWidth()));
887
+ switch (Opcode) {
888
+ case Instruction::Shl: {
889
+ PII.emplace_back(Instruction::Mul, Op0,
890
+ ConstantInt::get(Op1Ty, 1 << Op1Int.getZExtValue()));
891
+ if (Op1Int.isZero()) {
892
+ PII.emplace_back(Instruction::Sub, Op0, Zero);
893
+ PII.emplace_back(Instruction::Add, Op0, Zero);
894
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
895
+ PII.emplace_back(Instruction::Or, Op0, Zero);
896
+ }
897
+ break;
898
+ }
899
+ case Instruction::Mul: {
900
+ switch (Op1Int.getSExtValue()) {
901
+ case 1:
902
+ PII.emplace_back(Instruction::Sub, Op0, Zero);
903
+ PII.emplace_back(Instruction::Add, Op0, Zero);
904
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
905
+ PII.emplace_back(Instruction::Or, Op0, Zero);
906
+ break;
907
+ case 0:
908
+ PII.emplace_back(Instruction::And, Op0, Zero);
909
+ break;
910
+ case -1:
911
+ PII.emplace_back(Instruction::Sub, Zero, Op0);
912
+ break;
913
+ }
914
+ break;
915
+ }
916
+ case Instruction::Sub:
917
+ if (Op1Int.isZero()) {
918
+ PII.emplace_back(Instruction::Add, Op0, Zero);
919
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
920
+ PII.emplace_back(Instruction::Or, Op0, Zero);
921
+ }
922
+ break;
923
+ case Instruction::Add:
924
+ if (Op1Int.isZero()) {
925
+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
926
+ PII.emplace_back(Instruction::Or, Op0, Zero);
927
+ }
928
+ break;
929
+ }
930
+ }
931
+ // std::set_intersection requires a sorted range.
932
+ sort(PII);
933
+ return PII;
934
+ }
935
+
837
936
/// \returns true if \p Opcode is allowed as part of the main/alternate
838
937
/// instruction for SLP vectorization.
839
938
///
@@ -938,18 +1037,54 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
938
1037
if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
939
1038
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
940
1039
}
1040
+ // Currently, this is only used for binary ops.
1041
+ // TODO: support all instructions
1042
+ SmallVector<InterchangeableInstruction> InterchangeableOpcode =
1043
+ getInterchangeableInstruction(cast<Instruction>(VL[BaseIndex]));
1044
+ SmallVector<InterchangeableInstruction> AlternateInterchangeableOpcode;
1045
+ auto UpdateInterchangeableOpcode =
1046
+ [](SmallVector<InterchangeableInstruction> &LHS,
1047
+ ArrayRef<InterchangeableInstruction> RHS) {
1048
+ SmallVector<InterchangeableInstruction> NewInterchangeableOpcode;
1049
+ std::set_intersection(LHS.begin(), LHS.end(), RHS.begin(), RHS.end(),
1050
+ std::back_inserter(NewInterchangeableOpcode));
1051
+ if (NewInterchangeableOpcode.empty())
1052
+ return false;
1053
+ LHS = std::move(NewInterchangeableOpcode);
1054
+ return true;
1055
+ };
941
1056
for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
942
1057
auto *I = cast<Instruction>(VL[Cnt]);
943
1058
unsigned InstOpcode = I->getOpcode();
944
1059
if (IsBinOp && isa<BinaryOperator>(I)) {
945
- if (InstOpcode == Opcode || InstOpcode == AltOpcode)
1060
+ SmallVector<InterchangeableInstruction> ThisInterchangeableOpcode(
1061
+ getInterchangeableInstruction(I));
1062
+ if (UpdateInterchangeableOpcode(InterchangeableOpcode,
1063
+ ThisInterchangeableOpcode))
946
1064
continue;
947
- if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
948
- isValidForAlternation(Opcode)) {
949
- AltOpcode = InstOpcode;
950
- AltIndex = Cnt;
1065
+ if (AlternateInterchangeableOpcode.empty()) {
1066
+ InterchangeableOpcode.erase(
1067
+ std::remove_if(InterchangeableOpcode.begin(),
1068
+ InterchangeableOpcode.end(),
1069
+ [](const InterchangeableInstruction &I) {
1070
+ return !isValidForAlternation(I.Opcode);
1071
+ }),
1072
+ InterchangeableOpcode.end());
1073
+ ThisInterchangeableOpcode.erase(
1074
+ std::remove_if(ThisInterchangeableOpcode.begin(),
1075
+ ThisInterchangeableOpcode.end(),
1076
+ [](const InterchangeableInstruction &I) {
1077
+ return !isValidForAlternation(I.Opcode);
1078
+ }),
1079
+ ThisInterchangeableOpcode.end());
1080
+ if (InterchangeableOpcode.empty() || ThisInterchangeableOpcode.empty())
1081
+ return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1082
+ AlternateInterchangeableOpcode = std::move(ThisInterchangeableOpcode);
951
1083
continue;
952
1084
}
1085
+ if (UpdateInterchangeableOpcode(AlternateInterchangeableOpcode,
1086
+ ThisInterchangeableOpcode))
1087
+ continue;
953
1088
} else if (IsCastOp && isa<CastInst>(I)) {
954
1089
Value *Op0 = IBase->getOperand(0);
955
1090
Type *Ty0 = Op0->getType();
@@ -1043,6 +1178,21 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1043
1178
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1044
1179
}
1045
1180
1181
+ if (IsBinOp) {
1182
+ auto FindOp = [&](ArrayRef<InterchangeableInstruction> CandidateOp) {
1183
+ for (Value *V : VL)
1184
+ for (const InterchangeableInstruction &I : CandidateOp)
1185
+ if (cast<Instruction>(V)->getOpcode() == I.Opcode)
1186
+ return cast<Instruction>(V);
1187
+ llvm_unreachable(
1188
+ "Cannot find the candidate instruction for InstructionsState.");
1189
+ };
1190
+ Instruction *MainOp = FindOp(InterchangeableOpcode);
1191
+ Instruction *AltOp = AlternateInterchangeableOpcode.empty()
1192
+ ? MainOp
1193
+ : FindOp(AlternateInterchangeableOpcode);
1194
+ return InstructionsState(VL[BaseIndex], MainOp, AltOp);
1195
+ }
1046
1196
return InstructionsState(VL[BaseIndex], cast<Instruction>(VL[BaseIndex]),
1047
1197
cast<Instruction>(VL[AltIndex]));
1048
1198
}
@@ -2335,24 +2485,41 @@ class BoUpSLP {
2335
2485
: cast<Instruction>(VL[0])->getNumOperands();
2336
2486
OpsVec.resize(NumOperands);
2337
2487
unsigned NumLanes = VL.size();
2338
- for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) {
2488
+ InstructionsState S = getSameOpcode(VL, TLI);
2489
+ for (unsigned OpIdx : seq<unsigned>(NumOperands))
2339
2490
OpsVec[OpIdx].resize(NumLanes);
2340
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
2341
- assert(isa<Instruction>(VL[Lane]) && "Expected instruction");
2342
- // Our tree has just 3 nodes: the root and two operands.
2343
- // It is therefore trivial to get the APO. We only need to check the
2344
- // opcode of VL[Lane] and whether the operand at OpIdx is the LHS or
2345
- // RHS operand. The LHS operand of both add and sub is never attached
2346
- // to an inversese operation in the linearized form, therefore its APO
2347
- // is false. The RHS is true only if VL[Lane] is an inverse operation.
2348
-
2349
- // Since operand reordering is performed on groups of commutative
2350
- // operations or alternating sequences (e.g., +, -), we can safely
2351
- // tell the inverse operations by checking commutativity.
2352
- bool IsInverseOperation = !isCommutative(cast<Instruction>(VL[Lane]));
2491
+ for (auto [I, V] : enumerate(VL)) {
2492
+ assert(isa<Instruction>(V) && "Expected instruction");
2493
+ SmallVector<InterchangeableInstruction> IIList =
2494
+ getInterchangeableInstruction(cast<Instruction>(V));
2495
+ Value *SelectedOp;
2496
+ auto Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
2497
+ return II.Opcode == S.MainOp->getOpcode();
2498
+ });
2499
+ if (Iter == IIList.end()) {
2500
+ Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
2501
+ return II.Opcode == S.AltOp->getOpcode();
2502
+ });
2503
+ SelectedOp = S.AltOp;
2504
+ } else {
2505
+ SelectedOp = S.MainOp;
2506
+ }
2507
+ assert(Iter != IIList.end() &&
2508
+ "Cannot find an interchangeable instruction.");
2509
+ // Our tree has just 3 nodes: the root and two operands.
2510
+ // It is therefore trivial to get the APO. We only need to check the
2511
+ // opcode of V and whether the operand at OpIdx is the LHS or RHS
2512
+ // operand. The LHS operand of both add and sub is never attached to an
2513
+ // inversese operation in the linearized form, therefore its APO is
2514
+ // false. The RHS is true only if V is an inverse operation.
2515
+
2516
+ // Since operand reordering is performed on groups of commutative
2517
+ // operations or alternating sequences (e.g., +, -), we can safely
2518
+ // tell the inverse operations by checking commutativity.
2519
+ bool IsInverseOperation = !isCommutative(cast<Instruction>(SelectedOp));
2520
+ for (unsigned OpIdx : seq<unsigned>(NumOperands)) {
2353
2521
bool APO = (OpIdx == 0) ? false : IsInverseOperation;
2354
- OpsVec[OpIdx][Lane] = {cast<Instruction>(VL[Lane])->getOperand(OpIdx),
2355
- APO, false};
2522
+ OpsVec[OpIdx][I] = {Iter->Ops[OpIdx], APO, false};
2356
2523
}
2357
2524
}
2358
2525
}
@@ -3252,15 +3419,25 @@ class BoUpSLP {
3252
3419
auto *I0 = cast<Instruction>(Scalars[0]);
3253
3420
Operands.resize(I0->getNumOperands());
3254
3421
unsigned NumLanes = Scalars.size();
3255
- for ( unsigned OpIdx = 0, NumOperands = I0->getNumOperands();
3256
- OpIdx != NumOperands; ++OpIdx) {
3422
+ unsigned NumOperands = I0->getNumOperands();
3423
+ for (unsigned OpIdx : seq<unsigned>( NumOperands))
3257
3424
Operands[OpIdx].resize(NumLanes);
3258
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
3259
- auto *I = cast<Instruction>(Scalars[Lane]);
3260
- assert(I->getNumOperands() == NumOperands &&
3261
- "Expected same number of operands");
3262
- Operands[OpIdx][Lane] = I->getOperand(OpIdx);
3263
- }
3425
+ for (auto [I, V] : enumerate(Scalars)) {
3426
+ SmallVector<InterchangeableInstruction> IIList =
3427
+ getInterchangeableInstruction(cast<Instruction>(V));
3428
+ auto Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
3429
+ return II.Opcode == MainOp->getOpcode();
3430
+ });
3431
+ if (Iter == IIList.end())
3432
+ Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
3433
+ return II.Opcode == AltOp->getOpcode();
3434
+ });
3435
+ assert(Iter != IIList.end() &&
3436
+ "Cannot find an interchangeable instruction.");
3437
+ assert(Iter->Ops.size() == NumOperands &&
3438
+ "Expected same number of operands");
3439
+ for (auto [J, Op] : enumerate(Iter->Ops))
3440
+ Operands[J][I] = Op;
3264
3441
}
3265
3442
}
3266
3443
@@ -14935,7 +15112,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
14935
15112
Value *V = Builder.CreateBinOp(
14936
15113
static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
14937
15114
RHS);
14938
- propagateIRFlags(V, E->Scalars, VL0 , It == MinBWs.end());
15115
+ propagateIRFlags(V, E->Scalars, nullptr , It == MinBWs.end());
14939
15116
if (auto *I = dyn_cast<Instruction>(V)) {
14940
15117
V = propagateMetadata(I, E->Scalars);
14941
15118
// Drop nuw flags for abs(sub(commutative), true).
0 commit comments