Skip to content

Commit 79261d4

Browse files
authored
[NVPTX][InferAS] assume alloca instructions are in local AS (#121710)
1 parent ca0406d commit 79261d4

File tree

11 files changed

+339
-152
lines changed

11 files changed

+339
-152
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "llvm/Support/CommandLine.h"
2727
#include "llvm/Support/ErrorHandling.h"
2828
#include "llvm/Support/FormatVariadic.h"
29+
#include <optional>
2930

3031
using namespace llvm;
3132

@@ -342,30 +343,28 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
342343
return true;
343344
}
344345

345-
static unsigned int getCodeAddrSpace(MemSDNode *N) {
346-
const Value *Src = N->getMemOperand()->getValue();
347-
348-
if (!Src)
346+
static std::optional<unsigned> convertAS(unsigned AS) {
347+
switch (AS) {
348+
case llvm::ADDRESS_SPACE_LOCAL:
349+
return NVPTX::AddressSpace::Local;
350+
case llvm::ADDRESS_SPACE_GLOBAL:
351+
return NVPTX::AddressSpace::Global;
352+
case llvm::ADDRESS_SPACE_SHARED:
353+
return NVPTX::AddressSpace::Shared;
354+
case llvm::ADDRESS_SPACE_GENERIC:
349355
return NVPTX::AddressSpace::Generic;
350-
351-
if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
352-
switch (PT->getAddressSpace()) {
353-
case llvm::ADDRESS_SPACE_LOCAL:
354-
return NVPTX::AddressSpace::Local;
355-
case llvm::ADDRESS_SPACE_GLOBAL:
356-
return NVPTX::AddressSpace::Global;
357-
case llvm::ADDRESS_SPACE_SHARED:
358-
return NVPTX::AddressSpace::Shared;
359-
case llvm::ADDRESS_SPACE_GENERIC:
360-
return NVPTX::AddressSpace::Generic;
361-
case llvm::ADDRESS_SPACE_PARAM:
362-
return NVPTX::AddressSpace::Param;
363-
case llvm::ADDRESS_SPACE_CONST:
364-
return NVPTX::AddressSpace::Const;
365-
default: break;
366-
}
356+
case llvm::ADDRESS_SPACE_PARAM:
357+
return NVPTX::AddressSpace::Param;
358+
case llvm::ADDRESS_SPACE_CONST:
359+
return NVPTX::AddressSpace::Const;
360+
default:
361+
return std::nullopt;
367362
}
368-
return NVPTX::AddressSpace::Generic;
363+
}
364+
365+
static unsigned int getCodeAddrSpace(const MemSDNode *N) {
366+
return convertAS(N->getMemOperand()->getAddrSpace())
367+
.value_or(NVPTX::AddressSpace::Generic);
369368
}
370369

371370
namespace {

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,19 @@ static bool shouldConvertToIndirectCall(const CallBase *CB,
14051405
return false;
14061406
}
14071407

1408+
static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1409+
const DataLayout &DL,
1410+
const TargetLowering &TL) {
1411+
if (Ptr->getOpcode() == ISD::FrameIndex) {
1412+
auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
1413+
Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
1414+
ADDRESS_SPACE_LOCAL);
1415+
1416+
return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1417+
}
1418+
return MachinePointerInfo();
1419+
}
1420+
14081421
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14091422
SmallVectorImpl<SDValue> &InVals) const {
14101423

@@ -1564,11 +1577,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15641577
}
15651578

15661579
if (IsByVal) {
1567-
auto PtrVT = getPointerTy(DL);
1568-
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1580+
auto MPI = refinePtrAS(StVal, DAG, DL, *this);
1581+
const EVT PtrVT = StVal.getValueType();
1582+
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
15691583
DAG.getConstant(CurOffset, dl, PtrVT));
1570-
StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1571-
PartAlign);
1584+
1585+
StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
15721586
} else if (ExtendIntegerParam) {
15731587
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
15741588
// zext/sext to i32

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/IR/Value.h"
2121
#include "llvm/Support/Casting.h"
2222
#include "llvm/Support/ErrorHandling.h"
23+
#include "llvm/Support/NVPTXAddrSpace.h"
2324
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2425
#include <optional>
2526
using namespace llvm;
@@ -564,6 +565,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
564565
return nullptr;
565566
}
566567

