Skip to content

Commit c1e0f3f

Browse files
committed
[RISCV] Add DAG combine to turn (sub (shl X, 8-Y), (shr X, Y)) into orc.b
This patch generalizes the DAG combine for (sub (shl X, 8), X) => (orc.b X) into the more general form of (sub (shl X, 8 - Y), (srl X, Y)) => (orc.b X). Alive2 generalized proof: https://alive2.llvm.org/ce/z/dFcf_n Related issue: #96595 Related PR: #96680
1 parent b773da0 commit c1e0f3f

File tree

2 files changed

+408
-8
lines changed

2 files changed

+408
-8
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13572,8 +13572,10 @@ static SDValue combineSubOfBoolean(SDNode *N, SelectionDAG &DAG) {
1357213572
return DAG.getNode(ISD::ADD, DL, VT, NewLHS, NewRHS);
1357313573
}
1357413574

13575-
// Looks for (sub (shl X, 8), X) where only bits 8, 16, 24, 32, etc. of X are
13576-
// non-zero. Replace with orc.b.
13575+
// Looks for (sub (shl X, 8-Y), (shr X, Y)) where the Y-th bit in each byte is
13576+
// potentially set. It is fine for Y to be 0, meaning that (sub (shl X, 8), X)
13577+
// is also valid. Replace with (orc.b X). For example, 0b0000_1000_0000_1000 is
13578+
// valid with Y=3, while 0b0000_1000_0000_0100 is not.
1357713579
static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
1357813580
const RISCVSubtarget &Subtarget) {
1357913581
if (!Subtarget.hasStdExtZbb())
@@ -13587,18 +13589,44 @@ static SDValue combineSubShiftToOrcB(SDNode *N, SelectionDAG &DAG,
1358713589
SDValue N0 = N->getOperand(0);
1358813590
SDValue N1 = N->getOperand(1);
1358913591

13590-
if (N0.getOpcode() != ISD::SHL || N0.getOperand(0) != N1 || !N0.hasOneUse())
13592+
if (N0->getOpcode() != ISD::SHL)
13593+
return SDValue();
13594+
13595+
auto *ShAmtCLeft = dyn_cast<ConstantSDNode>(N0.getOperand(1));
13596+
if (!ShAmtCLeft)
13597+
return SDValue();
13598+
unsigned ShiftedAmount = 8 - ShAmtCLeft->getZExtValue();
13599+
13600+
if (ShiftedAmount >= 8)
13601+
return SDValue();
13602+
13603+
SDValue LeftShiftOperand = N0->getOperand(0);
13604+
SDValue RightShiftOperand = N1;
13605+
13606+
if (ShiftedAmount != 0) { // Right operand must be a right shift.
13607+
if (N1->getOpcode() != ISD::SRL)
13608+
return SDValue();
13609+
auto *ShAmtCRight = dyn_cast<ConstantSDNode>(N1.getOperand(1));
13610+
if (!ShAmtCRight || ShAmtCRight->getZExtValue() != ShiftedAmount)
13611+
return SDValue();
13612+
RightShiftOperand = N1.getOperand(0);
13613+
}
13614+
13615+
// At least one shift should have a single use.
13616+
if (!N0.hasOneUse() && (ShiftedAmount == 0 || !N1.hasOneUse()))
1359113617
return SDValue();
1359213618

13593-
auto *ShAmtC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
13594-
if (!ShAmtC || ShAmtC->getZExtValue() != 8)
13619+
if (LeftShiftOperand != RightShiftOperand)
1359513620
return SDValue();
1359613621

13597-
APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0xfe));
13598-
if (!DAG.MaskedValueIsZero(N1, Mask))
13622+
APInt Mask = APInt::getSplat(VT.getSizeInBits(), APInt(8, 0x1));
13623+
Mask <<= ShiftedAmount;
13624+
// Check that X has indeed the right shape (only the Y-th bit can be set in
13625+
// every byte).
13626+
if (!DAG.MaskedValueIsZero(LeftShiftOperand, ~Mask))
1359913627
return SDValue();
1360013628

13601-
return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, N1);
13629+
return DAG.getNode(RISCVISD::ORC_B, SDLoc(N), VT, LeftShiftOperand);
1360213630
}
1360313631

1360413632
static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,

0 commit comments

Comments
 (0)