Skip to content

Commit d1e6ba5

Browse files
committed
[NVPTX] support switch statement with brx.idx (llvm#102400)
Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx] (https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)). Depending on the heuristics in DAG selection, `switch` statements may now be lowered using `brx.idx`
1 parent 1d9e1c6 commit d1e6ba5

File tree

6 files changed

+169
-8
lines changed

6 files changed

+169
-8
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
38433843
/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
38443844
virtual unsigned getJumpTableEncoding() const;
38453845

3846+
virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
3847+
return getPointerTy(DL);
3848+
}
3849+
38463850
virtual const MCExpr *
38473851
LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
38483852
const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
29772977
// Emit the code for the jump table
29782978
assert(JT.SL && "Should set SDLoc for SelectionDAG!");
29792979
assert(JT.Reg != -1U && "Should lower JT Header first!");
2980-
EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
2980+
EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
29812981
SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
29822982
SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
29832983
SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
30053005
// This value may be smaller or larger than the target's pointer type, and
30063006
// therefore require extension or truncating.
30073007
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
3008-
SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
3008+
SwitchOp =
3009+
DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
30093010

30103011
unsigned JumpTableReg =
3011-
FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
3012-
SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
3013-
JumpTableReg, SwitchOp);
3012+
FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
3013+
SDValue CopyTo =
3014+
DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
30143015
JT.Reg = JumpTableReg;
30153016

30163017
if (!JTH.FallthroughUnreachable) {

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/CodeGen/Analysis.h"
2626
#include "llvm/CodeGen/ISDOpcodes.h"
2727
#include "llvm/CodeGen/MachineFunction.h"
28+
#include "llvm/CodeGen/MachineJumpTableInfo.h"
2829
#include "llvm/CodeGen/MachineMemOperand.h"
2930
#include "llvm/CodeGen/SelectionDAG.h"
3031
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
582583
setOperationAction(ISD::ROTR, MVT::i8, Expand);
583584
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
584585

585-
// Indirect branch is not supported.
586-
// This also disables Jump Table creation.
587-
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
586+
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
588587
setOperationAction(ISD::BRIND, MVT::Other, Expand);
589588

590589
setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
945944
MAKE_CASE(NVPTXISD::Dummy)
946945
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
947946
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
947+
MAKE_CASE(NVPTXISD::BrxEnd)
948+
MAKE_CASE(NVPTXISD::BrxItem)
949+
MAKE_CASE(NVPTXISD::BrxStart)
948950
MAKE_CASE(NVPTXISD::Tex1DFloatS32)
949951
MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
950952
MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27852787
return LowerFP_ROUND(Op, DAG);
27862788
case ISD::FP_EXTEND:
27872789
return LowerFP_EXTEND(Op, DAG);
2790+
case ISD::BR_JT:
2791+
return LowerBR_JT(Op, DAG);
27882792
case ISD::VAARG:
27892793
return LowerVAARG(Op, DAG);
27902794
case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28102814
}
28112815
}
28122816

2817+
SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2818+
SDLoc DL(Op);
2819+
SDValue Chain = Op.getOperand(0);
2820+
const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
2821+
SDValue Index = Op.getOperand(2);
2822+
2823+
unsigned JId = JT->getIndex();
2824+
MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
2825+
ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2826+
2827+
SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
2828+
2829+
// Generate BrxStart node
2830+
SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
2831+
Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
2832+
2833+
// Generate BrxItem nodes
2834+
assert(!MBBs.empty());
2835+
for (MachineBasicBlock *MBB : MBBs.drop_back())
2836+
Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
2837+
DAG.getBasicBlock(MBB), Chain.getValue(1));
2838+
2839+
// Generate BrxEnd nodes
2840+
SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
2841+
IdV, Chain.getValue(1)};
2842+
SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
2843+
2844+
return BrxEnd;
2845+
}
2846+
2847+
// This will prevent AsmPrinter from trying to print the jump tables itself.
2848+
unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
2849+
return MachineJumpTableInfo::EK_Inline;
2850+
}
2851+
28132852
// This function is almost a copy of SelectionDAG::expandVAArg().
28142853
// The only diff is that this one produces loads from local address space.
28152854
SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
6262
BFI,
6363
PRMT,
6464
DYNAMIC_STACKALLOC,
65+
BrxStart,
66+
BrxItem,
67+
BrxEnd,
6568
Dummy,
6669

6770
LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
580583
return true;
581584
}
582585

