@@ -828,20 +828,6 @@ class InterchangeableInstruction {
828
828
protected:
829
829
Instruction *const MainOp;
830
830
831
- /// Return non nullptr if the right operand of I is ConstantInt.
832
- static ConstantInt *isBinOpWithConstantInt(Instruction *I) {
833
- Constant *C;
834
- if (!match(I, m_BinOp(m_Value(), m_Constant(C))))
835
- return nullptr;
836
- if (auto *CI = dyn_cast<ConstantInt>(C))
837
- return CI;
838
- if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
839
- if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
840
- return CI;
841
- }
842
- return nullptr;
843
- }
844
-
845
831
public:
846
832
InterchangeableInstruction(Instruction *MainOp) : MainOp(MainOp) {}
847
833
virtual bool isSame(Instruction *I) {
@@ -867,6 +853,29 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
867
853
MaskType Mask = 0b11111111;
868
854
MaskType SeenBefore = 0;
869
855
856
+ /// Return a non-nullptr if either operand of I is a ConstantInt.
857
+ static std::pair<ConstantInt *, unsigned>
858
+ isBinOpWithConstantInt(Instruction *I) {
859
+ unsigned Opcode = I->getOpcode();
860
+ unsigned Pos = 1;
861
+ Constant *C;
862
+ if (!match(I, m_BinOp(m_Value(), m_Constant(C)))) {
863
+ if (Opcode == Instruction::Sub || Opcode == Instruction::Shl ||
864
+ Opcode == Instruction::AShr)
865
+ return std::make_pair(nullptr, Pos);
866
+ if (!match(I, m_BinOp(m_Constant(C), m_Value())))
867
+ return std::make_pair(nullptr, Pos);
868
+ Pos = 0;
869
+ }
870
+ if (auto *CI = dyn_cast<ConstantInt>(C))
871
+ return std::make_pair(CI, Pos);
872
+ if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
873
+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
874
+ return std::make_pair(CI, Pos);
875
+ }
876
+ return std::make_pair(nullptr, Pos);
877
+ }
878
+
870
879
static MaskType opcodeToMask(unsigned Opcode) {
871
880
switch (Opcode) {
872
881
case Instruction::Shl:
@@ -904,26 +913,26 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
904
913
if (!binary_search(SupportedOp, Opcode))
905
914
return false;
906
915
SeenBefore |= opcodeToMask(Opcode);
907
- ConstantInt *CI = isBinOpWithConstantInt(I);
916
+ ConstantInt *CI = isBinOpWithConstantInt(I).first ;
908
917
if (CI) {
909
- const APInt &Op1Int = CI->getValue();
918
+ const APInt &CIValue = CI->getValue();
910
919
switch (Opcode) {
911
920
case Instruction::Shl:
912
- if (Op1Int .isZero())
921
+ if (CIValue .isZero())
913
922
return true;
914
923
return tryAnd(0b101);
915
924
case Instruction::Mul:
916
- if (Op1Int .isOne())
925
+ if (CIValue .isOne())
917
926
return true;
918
- if (Op1Int .isPowerOf2())
927
+ if (CIValue .isPowerOf2())
919
928
return tryAnd(0b101);
920
929
break;
921
930
case Instruction::And:
922
- if (Op1Int .isAllOnes())
931
+ if (CIValue .isAllOnes())
923
932
return true;
924
933
break;
925
934
default:
926
- if (Op1Int .isZero())
935
+ if (CIValue .isZero())
927
936
return true;
928
937
break;
929
938
}
@@ -957,41 +966,48 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
957
966
unsigned FromOpcode = MainOp->getOpcode();
958
967
if (FromOpcode == ToOpcode)
959
968
return SmallVector<Value *>(MainOp->operands());
960
- const APInt &Op1Int = isBinOpWithConstantInt(MainOp)->getValue();
961
- unsigned Op1IntBitWidth = Op1Int.getBitWidth();
962
- APInt RHSV;
969
+ auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
970
+ const APInt &FromCIValue = CI->getValue();
971
+ unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
972
+ APInt ToCIValue;
963
973
switch (FromOpcode) {
964
974
case Instruction::Shl:
965
975
if (ToOpcode == Instruction::Mul) {
966
- RHSV = APInt::getOneBitSet(Op1IntBitWidth, Op1Int.getZExtValue());
976
+ ToCIValue = APInt::getOneBitSet(FromCIValueBitWidth,
977
+ FromCIValue.getZExtValue());
967
978
} else {
968
- assert(Op1Int.isZero() && "Cannot convert the instruction.");
969
- RHSV = ToOpcode == Instruction::And ? APInt::getAllOnes(Op1IntBitWidth)
970
- : APInt::getZero(Op1IntBitWidth);
979
+ assert(FromCIValue.isZero() && "Cannot convert the instruction.");
980
+ ToCIValue = ToOpcode == Instruction::And
981
+ ? APInt::getAllOnes(FromCIValueBitWidth)
982
+ : APInt::getZero(FromCIValueBitWidth);
971
983
}
972
984
break;
973
985
case Instruction::Mul:
974
- assert(Op1Int .isPowerOf2() && "Cannot convert the instruction.");
986
+ assert(FromCIValue .isPowerOf2() && "Cannot convert the instruction.");
975
987
if (ToOpcode == Instruction::Shl) {
976
- RHSV = APInt(Op1IntBitWidth, Op1Int .logBase2());
988
+ ToCIValue = APInt(FromCIValueBitWidth, FromCIValue .logBase2());
977
989
} else {
978
- assert(Op1Int.isOne() && "Cannot convert the instruction.");
979
- RHSV = ToOpcode == Instruction::And ? APInt::getAllOnes(Op1IntBitWidth)
980
- : APInt::getZero(Op1IntBitWidth);
990
+ assert(FromCIValue.isOne() && "Cannot convert the instruction.");
991
+ ToCIValue = ToOpcode == Instruction::And
992
+ ? APInt::getAllOnes(FromCIValueBitWidth)
993
+ : APInt::getZero(FromCIValueBitWidth);
981
994
}
982
995
break;
983
996
case Instruction::And:
984
- assert(Op1Int .isAllOnes() && "Cannot convert the instruction.");
985
- RHSV = ToOpcode == Instruction::Mul
986
- ? APInt::getOneBitSet(Op1IntBitWidth , 0)
987
- : APInt::getZero(Op1IntBitWidth );
997
+ assert(FromCIValue .isAllOnes() && "Cannot convert the instruction.");
998
+ ToCIValue = ToOpcode == Instruction::Mul
999
+ ? APInt::getOneBitSet(FromCIValueBitWidth , 0)
1000
+ : APInt::getZero(FromCIValueBitWidth );
988
1001
break;
989
1002
default:
990
- RHSV = APInt::getZero(Op1IntBitWidth );
1003
+ ToCIValue = APInt::getZero(FromCIValueBitWidth );
991
1004
break;
992
1005
}
993
- return {MainOp->getOperand(0),
994
- ConstantInt::get(MainOp->getOperand(1)->getType(), RHSV)};
1006
+ auto LHS = MainOp->getOperand(1 - Pos);
1007
+ auto RHS = ConstantInt::get(MainOp->getOperand(Pos)->getType(), ToCIValue);
1008
+ if (Pos == 1)
1009
+ return SmallVector<Value *>({LHS, RHS});
1010
+ return SmallVector<Value *>({RHS, LHS});
995
1011
}
996
1012
};
997
1013
0 commit comments