568+
unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
569+
if (isa<AllocaInst>(V))
570+
return ADDRESS_SPACE_LOCAL;
571+
572+
return -1;
573+
}
574+
567575
void NVPTXTTIImpl::collectKernelLaunchBounds(
568576
const Function &F,
569577
SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const {

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
129129

130130
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
131131
Value *NewV) const;
132+
unsigned getAssumedAddrSpace(const Value *V) const;
132133

133134
void collectKernelLaunchBounds(
134135
const Function &F,

llvm/test/CodeGen/NVPTX/indirect_byval.ll

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@ define internal i32 @foo() {
1717
; CHECK-NEXT: .reg .b64 %SPL;
1818
; CHECK-NEXT: .reg .b16 %rs<2>;
1919
; CHECK-NEXT: .reg .b32 %r<3>;
20-
; CHECK-NEXT: .reg .b64 %rd<3>;
20+
; CHECK-NEXT: .reg .b64 %rd<5>;
2121
; CHECK-EMPTY:
2222
; CHECK-NEXT: // %bb.0: // %entry
2323
; CHECK-NEXT: mov.u64 %SPL, __local_depot0;
2424
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
2525
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
26-
; CHECK-NEXT: ld.u8 %rs1, [%SP+1];
27-
; CHECK-NEXT: add.u64 %rd2, %SP, 0;
26+
; CHECK-NEXT: add.u64 %rd3, %SPL, 1;
27+
; CHECK-NEXT: ld.local.u8 %rs1, [%rd3];
28+
; CHECK-NEXT: add.u64 %rd4, %SP, 0;
2829
; CHECK-NEXT: { // callseq 0, 0
2930
; CHECK-NEXT: .param .align 1 .b8 param0[1];
3031
; CHECK-NEXT: st.param.b8 [param0], %rs1;
3132
; CHECK-NEXT: .param .b64 param1;
32-
; CHECK-NEXT: st.param.b64 [param1], %rd2;
33+
; CHECK-NEXT: st.param.b64 [param1], %rd4;
3334
; CHECK-NEXT: .param .b32 retval0;
3435
; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
3536
; CHECK-NEXT: call (retval0),
@@ -59,19 +60,20 @@ define internal i32 @bar() {
5960
; CHECK-NEXT: .reg .b64 %SP;
6061
; CHECK-NEXT: .reg .b64 %SPL;
6162
; CHECK-NEXT: .reg .b32 %r<3>;
62-
; CHECK-NEXT: .reg .b64 %rd<4>;
63+
; CHECK-NEXT: .reg .b64 %rd<6>;
6364
; CHECK-EMPTY:
6465
; CHECK-NEXT: // %bb.0: // %entry
6566
; CHECK-NEXT: mov.u64 %SPL, __local_depot1;
6667
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
6768
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
68-
; CHECK-NEXT: ld.u64 %rd2, [%SP+8];
69-
; CHECK-NEXT: add.u64 %rd3, %SP, 0;
69+
; CHECK-NEXT: add.u64 %rd3, %SPL, 8;
70+
; CHECK-NEXT: ld.local.u64 %rd4, [%rd3];
71+
; CHECK-NEXT: add.u64 %rd5, %SP, 0;
7072
; CHECK-NEXT: { // callseq 1, 0
7173
; CHECK-NEXT: .param .align 8 .b8 param0[8];
72-
; CHECK-NEXT: st.param.b64 [param0], %rd2;
74+
; CHECK-NEXT: st.param.b64 [param0], %rd4;
7375
; CHECK-NEXT: .param .b64 param1;
74-
; CHECK-NEXT: st.param.b64 [param1], %rd3;
76+
; CHECK-NEXT: st.param.b64 [param1], %rd5;
7577
; CHECK-NEXT: .param .b32 retval0;
7678
; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
7779
; CHECK-NEXT: call (retval0),

0 commit comments

Comments
 (0)