586+
// The default is the same as pointer type, but brx.idx only accepts i32
587+
MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
588+
589+
unsigned getJumpTableEncoding() const override;
590+
583591
bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
584592

585593
// The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
637645

638646
SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
639647

648+
SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
649+
640650
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
641651
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
642652

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3880,6 +3880,44 @@ def DYNAMIC_STACKALLOC64 :
38803880
[(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
38813881
Requires<[hasPTX<73>, hasSM<52>]>;
38823882

3883+
3884+
//
3885+
// BRX
3886+
//
3887+
3888+
def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
3889+
def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
3890+
def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
3891+
3892+
def brx_start :
3893+
SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
3894+
[SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
3895+
def brx_item :
3896+
SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
3897+
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
3898+
def brx_end :
3899+
SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
3900+
[SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
3901+
3902+
let isTerminator = 1, isBranch = 1, isIndirectBranch = 1 in {
3903+
3904+
def BRX_START :
3905+
NVPTXInst<(outs), (ins i32imm:$id),
3906+
"$$L_brx_$id: .branchtargets",
3907+
[(brx_start (i32 imm:$id))]>;
3908+
3909+
def BRX_ITEM :
3910+
NVPTXInst<(outs), (ins brtarget:$target),
3911+
"\t$target,",
3912+
[(brx_item bb:$target)]>;
3913+
3914+
def BRX_END :
3915+
NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
3916+
"\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
3917+
[(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]>;
3918+
}
3919+
3920+
38833921
include "NVPTXIntrinsics.td"
38843922

38853923
//-----------------------------------

llvm/test/CodeGen/NVPTX/jump-table.ll

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
7+
@out = addrspace(1) global i32 0, align 4
8+
9+
define void @foo(i32 %i) {
10+
; CHECK-LABEL: foo(
11+
; CHECK: {
12+
; CHECK-NEXT: .reg .pred %p<2>;
13+
; CHECK-NEXT: .reg .b32 %r<7>;
14+
; CHECK-EMPTY:
15+
; CHECK-NEXT: // %bb.0: // %entry
16+
; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0];
17+
; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3;
18+
; CHECK-NEXT: @%p1 bra $L__BB0_6;
19+
; CHECK-NEXT: // %bb.1: // %entry
20+
; CHECK-NEXT: $L_brx_0: .branchtargets
21+
; CHECK-NEXT: $L__BB0_2,
22+
; CHECK-NEXT: $L__BB0_3,
23+
; CHECK-NEXT: $L__BB0_4,
24+
; CHECK-NEXT: $L__BB0_5;
25+
; CHECK-NEXT: brx.idx %r2, $L_brx_0;
26+
; CHECK-NEXT: $L__BB0_2: // %case0
27+
; CHECK-NEXT: mov.b32 %r6, 0;
28+
; CHECK-NEXT: st.global.u32 [out], %r6;
29+
; CHECK-NEXT: bra.uni $L__BB0_6;
30+
; CHECK-NEXT: $L__BB0_4: // %case2
31+
; CHECK-NEXT: mov.b32 %r4, 2;
32+
; CHECK-NEXT: st.global.u32 [out], %r4;
33+
; CHECK-NEXT: bra.uni $L__BB0_6;
34+
; CHECK-NEXT: $L__BB0_5: // %case3
35+
; CHECK-NEXT: mov.b32 %r3, 3;
36+
; CHECK-NEXT: st.global.u32 [out], %r3;
37+
; CHECK-NEXT: bra.uni $L__BB0_6;
38+
; CHECK-NEXT: $L__BB0_3: // %case1
39+
; CHECK-NEXT: mov.b32 %r5, 1;
40+
; CHECK-NEXT: st.global.u32 [out], %r5;
41+
; CHECK-NEXT: $L__BB0_6: // %end
42+
; CHECK-NEXT: ret;
43+
entry:
44+
switch i32 %i, label %end [
45+
i32 0, label %case0
46+
i32 1, label %case1
47+
i32 2, label %case2
48+
i32 3, label %case3
49+
]
50+
51+
case0:
52+
store i32 0, ptr addrspace(1) @out, align 4
53+
br label %end
54+
55+
case1:
56+
store i32 1, ptr addrspace(1) @out, align 4
57+
br label %end
58+
59+
case2:
60+
store i32 2, ptr addrspace(1) @out, align 4
61+
br label %end
62+
63+
case3:
64+
store i32 3, ptr addrspace(1) @out, align 4
65+
br label %end
66+
67+
end:
68+
ret void
69+
}

0 commit comments

Comments
 (0)