Skip to content

Commit 687d6fb

Browse files
[NVPTX] Basic support for "grid_constant" (#96125)
- Adds a helper function for checking whether an argument is a [grid_constant](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties). - Adds support for cvta.param using changes from #95289 - Supports escaped grid_constant pointers conservatively, by casting all uses to the generic address space with cvta.param.
1 parent a030c8b commit 687d6fb

File tree

6 files changed

+297
-86
lines changed

6 files changed

+297
-86
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,12 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
15961596
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
15971597
"llvm.nvvm.ptr.gen.to.param">;
15981598

1599+
// sm70+, PTX7.7+
1600+
def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
1601+
[llvm_anyptr_ty],
1602+
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
1603+
"llvm.nvvm.ptr.param.to.gen">;
1604+
15991605
// Move intrinsics, used in nvvm internally
16001606

16011607
def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2475,6 +2475,7 @@ defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
24752475
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
24762476
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
24772477
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
2478+
defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
24782479

24792480
defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
24802481
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@
9595
#include "llvm/Analysis/ValueTracking.h"
9696
#include "llvm/CodeGen/TargetPassConfig.h"
9797
#include "llvm/IR/Function.h"
98+
#include "llvm/IR/IRBuilder.h"
9899
#include "llvm/IR/Instructions.h"
100+
#include "llvm/IR/IntrinsicsNVPTX.h"
99101
#include "llvm/IR/Module.h"
100102
#include "llvm/IR/Type.h"
101103
#include "llvm/InitializePasses.h"
@@ -336,8 +338,9 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
336338
while (!ValuesToCheck.empty()) {
337339
Value *V = ValuesToCheck.pop_back_val();
338340
if (!IsALoadChainInstr(V)) {
339-
LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
340-
<< "\n");
341+
LLVM_DEBUG(dbgs() << "Need a "
342+
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
343+
<< "of " << *Arg << " because of " << *V << "\n");
341344
(void)Arg;
342345
return false;
343346
}
@@ -366,27 +369,59 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
366369
return;
367370
}
368371

369-
// Otherwise we have to create a temporary copy.
370372
const DataLayout &DL = Func->getParent()->getDataLayout();
371373
unsigned AS = DL.getAllocaAddrSpace();
372-
AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
373-
// Set the alignment to alignment of the byval parameter. This is because,
374-
// later load/stores assume that alignment, and we are going to replace
375-
// the use of the byval parameter with this alloca instruction.
376-
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
377-
.value_or(DL.getPrefTypeAlign(StructType)));
378-
Arg->replaceAllUsesWith(AllocA);
379-
380-
Value *ArgInParam = new AddrSpaceCastInst(
381-
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
382-
FirstInst);
383-
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
384-
// addrspacecast preserves alignment. Since params are constant, this load is
385-
// definitely not volatile.
386-
LoadInst *LI =
387-
new LoadInst(StructType, ArgInParam, Arg->getName(),
388-
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
389-
new StoreInst(LI, AllocA, FirstInst);
374+
if (isParamGridConstant(*Arg)) {
375+
// Writes to a grid constant are undefined behaviour. We do not need a
376+
// temporary copy. When a pointer might have escaped, conservatively replace
377+
// all of its uses (which might include a device function call) with a cast
378+
// to the generic address space.
379+
// TODO: only cast byval grid constant parameters at use points that need
380+
// generic address (e.g., merging parameter pointers with other address
381+
// space, or escaping to call-sites, inline-asm, memory), and use the
382+
// parameter address space for normal loads.
383+
IRBuilder<> IRB(&Func->getEntryBlock().front());
384+
385+
// Cast argument to param address space
386+
auto *CastToParam =
387+
cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
388+
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
389+
390+
// Cast param address to generic address space. We do not use an
391+
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
392+
// generic address space, and a `generic -> param` cast followed by a `param
393+
// -> generic` cast will be folded away. The `param -> generic` intrinsic
394+
// will be correctly lowered to `cvta.param`.
395+
Value *CvtToGenCall = IRB.CreateIntrinsic(
396+
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
397+
CastToParam, nullptr, CastToParam->getName() + ".gen");
398+
399+
Arg->replaceAllUsesWith(CvtToGenCall);
400+
401+
// Do not replace Arg in the cast to param space
402+
CastToParam->setOperand(0, Arg);
403+
} else {
404+
// Otherwise we have to create a temporary copy.
405+
AllocaInst *AllocA =
406+
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
407+
// Set the alignment to alignment of the byval parameter. This is because,
408+
// later load/stores assume that alignment, and we are going to replace
409+
// the use of the byval parameter with this alloca instruction.
410+
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
411+
.value_or(DL.getPrefTypeAlign(StructType)));
412+
Arg->replaceAllUsesWith(AllocA);
413+
414+
Value *ArgInParam = new AddrSpaceCastInst(
415+
Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
416+
Arg->getName(), FirstInst);
417+
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
418+
// addrspacecast preserves alignment. Since params are constant, this load
419+
// is definitely not volatile.
420+
LoadInst *LI =
421+
new LoadInst(StructType, ArgInParam, Arg->getName(),
422+
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
423+
new StoreInst(LI, AllocA, FirstInst);
424+
}
390425
}
391426

