Skip to content

Commit d8e44a9

Browse files
[RISCV] Add late optimization pass for riscv (#133256)
This patch is an alternative to PRs #117060, #131684, #131728. The patch adds a late optimization pass that replaces conditional branches that can be statically evaluated with an unconditinal branch. Adding Michael as a co-author as most of the code that evaluates the condition comes from #131684. Co-authored-by: Michael Maitland [email protected]
1 parent 01e505b commit d8e44a9

16 files changed

+217
-103
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ add_llvm_target(RISCVCodeGen
4747
RISCVISelDAGToDAG.cpp
4848
RISCVISelLowering.cpp
4949
RISCVLandingPadSetup.cpp
50+
RISCVLateBranchOpt.cpp
5051
RISCVLoadStoreOptimizer.cpp
5152
RISCVMachineFunctionInfo.cpp
5253
RISCVMakeCompressible.cpp

llvm/lib/Target/RISCV/RISCV.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ void initializeRISCVLandingPadSetupPass(PassRegistry &);
4040
FunctionPass *createRISCVISelDag(RISCVTargetMachine &TM,
4141
CodeGenOptLevel OptLevel);
4242

43+
FunctionPass *createRISCVLateBranchOptPass();
44+
void initializeRISCVLateBranchOptPass(PassRegistry &);
45+
4346
FunctionPass *createRISCVMakeCompressibleOptPass();
4447
void initializeRISCVMakeCompressibleOptPass(PassRegistry &);
4548

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ static RISCVCC::CondCode getCondFromBranchOpc(unsigned Opc) {
993993
}
994994
}
995995

996-
static bool evaluateCondBranch(unsigned CC, int64_t C0, int64_t C1) {
996+
bool RISCVInstrInfo::evaluateCondBranch(unsigned CC, int64_t C0, int64_t C1) {
997997
switch (CC) {
998998
default:
999999
llvm_unreachable("Unexpected CC");
@@ -1297,6 +1297,31 @@ bool RISCVInstrInfo::reverseBranchCondition(
12971297
return false;
12981298
}
12991299

1300+
// Return true if the instruction is a load immediate instruction (i.e.
1301+
// ADDI x0, imm).
1302+
static bool isLoadImm(const MachineInstr *MI, int64_t &Imm) {
1303+
if (MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
1304+
MI->getOperand(1).getReg() == RISCV::X0) {
1305+
Imm = MI->getOperand(2).getImm();
1306+
return true;
1307+
}
1308+
return false;
1309+
}
1310+
1311+
bool RISCVInstrInfo::isFromLoadImm(const MachineRegisterInfo &MRI,
1312+
const MachineOperand &Op, int64_t &Imm) {
1313+
// Either a load from immediate instruction or X0.
1314+
if (!Op.isReg())
1315+
return false;
1316+
1317+
Register Reg = Op.getReg();
1318+
if (Reg == RISCV::X0) {
1319+
Imm = 0;
1320+
return true;
1321+
}
1322+
return Reg.isVirtual() && isLoadImm(MRI.getVRegDef(Reg), Imm);
1323+
}
1324+
13001325
bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
13011326
MachineBasicBlock *MBB = MI.getParent();
13021327
MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
@@ -1319,31 +1344,10 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
13191344
MI.eraseFromParent();
13201345
};
13211346

1322-
// Right now we only care about LI (i.e. ADDI x0, imm)
1323-
auto isLoadImm = [](const MachineInstr *MI, int64_t &Imm) -> bool {
1324-
if (MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
1325-
MI->getOperand(1).getReg() == RISCV::X0) {
1326-
Imm = MI->getOperand(2).getImm();
1327-
return true;
1328-
}
1329-
return false;
1330-
};
1331-
// Either a load from immediate instruction or X0.
1332-
auto isFromLoadImm = [&](const MachineOperand &Op, int64_t &Imm) -> bool {
1333-
if (!Op.isReg())
1334-
return false;
1335-
Register Reg = Op.getReg();
1336-
if (Reg == RISCV::X0) {
1337-
Imm = 0;
1338-
return true;
1339-
}
1340-
return Reg.isVirtual() && isLoadImm(MRI.getVRegDef(Reg), Imm);
1341-
};
1342-
13431347
// Canonicalize conditional branches which can be constant folded into
13441348
// beqz or bnez. We can't modify the CFG here.
13451349
int64_t C0, C1;
1346-
if (isFromLoadImm(Cond[1], C0) && isFromLoadImm(Cond[2], C1)) {
1350+
if (isFromLoadImm(MRI, Cond[1], C0) && isFromLoadImm(MRI, Cond[2], C1)) {
13471351
unsigned NewCC =
13481352
evaluateCondBranch(CC, C0, C1) ? RISCVCC::COND_EQ : RISCVCC::COND_NE;
13491353
Cond[0] = MachineOperand::CreateImm(NewCC);
@@ -1389,7 +1393,7 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
13891393
return Register();
13901394
};
13911395

1392-
if (isFromLoadImm(LHS, C0) && MRI.hasOneUse(LHS.getReg())) {
1396+
if (isFromLoadImm(MRI, LHS, C0) && MRI.hasOneUse(LHS.getReg())) {
13931397
// Might be case 1.
13941398
// Signed integer overflow is UB. (UINT64_MAX is bigger so we don't need
13951399
// to worry about unsigned overflow here)
@@ -1404,7 +1408,7 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
14041408
modifyBranch();
14051409
return true;
14061410
}
1407-
} else if (isFromLoadImm(RHS, C0) && MRI.hasOneUse(RHS.getReg())) {
1411+
} else if (isFromLoadImm(MRI, RHS, C0) && MRI.hasOneUse(RHS.getReg())) {
14081412
// Might be case 2.
14091413
// For unsigned cases, we don't want C1 to wrap back to UINT64_MAX
14101414
// when C0 is zero.

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
307307
static bool isLdStSafeToPair(const MachineInstr &LdSt,
308308
const TargetRegisterInfo *TRI);
309309

310+
/// Return the result of the evaluation of C0 CC C1, where CC is a
311+
/// RISCVCC::CondCode.
312+
static bool evaluateCondBranch(unsigned CC, int64_t C0, int64_t C1);
313+
314+
/// Return true if the operand is a load immediate instruction and
315+
/// sets Imm to the immediate value.
316+
static bool isFromLoadImm(const MachineRegisterInfo &MRI,
317+
const MachineOperand &Op, int64_t &Imm);
318+
310319
protected:
311320
const RISCVSubtarget &STI;
312321

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//===-- RISCVLateBranchOpt.cpp - Late Stage Branch Optimization -----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// This file provides RISC-V specific target optimizations, currently it's
10+
/// limited to convert conditional branches into unconditional branches when
11+
/// the condition can be statically evaluated.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "RISCVInstrInfo.h"
16+
#include "RISCVSubtarget.h"
17+
18+
using namespace llvm;
19+
20+
#define RISCV_LATE_BRANCH_OPT_NAME "RISC-V Late Branch Optimisation Pass"
21+
22+
namespace {
23+
24+
struct RISCVLateBranchOpt : public MachineFunctionPass {
25+
static char ID;
26+
27+
RISCVLateBranchOpt() : MachineFunctionPass(ID) {}
28+
29+
StringRef getPassName() const override { return RISCV_LATE_BRANCH_OPT_NAME; }
30+
31+
void getAnalysisUsage(AnalysisUsage &AU) const override {
32+
MachineFunctionPass::getAnalysisUsage(AU);
33+
}
34+
35+
bool runOnMachineFunction(MachineFunction &Fn) override;
36+
37+
private:
38+
bool runOnBasicBlock(MachineBasicBlock &MBB) const;
39+
40+
const RISCVInstrInfo *RII = nullptr;
41+
};
42+
} // namespace
43+
44+
char RISCVLateBranchOpt::ID = 0;
45+
INITIALIZE_PASS(RISCVLateBranchOpt, "riscv-late-branch-opt",
46+
RISCV_LATE_BRANCH_OPT_NAME, false, false)
47+
48+
bool RISCVLateBranchOpt::runOnBasicBlock(MachineBasicBlock &MBB) const {
49+
MachineBasicBlock *TBB, *FBB;
50+
SmallVector<MachineOperand, 4> Cond;
51+
if (RII->analyzeBranch(MBB, TBB, FBB, Cond, /*AllowModify=*/false))
52+
return false;
53+
54+
if (!TBB || Cond.size() != 3)
55+
return false;
56+
57+
RISCVCC::CondCode CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
58+
assert(CC != RISCVCC::COND_INVALID);
59+
60+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
61+
62+
// Try and convert a conditional branch that can be evaluated statically
63+
// into an unconditional branch.
64+
int64_t C0, C1;
65+
if (!RISCVInstrInfo::isFromLoadImm(MRI, Cond[1], C0) ||
66+
!RISCVInstrInfo::isFromLoadImm(MRI, Cond[2], C1))
67+
return false;
68+
69+
MachineBasicBlock *Folded =
70+
RISCVInstrInfo::evaluateCondBranch(CC, C0, C1) ? TBB : FBB;
71+
72+
// At this point, its legal to optimize.
73+
RII->removeBranch(MBB);
74+
75+
// Only need to insert a branch if we're not falling through.
76+
if (Folded) {
77+
DebugLoc DL = MBB.findBranchDebugLoc();
78+
RII->insertBranch(MBB, Folded, nullptr, {}, DL);
79+
}
80+
81+
// Update the successors. Remove them all and add back the correct one.
82+
while (!MBB.succ_empty())
83+
MBB.removeSuccessor(MBB.succ_end() - 1);
84+
85+
// If it's a fallthrough, we need to figure out where MBB is going.
86+
if (!Folded) {
87+
MachineFunction::iterator Fallthrough = ++MBB.getIterator();
88+
if (Fallthrough != MBB.getParent()->end())
89+
MBB.addSuccessor(&*Fallthrough);
90+
} else
91+
MBB.addSuccessor(Folded);
92+
93+
return true;
94+
}
95+
96+
bool RISCVLateBranchOpt::runOnMachineFunction(MachineFunction &Fn) {
97+
if (skipFunction(Fn.getFunction()))
98+
return false;
99+
100+
auto &ST = Fn.getSubtarget<RISCVSubtarget>();
101+
RII = ST.getInstrInfo();
102+
103+
bool Changed = false;
104+
for (MachineBasicBlock &MBB : Fn)
105+
Changed |= runOnBasicBlock(MBB);
106+
return Changed;
107+
}
108+
109+
FunctionPass *llvm::createRISCVLateBranchOptPass() {
110+
return new RISCVLateBranchOpt();
111+
}

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
127127
initializeRISCVPostLegalizerCombinerPass(*PR);
128128
initializeKCFIPass(*PR);
129129
initializeRISCVDeadRegisterDefinitionsPass(*PR);
130+
initializeRISCVLateBranchOptPass(*PR);
130131
initializeRISCVMakeCompressibleOptPass(*PR);
131132
initializeRISCVGatherScatterLoweringPass(*PR);
132133
initializeRISCVCodeGenPreparePass(*PR);
@@ -565,6 +566,8 @@ void RISCVPassConfig::addPreEmitPass() {
565566
if (TM->getOptLevel() >= CodeGenOptLevel::Default &&
566567
EnableRISCVCopyPropagation)
567568
addPass(createMachineCopyPropagationPass(true));
569+
if (TM->getOptLevel() >= CodeGenOptLevel::Default)
570+
addPass(createRISCVLateBranchOptPass());
568571
addPass(&BranchRelaxationPassID);
569572
addPass(createRISCVMakeCompressibleOptPass());
570573
}

llvm/test/CodeGen/RISCV/GlobalISel/rv32zbb.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ define i64 @ctpop_i64(i64 %a) nounwind {
357357
define i1 @ctpop_i64_ugt_two(i64 %a) nounwind {
358358
; RV32I-LABEL: ctpop_i64_ugt_two:
359359
; RV32I: # %bb.0:
360-
; RV32I-NEXT: beqz zero, .LBB6_2
360+
; RV32I-NEXT: j .LBB6_2
361361
; RV32I-NEXT: # %bb.1:
362362
; RV32I-NEXT: sltiu a0, zero, 0
363363
; RV32I-NEXT: ret
@@ -404,7 +404,7 @@ define i1 @ctpop_i64_ugt_two(i64 %a) nounwind {
404404
;
405405
; RV32ZBB-LABEL: ctpop_i64_ugt_two:
406406
; RV32ZBB: # %bb.0:
407-
; RV32ZBB-NEXT: beqz zero, .LBB6_2
407+
; RV32ZBB-NEXT: j .LBB6_2
408408
; RV32ZBB-NEXT: # %bb.1:
409409
; RV32ZBB-NEXT: sltiu a0, zero, 0
410410
; RV32ZBB-NEXT: ret
@@ -422,7 +422,7 @@ define i1 @ctpop_i64_ugt_two(i64 %a) nounwind {
422422
define i1 @ctpop_i64_ugt_one(i64 %a) nounwind {
423423
; RV32I-LABEL: ctpop_i64_ugt_one:
424424
; RV32I: # %bb.0:
425-
; RV32I-NEXT: beqz zero, .LBB7_2
425+
; RV32I-NEXT: j .LBB7_2
426426
; RV32I-NEXT: # %bb.1:
427427
; RV32I-NEXT: snez a0, zero
428428
; RV32I-NEXT: ret
@@ -470,7 +470,7 @@ define i1 @ctpop_i64_ugt_one(i64 %a) nounwind {
470470
;
471471
; RV32ZBB-LABEL: ctpop_i64_ugt_one:
472472
; RV32ZBB: # %bb.0:
473-
; RV32ZBB-NEXT: beqz zero, .LBB7_2
473+
; RV32ZBB-NEXT: j .LBB7_2
474474
; RV32ZBB-NEXT: # %bb.1:
475475
; RV32ZBB-NEXT: snez a0, zero
476476
; RV32ZBB-NEXT: ret

llvm/test/CodeGen/RISCV/O3-pipeline.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@
194194
; CHECK-NEXT: Insert XRay ops
195195
; CHECK-NEXT: Implement the 'patchable-function' attribute
196196
; CHECK-NEXT: Machine Copy Propagation Pass
197+
; CHECK-NEXT: RISC-V Late Branch Optimisation Pass
197198
; CHECK-NEXT: Branch relaxation pass
198199
; CHECK-NEXT: RISC-V Make Compressible
199200
; CHECK-NEXT: Contiguously Lay Out Funclets

llvm/test/CodeGen/RISCV/bfloat-br-fcmp.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ declare bfloat @dummy(bfloat)
1111
define void @br_fcmp_false(bfloat %a, bfloat %b) nounwind {
1212
; RV32IZFBFMIN-LABEL: br_fcmp_false:
1313
; RV32IZFBFMIN: # %bb.0:
14-
; RV32IZFBFMIN-NEXT: beqz zero, .LBB0_2
14+
; RV32IZFBFMIN-NEXT: j .LBB0_2
1515
; RV32IZFBFMIN-NEXT: # %bb.1: # %if.then
1616
; RV32IZFBFMIN-NEXT: ret
1717
; RV32IZFBFMIN-NEXT: .LBB0_2: # %if.else
@@ -21,7 +21,7 @@ define void @br_fcmp_false(bfloat %a, bfloat %b) nounwind {
2121
;
2222
; RV64IZFBFMIN-LABEL: br_fcmp_false:
2323
; RV64IZFBFMIN: # %bb.0:
24-
; RV64IZFBFMIN-NEXT: beqz zero, .LBB0_2
24+
; RV64IZFBFMIN-NEXT: j .LBB0_2
2525
; RV64IZFBFMIN-NEXT: # %bb.1: # %if.then
2626
; RV64IZFBFMIN-NEXT: ret
2727
; RV64IZFBFMIN-NEXT: .LBB0_2: # %if.else
@@ -581,7 +581,7 @@ if.then:
581581
define void @br_fcmp_true(bfloat %a, bfloat %b) nounwind {
582582
; RV32IZFBFMIN-LABEL: br_fcmp_true:
583583
; RV32IZFBFMIN: # %bb.0:
584-
; RV32IZFBFMIN-NEXT: beqz zero, .LBB16_2
584+
; RV32IZFBFMIN-NEXT: j .LBB16_2
585585
; RV32IZFBFMIN-NEXT: # %bb.1: # %if.else
586586
; RV32IZFBFMIN-NEXT: ret
587587
; RV32IZFBFMIN-NEXT: .LBB16_2: # %if.then
@@ -591,7 +591,7 @@ define void @br_fcmp_true(bfloat %a, bfloat %b) nounwind {
591591
;
592592
; RV64IZFBFMIN-LABEL: br_fcmp_true:
593593
; RV64IZFBFMIN: # %bb.0:
594-
; RV64IZFBFMIN-NEXT: beqz zero, .LBB16_2
594+
; RV64IZFBFMIN-NEXT: j .LBB16_2
595595
; RV64IZFBFMIN-NEXT: # %bb.1: # %if.else
596596
; RV64IZFBFMIN-NEXT: ret
597597
; RV64IZFBFMIN-NEXT: .LBB16_2: # %if.then

llvm/test/CodeGen/RISCV/branch_zero.ll

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
define void @foo(i16 %finder_idx) {
66
; CHECK-LABEL: foo:
77
; CHECK: # %bb.0: # %entry
8-
; CHECK-NEXT: .LBB0_1: # %for.body
9-
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
8+
; CHECK-NEXT: # %bb.1: # %for.body
109
; CHECK-NEXT: slli a0, a0, 48
1110
; CHECK-NEXT: bltz a0, .LBB0_4
1211
; CHECK-NEXT: # %bb.2: # %while.cond.preheader.i
13-
; CHECK-NEXT: # in Loop: Header=BB0_1 Depth=1
1412
; CHECK-NEXT: li a0, 0
15-
; CHECK-NEXT: bnez zero, .LBB0_1
1613
; CHECK-NEXT: # %bb.3: # %while.body
1714
; CHECK-NEXT: .LBB0_4: # %while.cond1.preheader.i
1815
entry:
@@ -46,14 +43,11 @@ if.then:
4643
define void @bar(i16 %finder_idx) {
4744
; CHECK-LABEL: bar:
4845
; CHECK: # %bb.0: # %entry
49-
; CHECK-NEXT: .LBB1_1: # %for.body
50-
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
46+
; CHECK-NEXT: # %bb.1: # %for.body
5147
; CHECK-NEXT: slli a0, a0, 48
5248
; CHECK-NEXT: bgez a0, .LBB1_4
5349
; CHECK-NEXT: # %bb.2: # %while.cond.preheader.i
54-
; CHECK-NEXT: # in Loop: Header=BB1_1 Depth=1
5550
; CHECK-NEXT: li a0, 0
56-
; CHECK-NEXT: bnez zero, .LBB1_1
5751
; CHECK-NEXT: # %bb.3: # %while.body
5852
; CHECK-NEXT: .LBB1_4: # %while.cond1.preheader.i
5953
entry:

0 commit comments

Comments
 (0)