Skip to content

Commit 26b832a

Browse files
authored
[RISCV] Add DAG combine to turn (sub (shl X, 8-Y), (shr X, Y)) into orc.b (#111828)
This patch generalizes the DAG combine for `(sub (shl X, 8), X) => (orc.b X)` into the more general form of `(sub (shl X, 8 - Y), (srl X, Y)) => (orc.b X)`. Alive2 generalized proof: https://alive2.llvm.org/ce/z/dFcf_n Related issue: #96595 Related PR: #96680
1 parent c84f759 commit 26b832a

File tree

2 files changed

+408
-8
lines changed

2 files changed

+408
-8
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13569,8 +13569,10 @@ static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
1356913569
return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
1357013570
}
1357113571

13572-
// Looks for (sub (shl X, 8), X) where only bits 8, 16, 24, 32, etc. of X are
13573-
// non-zero. Replace with orc.b.
13572+
// Looks for (sub (shl X, 8-Y), (shr X, Y)) where the Y-th bit in each byte is
13573+
// potentially set. It is fine for Y to be 0, meaning that (sub (shl X, 8), X)
13574+
// is also valid. Replace with (orc.b X). For example, 0b0000_1000_0000_1000 is
13575+
// valid with Y=3, while 0b0000_1000_0000_0100 is not.
1357413576
static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
1357513577
const RISCVSubtarget &Subtarget) {
1357613578
if (!Subtarget.hasStdExtZbb())
@@ -13584,18 +13586,44 @@ static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
1358413586
SDValue N0 = N->getOperand(0);
1358513587
SDValue N1 = N->getOperand(1);
1358613588

13587-
if (N0.getOpcode() != ISD::SHL || N0.getOperand(0) != N1 || !N0.hasOneUse())
13589+
if (N0->getOpcode() != ISD::SHL)
1358813590
return SDValue();
1358913591

13590-
auto *ShAmtC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
13591-
if (!ShAmtC || ShAmtC->getZExtValue() != 8)
13592+
auto *ShAmtCLeft = dyn_cast<ConstantSDNode>(N0.getOperand(1));
13593+
if (!ShAmtCLeft)
1359213594
return SDValue();
13595+
unsigned ShiftedAmount = 8 - ShAmtCLeft->getZExtValue();
1359313596

13594-
APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0xfe));
13595-
if (!DAG.MaskedValueIsZero(N1, Mask))
13597+
if (ShiftedAmount >= 8)
1359613598
return SDValue();
1359713599

13598-
return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, N1);
13600+
SDValue LeftShiftOperand = N0->getOperand(0);
13601+
SDValue RightShiftOperand = N1;
13602+
13603+
if (ShiftedAmount != 0) { // Right operand must be a right shift.
13604+
if (N1->getOpcode() != ISD::SRL)
13605+
return SDValue();
13606+
auto *ShAmtCRight = dyn_cast<ConstantSDNode>(N1.getOperand(1));
13607+
if (!ShAmtCRight || ShAmtCRight->getZExtValue() != ShiftedAmount)
13608+
return SDValue();
13609+
RightShiftOperand = N1.getOperand(0);
13610+
}
13611+
13612+
// At least one shift should have a single use.
13613+
if (!N0.hasOneUse() && (ShiftedAmount == 0 || !N1.hasOneUse()))
13614+
return SDValue();
13615+
13616+
if (LeftShiftOperand != RightShiftOperand)
13617+
return SDValue();
13618+
13619+
APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0x1));
13620+
Mask <<= ShiftedAmount;
13621+
// Check that X has indeed the right shape (only the Y-th bit can be set in
13622+
// every byte).
13623+
if (!DAG.MaskedValueIsZero(LeftShiftOperand, ~Mask))
13624+
return SDValue();
13625+
13626+
return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, LeftShiftOperand);
1359913627
}
1360013628

1360113629
static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,

0 commit comments

Comments
 (0)