Skip to content

Commit fa2a6d6

Browse files
authored
[CodeGenPrepare][RISCV] Combine (X ^ Y) and (X == Y) where appropriate (#130922)
Fixes #130510. In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used. If a constant is being used for the XOR before a branch, ensure that it is small enough to fit within a 12-bit immediate field. Otherwise, the equality check is more efficient than the check against 0, see the following: ``` # %bb.0: lui a1, 5 addiw a1, a1, 1365 xor a0, a0, a1 beqz a0, .LBB0_2 # %bb.1: ret .LBB0_2: ``` ``` # %bb.0: lui a1, 5 addiw a1, a1, 1365 beq a0, a1, .LBB0_2 # %bb.1: xor a0, a0, a1 ret .LBB0_2: ``` Similarly, if the XOR is between 1 and a size one integer, we should still fold away the XOR since that comparison can be optimized as a comparison against 0. ``` # %bb.0: slt a0, a0, a1 xor a0, a0, 1 beqz a0, .LBB0_2 # %bb.1: ret .LBB0_2: ``` ``` # %bb.0: slt a0, a0, a1 bnez a0, .LBB0_2 # %bb.1: xor a0, a0, 1 ret .LBB0_2: ``` One question about my code is that I used a hard-coded value for the width of a RISCV ALU immediate. Do you know of a way that I can gather this from the `context`, I was unable to devise one.
1 parent 74ec038 commit fa2a6d6

File tree

3 files changed

+121
-2
lines changed

3 files changed

+121
-2
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8575,7 +8575,8 @@ static bool optimizeBranch(BranchInst *Branch, const TargetLowering &TLI,
85758575
}
85768576
if (Cmp->isEquality() &&
85778577
(match(UI, m_Add(m_Specific(X), m_SpecificInt(-CmpC))) ||
8578-
match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))))) {
8578+
match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))) ||
8579+
match(UI, m_Xor(m_Specific(X), m_SpecificInt(CmpC))))) {
85798580
IRBuilder<> Builder(Branch);
85808581
if (UI->getParent() != Branch->getParent())
85818582
UI->moveBefore(Branch->getIterator());

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17400,12 +17400,56 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
1740017400
return true;
1740117401
}
1740217402

17403+
// If XOR is reused and has an immediate that will fit in XORI,
17404+
// do not fold.
17405+
auto isXorImmediate = [](const SDValue &Op) -> bool {
17406+
if (const auto *XorCnst = dyn_cast<ConstantSDNode>(Op))
17407+
return isInt<12>(XorCnst->getSExtValue());
17408+
return false;
17409+
};
17410+
// Fold (X(i1) ^ 1) == 0 -> X != 0
17411+
auto singleBitOp = [&DAG](const SDValue &VarOp,
17412+
const SDValue &ConstOp) -> bool {
17413+
if (const auto *XorCnst = dyn_cast<ConstantSDNode>(ConstOp)) {
17414+
const APInt Mask = APInt::getBitsSetFrom(VarOp.getValueSizeInBits(), 1);
17415+
return (XorCnst->getSExtValue() == 1) &&
17416+
DAG.MaskedValueIsZero(VarOp, Mask);
17417+
}
17418+
return false;
17419+
};
17420+
auto onlyUsedBySelectOrBR = [](const SDValue &Op) -> bool {
17421+
for (const SDNode *UserNode : Op->users()) {
17422+
const unsigned Opcode = UserNode->getOpcode();
17423+
if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC)
17424+
return false;
17425+
}
17426+
return true;
17427+
};
17428+
auto isFoldableXorEq = [isXorImmediate, singleBitOp, onlyUsedBySelectOrBR](
17429+
const SDValue &LHS, const SDValue &RHS) -> bool {
17430+
return LHS.getOpcode() == ISD::XOR && isNullConstant(RHS) &&
17431+
(!isXorImmediate(LHS.getOperand(1)) ||
17432+
singleBitOp(LHS.getOperand(0), LHS.getOperand(1)) ||
17433+
onlyUsedBySelectOrBR(LHS));
17434+
};
1740317435
// Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
17404-
if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
17436+
if (isFoldableXorEq(LHS, RHS)) {
1740517437
RHS = LHS.getOperand(1);
1740617438
LHS = LHS.getOperand(0);
1740717439
return true;
1740817440
}
17441+
// Fold ((sext (xor X, C)), 0, eq/ne) -> ((sext(X), C, eq/ne)
17442+
if (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG) {
17443+
const SDValue LHS0 = LHS.getOperand(0);
17444+
if (isFoldableXorEq(LHS0, RHS) && isa<ConstantSDNode>(LHS0.getOperand(1))) {
17445+
// SEXT(XOR(X, Y)) -> XOR(SEXT(X), SEXT(Y)))
17446+
RHS = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, LHS.getValueType(),
17447+
LHS0.getOperand(1), LHS.getOperand(1));
17448+
LHS = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, LHS.getValueType(),
17449+
LHS0.getOperand(0), LHS.getOperand(1));
17450+
return true;
17451+
}
17452+
}
1740917453

