Skip to content

[RFC][BPF] Support Jump Table #133856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/AsmPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/DebugInfo/CodeView/CodeView.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>

namespace llvm {
extern cl::opt<bool> EmitJumpTableSizesSection;

class AddrLabelMap;
class AsmPrinterHandler;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ static cl::opt<bool> BBAddrMapSkipEmitBBEntries(
"unnecessary for some PGOAnalysisMap features."),
cl::Hidden, cl::init(false));

static cl::opt<bool> EmitJumpTableSizesSection(
cl::opt<bool> llvm::EmitJumpTableSizesSection(
"emit-jump-table-sizes-section",
cl::desc("Emit a section containing jump table addresses and sizes"),
cl::Hidden, cl::init(false));
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ struct BPFOperand : public MCParsedAsmOperand {
.Case("callx", true)
.Case("goto", true)
.Case("gotol", true)
.Case("gotox", true)
.Case("may_goto", true)
.Case("*", true)
.Case("exit", true)
Expand Down Expand Up @@ -261,7 +262,6 @@ struct BPFOperand : public MCParsedAsmOperand {
.Case("bswap32", true)
.Case("bswap64", true)
.Case("goto", true)
.Case("gotol", true)
.Case("ll", true)
.Case("skb", true)
.Case("s", true)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/BPF/BPFAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class BPFAsmPrinter : public AsmPrinter {
} // namespace

bool BPFAsmPrinter::doInitialization(Module &M) {
EmitJumpTableSizesSection = true;
AsmPrinter::doInitialization(M);

// Only emit BTF when debuginfo available.
Expand Down
36 changes: 34 additions & 2 deletions llvm/lib/Target/BPF/BPFISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
cl::Hidden, cl::init(false),
cl::desc("Expand memcpy into load/store pairs in order"));

static cl::opt<unsigned> BPFMinimumJumpTableEntries(
"bpf-min-jump-table-entries", cl::init(4), cl::Hidden,
cl::desc("Set minimum number of entries to use a jump table on BPF"));

static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
SDValue Val = {}) {
std::string Str;
Expand Down Expand Up @@ -65,10 +69,11 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,

setOperationAction(ISD::BR_CC, MVT::i64, Custom);
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
setOperationAction(ISD::BRIND, MVT::Other, Expand);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this does remove restriction to not produce indirect jumps?

Is there a way to control if we want to generate indirect jumps "in general" vs., say, "only for large switches"? (Or even only for a particular switch?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this does remove restriction to not produce indirect jumps?
Yes, we do not want to expand 'brind', rather we will do pattern matching with 'brind'.

Is there a way to control if we want to generate indirect jumps "in general" vs., say, "only for large switches"? (Or even only for a particular switch?)

Good point. Let me do some experiments with a flag for this. I am not sure whether I could do 'only for a particular switch', but I will do some investigation. Hopefully can find a s solution for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an option to control how many cases in a switch statement to use jump table. The default is 4 cases. But you can change it with additional clang option, e.g., the minimum number of cases must be 6, then

clang ... -mllvm -bpf-min-jump-table-entries=6

I checked other targets, there are no control for a specific switch. So I think we do not need them for now.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

setOperationAction(ISD::BRCOND, MVT::Other, Expand);

setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom);
setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable,
ISD::BlockAddress},
MVT::i64, Custom);

setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
Expand Down Expand Up @@ -155,6 +160,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,

setBooleanContents(ZeroOrOneBooleanContent);
setMaxAtomicSizeInBitsSupported(64);
setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);

// Function alignments
setMinFunctionAlignment(Align(8));
Expand Down Expand Up @@ -312,10 +318,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode()));
case ISD::BR_CC:
return LowerBR_CC(Op, DAG);
case ISD::JumpTable:
return LowerJumpTable(Op, DAG);
case ISD::GlobalAddress:
return LowerGlobalAddress(Op, DAG);
case ISD::ConstantPool:
return LowerConstantPool(Op, DAG);
case ISD::BlockAddress:
return LowerBlockAddress(Op, DAG);
case ISD::SELECT_CC:
return LowerSELECT_CC(Op, DAG);
case ISD::SDIV:
Expand Down Expand Up @@ -726,6 +736,11 @@ SDValue BPFTargetLowering::LowerATOMIC_LOAD_STORE(SDValue Op,
return Op;
}

SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
return getAddr(N, DAG);
}

const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
switch ((BPFISD::NodeType)Opcode) {
case BPFISD::FIRST_NUMBER:
Expand Down Expand Up @@ -757,6 +772,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty,
N->getOffset(), Flags);
}