392427
void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 78 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,46 @@ void clearAnnotationCache(const Module *Mod) {
5252
AC.Cache.erase(Mod);
5353
}
5454

55-
static void cacheAnnotationFromMD(const MDNode *md, key_val_pair_t &retval) {
55+
static void readIntVecFromMDNode(const MDNode *MetadataNode,
56+
std::vector<unsigned> &Vec) {
57+
for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
58+
ConstantInt *Val =
59+
mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
60+
Vec.push_back(Val->getZExtValue());
61+
}
62+
}
63+
64+
static void cacheAnnotationFromMD(const MDNode *MetadataNode,
65+
key_val_pair_t &retval) {
5666
auto &AC = getAnnotationCache();
5767
std::lock_guard<sys::Mutex> Guard(AC.Lock);
58-
assert(md && "Invalid mdnode for annotation");
59-
assert((md->getNumOperands() % 2) == 1 && "Invalid number of operands");
68+
assert(MetadataNode && "Invalid mdnode for annotation");
69+
assert((MetadataNode->getNumOperands() % 2) == 1 &&
70+
"Invalid number of operands");
6071
// start index = 1, to skip the global variable key
6172
// increment = 2, to skip the value for each property-value pairs
62-
for (unsigned i = 1, e = md->getNumOperands(); i != e; i += 2) {
73+
for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
6374
// property
64-
const MDString *prop = dyn_cast<MDString>(md->getOperand(i));
75+
const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
6576
assert(prop && "Annotation property not a string");
77+
std::string Key = prop->getString().str();
6678

6779
// value
68-
ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(md->getOperand(i + 1));
69-
assert(Val && "Value operand not a constant int");
70-
71-
std::string keyname = prop->getString().str();
72-
if (retval.find(keyname) != retval.end())
73-
retval[keyname].push_back(Val->getZExtValue());
74-
else {
75-
std::vector<unsigned> tmp;
76-
tmp.push_back(Val->getZExtValue());
77-
retval[keyname] = tmp;
80+
if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
81+
MetadataNode->getOperand(i + 1))) {
82+
retval[Key].push_back(Val->getZExtValue());
83+
} else if (MDNode *VecMd =
84+
dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
85+
// note: only "grid_constant" annotations support vector MDNodes.
86+
// assert: there can only exist one unique key value pair of
87+
// the form (string key, MDNode node). Operands of such a node
88+
// shall always be unsigned ints.
89+
if (retval.find(Key) == retval.end()) {
90+
readIntVecFromMDNode(VecMd, retval[Key]);
91+
continue;
92+
}
93+
} else {
94+
llvm_unreachable("Value operand not a constant int or an mdnode");
7895
}
7996
}
8097
}
@@ -153,9 +170,9 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
153170

154171
bool isTexture(const Value &val) {
155172
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
156-
unsigned annot;
157-
if (findOneNVVMAnnotation(gv, "texture", annot)) {
158-
assert((annot == 1) && "Unexpected annotation on a texture symbol");
173+
unsigned Annot;
174+
if (findOneNVVMAnnotation(gv, "texture", Annot)) {
175+
assert((Annot == 1) && "Unexpected annotation on a texture symbol");
159176
return true;
160177
}
161178
}
@@ -164,70 +181,67 @@ bool isTexture(const Value &val) {
164181

165182
bool isSurface(const Value &val) {
166183
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
167-
unsigned annot;
168-
if (findOneNVVMAnnotation(gv, "surface", annot)) {
169-
assert((annot == 1) && "Unexpected annotation on a surface symbol");
184+
unsigned Annot;
185+
if (findOneNVVMAnnotation(gv, "surface", Annot)) {
186+
assert((Annot == 1) && "Unexpected annotation on a surface symbol");
170187
return true;
171188
}
172189
}
173190
return false;
174191
}
175192

176-
bool isSampler(const Value &val) {
177-
const char *AnnotationName = "sampler";
178-
179-
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
180-
unsigned annot;
181-
if (findOneNVVMAnnotation(gv, AnnotationName, annot)) {
182-
assert((annot == 1) && "Unexpected annotation on a sampler symbol");
183-
return true;
184-
}
185-
}
186-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
187-
const Function *func = arg->getParent();
188-
std::vector<unsigned> annot;
189-
if (findAllNVVMAnnotation(func, AnnotationName, annot)) {
190-
if (is_contained(annot, arg->getArgNo()))
193+
static bool argHasNVVMAnnotation(const Value &Val,
194+
const std::string &Annotation,
195+
const bool StartArgIndexAtOne = false) {
196+
if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
197+
const Function *Func = Arg->getParent();
198+
std::vector<unsigned> Annot;
199+
if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
200+
const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
201+
if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
191202
return true;
203+
}
192204
}
193205
}
194206
return false;
195207
}
196208