1741017454
// Fold ((srl (and X, 1<<C), C), 0, eq/ne) -> ((shl X, XLen-1-C), 0, ge/lt)
1741117455
if (isNullConstant(RHS) && LHS.getOpcode() == ISD::SRL && LHS.hasOneUse() &&

llvm/test/CodeGen/RISCV/select-constant-xor.ll

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,77 @@ define i32 @oneusecmp(i32 %a, i32 %b, i32 %d) {
239239
%x = add i32 %s, %s2
240240
ret i32 %x
241241
}
242+
243+
define i32 @xor_branch_imm_ret(i32 %x) nounwind {
244+
; RV32-LABEL: xor_branch_imm_ret:
245+
; RV32: # %bb.0: # %entry
246+
; RV32-NEXT: xori a0, a0, -1365
247+
; RV32-NEXT: beqz a0, .LBB11_2
248+
; RV32-NEXT: # %bb.1: # %if.then
249+
; RV32-NEXT: ret
250+
; RV32-NEXT: .LBB11_2: # %if.end
251+
; RV32-NEXT: addi sp, sp, -16
252+
; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
253+
; RV32-NEXT: call abort
254+
;
255+
; RV64-LABEL: xor_branch_imm_ret:
256+
; RV64: # %bb.0: # %entry
257+
; RV64-NEXT: xori a0, a0, -1365
258+
; RV64-NEXT: sext.w a1, a0
259+
; RV64-NEXT: beqz a1, .LBB11_2
260+
; RV64-NEXT: # %bb.1: # %if.then
261+
; RV64-NEXT: ret
262+
; RV64-NEXT: .LBB11_2: # %if.end
263+
; RV64-NEXT: addi sp, sp, -16
264+
; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
265+
; RV64-NEXT: call abort
266+
entry:
267+
%cmp.not = icmp eq i32 %x, -1365
268+
br i1 %cmp.not, label %if.end, label %if.then
269+
if.then:
270+
%xor = xor i32 %x, -1365
271+
ret i32 %xor
272+
if.end:
273+
tail call void @abort() #2
274+
unreachable
275+
}
276+
277+
define i32 @xor_branch_ret(i32 %x) nounwind {
278+
; RV32-LABEL: xor_branch_ret:
279+
; RV32: # %bb.0: # %entry
280+
; RV32-NEXT: li a1, 1
281+
; RV32-NEXT: slli a1, a1, 11
282+
; RV32-NEXT: beq a0, a1, .LBB12_2
283+
; RV32-NEXT: # %bb.1: # %if.then
284+
; RV32-NEXT: xor a0, a0, a1
285+
; RV32-NEXT: ret
286+
; RV32-NEXT: .LBB12_2: # %if.end
287+
; RV32-NEXT: addi sp, sp, -16
288+
; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
289+
; RV32-NEXT: call abort
290+
;
291+
; RV64-LABEL: xor_branch_ret:
292+
; RV64: # %bb.0: # %entry
293+
; RV64-NEXT: li a1, 1
294+
; RV64-NEXT: slli a1, a1, 11
295+
; RV64-NEXT: sext.w a2, a0
296+
; RV64-NEXT: beq a2, a1, .LBB12_2
297+
; RV64-NEXT: # %bb.1: # %if.then
298+
; RV64-NEXT: xor a0, a0, a1
299+
; RV64-NEXT: ret
300+
; RV64-NEXT: .LBB12_2: # %if.end
301+
; RV64-NEXT: addi sp, sp, -16
302+
; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
303+
; RV64-NEXT: call abort
304+
entry:
305+
%cmp.not = icmp eq i32 %x, 2048
306+
br i1 %cmp.not, label %if.end, label %if.then
307+
if.then:
308+
%xor = xor i32 %x, 2048
309+
ret i32 %xor
310+
if.end:
311+
tail call void @abort() #2
312+
unreachable
313+
}
314+
315+
declare void @abort()

0 commit comments

Comments
 (0)