static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty,
SelectionDAG &DAG, unsigned Flags) {
return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
Flags);
}

static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,
SelectionDAG &DAG, unsigned Flags) {
return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
}

template <class NodeTy>
SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
unsigned Flags) const {
Expand All @@ -783,6 +809,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op,
return getAddr(N, DAG);
}

SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op,
SelectionDAG &DAG) const {
BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
return getAddr(N, DAG);
}

unsigned
BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB,
unsigned Reg, bool isSigned) const {
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/BPF/BPFISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class BPFTargetLowering : public TargetLowering {
SDValue LowerATOMIC_LOAD_STORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;

template <class NodeTy>
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;
Expand Down
41 changes: 41 additions & 0 deletions llvm/lib/Target/BPF/BPFInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
if (!isUnpredicatedTerminator(*I))
break;

// If a JX insn, we're done.
if (I->getOpcode() == BPF::JX)
break;

// A terminator that isn't a branch can't easily be handled
// by this analysis.
if (!I->isBranch())
Expand Down Expand Up @@ -259,3 +263,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,

return Count;
}

int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
// The pattern looks like:
// %0 = LD_imm64 %jump-table.0 ; load jump-table address
// %1 = ADD_rr %0, $another_reg ; address + offset
// %2 = LDD %1, 0 ; load the actual label
// JX %2
const MachineFunction &MF = *MI.getParent()->getParent();
const MachineRegisterInfo &MRI = MF.getRegInfo();

Register Reg = MI.getOperand(0).getReg();
if (!Reg.isVirtual())
return -1;
MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg);
if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD)
return -1;

Reg = Ldd->getOperand(1).getReg();
if (!Reg.isVirtual())
return -1;
MachineInstr *Add = MRI.getUniqueVRegDef(Reg);
if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr)
return -1;

Reg = Add->getOperand(1).getReg();
if (!Reg.isVirtual())
return -1;
MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg);
if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64)
return -1;

const MachineOperand &MO = LDimm64->getOperand(1);
if (!MO.isJTI())
return -1;

return MO.getIndex();
}
3 changes: 3 additions & 0 deletions llvm/lib/Target/BPF/BPFInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
const DebugLoc &DL,
int *BytesAdded = nullptr) const override;

int getJumpTableIndex(const MachineInstr &MI) const override;

private:
void expandMEMCPY(MachineBasicBlock::iterator) const;

Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/BPF/BPFInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ class TYPE_LD_ST<bits<3> mode, bits<2> size,
let Inst{60-59} = size;
}

// For indirect jump
class TYPE_IND_JMP<bits<4> op, bits<1> srctype,
dag outs, dag ins, string asmstr, list<dag> pattern>
: InstBPF<outs, ins, asmstr, pattern> {

let Inst{63-60} = op;
let Inst{59} = srctype;
}

// jump instructions
class JMP_RR<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
Expand Down Expand Up @@ -216,6 +225,18 @@ class JMP_RI<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
let BPFClass = BPF_JMP;
}

class JMP_IND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
: TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
(outs),
(ins GPR:$dst),
!strconcat(OpcodeStr, " $dst"),
Pattern> {
bits<4> dst;

let Inst{51-48} = dst;
let BPFClass = BPF_JMP;
}

class JMP_JCOND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
: TYPE_ALU_JMP<Opc.Value, BPF_K.Value,
(outs),
Expand Down Expand Up @@ -281,6 +302,10 @@ defm JSLT : J<BPF_JSLT, "s<", BPF_CC_LT, BPF_CC_LT_32>;
defm JSLE : J<BPF_JSLE, "s<=", BPF_CC_LE, BPF_CC_LE_32>;
defm JSET : J<BPF_JSET, "&", NoCond, NoCond>;
def JCOND : JMP_JCOND<BPF_JCOND, "may_goto", []>;

let isIndirectBranch = 1 in {
def JX : JMP_IND<BPF_JA, "gotox", [(brind i64:$dst)]>;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see how it should be done, I just had hardcoded it in my test branch: aspsk@98773c6

}

// ALU instructions
Expand Down Expand Up @@ -851,6 +876,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in {
// load 64-bit global addr into register
def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>;
def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>;
def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>;
def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>;

// 0xffffFFFF doesn't fit into simm32, optimize common case
def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)),
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/BPF/BPFMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const {
case MachineOperand::MO_ConstantPoolIndex:
MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex()));
break;
case MachineOperand::MO_JumpTableIndex:
MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex()));
break;
}

OutMI.addOperand(MCOp);
Expand Down