197-
bool isImageReadOnly(const Value &val) {
198-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
199-
const Function *func = arg->getParent();
200-
std::vector<unsigned> annot;
201-
if (findAllNVVMAnnotation(func, "rdoimage", annot)) {
202-
if (is_contained(annot, arg->getArgNo()))
203-
return true;
209+
bool isParamGridConstant(const Value &V) {
210+
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
211+
// "grid_constant" counts argument indices starting from 1
212+
if (Arg->hasByValAttr() &&
213+
argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
214+
assert(isKernelFunction(*Arg->getParent()) &&
215+
"only kernel arguments can be grid_constant");
216+
return true;
204217
}
205218
}
206219
return false;
207220
}
208221

209-
bool isImageWriteOnly(const Value &val) {
210-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
211-
const Function *func = arg->getParent();
212-
std::vector<unsigned> annot;
213-
if (findAllNVVMAnnotation(func, "wroimage", annot)) {
214-
if (is_contained(annot, arg->getArgNo()))
215-
return true;
222+
bool isSampler(const Value &val) {
223+
const char *AnnotationName = "sampler";
224+
225+
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
226+
unsigned Annot;
227+
if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
228+
assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
229+
return true;
216230
}
217231
}
218-
return false;
232+
return argHasNVVMAnnotation(val, AnnotationName);
233+
}
234+
235+
bool isImageReadOnly(const Value &val) {
236+
return argHasNVVMAnnotation(val, "rdoimage");
237+
}
238+
239+
bool isImageWriteOnly(const Value &val) {
240+
return argHasNVVMAnnotation(val, "wroimage");
219241
}
220242

221243
bool isImageReadWrite(const Value &val) {
222-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
223-
const Function *func = arg->getParent();
224-
std::vector<unsigned> annot;
225-
if (findAllNVVMAnnotation(func, "rdwrimage", annot)) {
226-
if (is_contained(annot, arg->getArgNo()))
227-
return true;
228-
}
229-
}
230-
return false;
244+
return argHasNVVMAnnotation(val, "rdwrimage");
231245
}
232246

233247
bool isImage(const Value &val) {
@@ -236,9 +250,9 @@ bool isImage(const Value &val) {
236250

237251
bool isManaged(const Value &val) {
238252
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
239-
unsigned annot;
240-
if (findOneNVVMAnnotation(gv, "managed", annot)) {
241-
assert((annot == 1) && "Unexpected annotation on a managed symbol");
253+
unsigned Annot;
254+
if (findOneNVVMAnnotation(gv, "managed", Annot)) {
255+
assert((Annot == 1) && "Unexpected annotation on a managed symbol");
242256
return true;
243257
}
244258
}
@@ -323,8 +337,7 @@ bool getMaxNReg(const Function &F, unsigned &x) {
323337

324338
bool isKernelFunction(const Function &F) {
325339
unsigned x = 0;
326-
bool retval = findOneNVVMAnnotation(&F, "kernel", x);
327-
if (!retval) {
340+
if (!findOneNVVMAnnotation(&F, "kernel", x)) {
328341
// There is no NVVM metadata, check the calling convention
329342
return F.getCallingConv() == CallingConv::PTX_Kernel;
330343
}

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ bool getMaxClusterRank(const Function &, unsigned &);
6262
bool getMinCTASm(const Function &, unsigned &);
6363
bool getMaxNReg(const Function &, unsigned &);
6464
bool isKernelFunction(const Function &);
65+
bool isParamGridConstant(const Value &);
6566

6667
MaybeAlign getAlign(const Function &, unsigned);
6768
MaybeAlign getAlign(const CallInst &, unsigned);

0 commit comments

Comments
 (0)