Skip to content

Commit 489acb2

Browse files
authored
[NVPTX][NFC] Refactor utilities to use std::optional (#109883)
1 parent 7a086e1 commit 489acb2

File tree

3 files changed

+70
-109
lines changed

3 files changed

+70
-109
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -563,21 +563,19 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
563563
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
564564
<< ", " << Maxntidz.value_or(1) << "\n";
565565

566-
unsigned Mincta = 0;
567-
if (getMinCTASm(F, Mincta))
568-
O << ".minnctapersm " << Mincta << "\n";
566+
if (const auto Mincta = getMinCTASm(F))
567+
O << ".minnctapersm " << *Mincta << "\n";
569568

570-
unsigned Maxnreg = 0;
571-
if (getMaxNReg(F, Maxnreg))
572-
O << ".maxnreg " << Maxnreg << "\n";
569+
if (const auto Maxnreg = getMaxNReg(F))
570+
O << ".maxnreg " << *Maxnreg << "\n";
573571

574572
// .maxclusterrank directive requires SM_90 or higher, make sure that we
575573
// filter it out for lower SM versions, as it causes a hard ptxas crash.
576574
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
577575
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
578-
unsigned Maxclusterrank = 0;
579-
if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
580-
O << ".maxclusterrank " << Maxclusterrank << "\n";
576+
if (STI->getSmVersion() >= 90)
577+
if (const auto Maxclusterrank = getMaxClusterRank(F))
578+
O << ".maxclusterrank " << *Maxclusterrank << "\n";
581579
}
582580

583581
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 54 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "NVPTXUtilities.h"
1414
#include "NVPTX.h"
1515
#include "NVPTXTargetMachine.h"
16+
#include "llvm/ADT/StringRef.h"
1617
#include "llvm/IR/Constants.h"
1718
#include "llvm/IR/Function.h"
1819
#include "llvm/IR/GlobalVariable.h"
@@ -130,8 +131,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
130131
}
131132
}
132133

133-
bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
134-
unsigned &retval) {
134+
static std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
135+
const std::string &prop) {
135136
auto &AC = getAnnotationCache();
136137
std::lock_guard<sys::Mutex> Guard(AC.Lock);
137138
const Module *m = gv->getParent();
@@ -140,21 +141,13 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
140141
else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
141142
cacheAnnotationFromMD(m, gv);
142143
if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
143-
return false;
144-
retval = AC.Cache[m][gv][prop][0];
145-
return true;
146-
}
147-
148-
static std::optional<unsigned>
149-
findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
150-
unsigned RetVal;
151-
if (findOneNVVMAnnotation(&GV, PropName, RetVal))
152-
return RetVal;
153-
return std::nullopt;
144+
return std::nullopt;
145+
return AC.Cache[m][gv][prop][0];
154146
}
155147

156-
bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
157-
std::vector<unsigned> &retval) {
148+
static bool findAllNVVMAnnotation(const GlobalValue *gv,
149+
const std::string &prop,
150+
std::vector<unsigned> &retval) {
158151
auto &AC = getAnnotationCache();
159152
std::lock_guard<sys::Mutex> Guard(AC.Lock);
160153
const Module *m = gv->getParent();
@@ -168,25 +161,13 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
168161
return true;
169162
}
170163

171-
bool isTexture(const Value &val) {
172-
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
173-
unsigned Annot;
174-
if (findOneNVVMAnnotation(gv, "texture", Annot)) {
175-
assert((Annot == 1) && "Unexpected annotation on a texture symbol");
164+
static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
165+
if (const auto *GV = dyn_cast<GlobalValue>(&V))
166+
if (const auto Annot = findOneNVVMAnnotation(GV, Prop)) {
167+
assert((*Annot == 1) && "Unexpected annotation on a symbol");
176168
return true;
177169
}
178-
}
179-
return false;
180-
}
181170

182-
bool isSurface(const Value &val) {
183-
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
184-
unsigned Annot;
185-
if (findOneNVVMAnnotation(gv, "surface", Annot)) {
186-
assert((Annot == 1) && "Unexpected annotation on a surface symbol");
187-
return true;
188-
}
189-
}
190171
return false;
191172
}
192173

@@ -220,71 +201,60 @@ bool isParamGridConstant(const Value &V) {
220201
return false;
221202
}
222203

223-
bool isSampler(const Value &val) {
204+
bool isTexture(const Value &V) { return globalHasNVVMAnnotation(V, "texture"); }
205+
206+
bool isSurface(const Value &V) { return globalHasNVVMAnnotation(V, "surface"); }
207+
208+
bool isSampler(const Value &V) {
224209
const char *AnnotationName = "sampler";
225210

226-
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
227-
unsigned Annot;
228-
if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
229-
assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
230-
return true;
231-
}
232-
}
233-
return argHasNVVMAnnotation(val, AnnotationName);
211+
return globalHasNVVMAnnotation(V, AnnotationName) ||
212+
argHasNVVMAnnotation(V, AnnotationName);
234213
}
235214

236-
bool isImageReadOnly(const Value &val) {
237-
return argHasNVVMAnnotation(val, "rdoimage");
215+
bool isImageReadOnly(const Value &V) {
216+
return argHasNVVMAnnotation(V, "rdoimage");
238217
}
239218

240-
bool isImageWriteOnly(const Value &val) {
241-
return argHasNVVMAnnotation(val, "wroimage");
219+
bool isImageWriteOnly(const Value &V) {
220+
return argHasNVVMAnnotation(V, "wroimage");
242221
}
243222

244-
bool isImageReadWrite(const Value &val) {
245-
return argHasNVVMAnnotation(val, "rdwrimage");
223+
bool isImageReadWrite(const Value &V) {
224+
return argHasNVVMAnnotation(V, "rdwrimage");
246225
}
247226

248-
bool isImage(const Value &val) {
249-
return isImageReadOnly(val) || isImageWriteOnly(val) || isImageReadWrite(val);
227+
bool isImage(const Value &V) {
228+
return isImageReadOnly(V) || isImageWriteOnly(V) || isImageReadWrite(V);
250229
}
251230

252-
bool isManaged(const Value &val) {
253-
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
254-
unsigned Annot;
255-
if (findOneNVVMAnnotation(gv, "managed", Annot)) {
256-
assert((Annot == 1) && "Unexpected annotation on a managed symbol");
257-
return true;
258-
}
259-
}
260-
return false;
261-
}
231+
bool isManaged(const Value &V) { return globalHasNVVMAnnotation(V, "managed"); }
262232

263-
std::string getTextureName(const Value &val) {
264-
assert(val.hasName() && "Found texture variable with no name");
265-
return std::string(val.getName());
233+
StringRef getTextureName(const Value &V) {
234+
assert(V.hasName() && "Found texture variable with no name");
235+
return V.getName();
266236
}
267237

268-
std::string getSurfaceName(const Value &val) {
269-
assert(val.hasName() && "Found surface variable with no name");
270-
return std::string(val.getName());
238+
StringRef getSurfaceName(const Value &V) {
239+
assert(V.hasName() && "Found surface variable with no name");
240+
return V.getName();
271241
}
272242

273-
std::string getSamplerName(const Value &val) {
274-
assert(val.hasName() && "Found sampler variable with no name");
275-
return std::string(val.getName());
243+
StringRef getSamplerName(const Value &V) {
244+
assert(V.hasName() && "Found sampler variable with no name");
245+
return V.getName();
276246
}
277247

278248
std::optional<unsigned> getMaxNTIDx(const Function &F) {
279-
return findOneNVVMAnnotation(F, "maxntidx");
249+
return findOneNVVMAnnotation(&F, "maxntidx");
280250
}
281251

282252
std::optional<unsigned> getMaxNTIDy(const Function &F) {
283-
return findOneNVVMAnnotation(F, "maxntidy");
253+
return findOneNVVMAnnotation(&F, "maxntidy");
284254
}
285255

286256
std::optional<unsigned> getMaxNTIDz(const Function &F) {
287-
return findOneNVVMAnnotation(F, "maxntidz");
257+
return findOneNVVMAnnotation(&F, "maxntidz");
288258
}
289259

290260
std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -302,20 +272,20 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
302272
return std::nullopt;
303273
}
304274

305-
bool getMaxClusterRank(const Function &F, unsigned &x) {
306-
return findOneNVVMAnnotation(&F, "maxclusterrank", x);
275+
std::optional<unsigned> getMaxClusterRank(const Function &F) {
276+
return findOneNVVMAnnotation(&F, "maxclusterrank");
307277
}
308278

309279
std::optional<unsigned> getReqNTIDx(const Function &F) {
310-
return findOneNVVMAnnotation(F, "reqntidx");
280+
return findOneNVVMAnnotation(&F, "reqntidx");
311281
}
312282

313283
std::optional<unsigned> getReqNTIDy(const Function &F) {
314-
return findOneNVVMAnnotation(F, "reqntidy");
284+
return findOneNVVMAnnotation(&F, "reqntidy");
315285
}
316286

317287
std::optional<unsigned> getReqNTIDz(const Function &F) {
318-
return findOneNVVMAnnotation(F, "reqntidz");
288+
return findOneNVVMAnnotation(&F, "reqntidz");
319289
}
320290

321291
std::optional<unsigned> getReqNTID(const Function &F) {
@@ -328,21 +298,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
328298
return std::nullopt;
329299
}
330300

331-
bool getMinCTASm(const Function &F, unsigned &x) {
332-
return findOneNVVMAnnotation(&F, "minctasm", x);
301+
std::optional<unsigned> getMinCTASm(const Function &F) {
302+
return findOneNVVMAnnotation(&F, "minctasm");
333303
}
334304

335-
bool getMaxNReg(const Function &F, unsigned &x) {
336-
return findOneNVVMAnnotation(&F, "maxnreg", x);
305+
std::optional<unsigned> getMaxNReg(const Function &F) {
306+
return findOneNVVMAnnotation(&F, "maxnreg");
337307
}
338308

339309
bool isKernelFunction(const Function &F) {
340-
unsigned x = 0;
341-
if (!findOneNVVMAnnotation(&F, "kernel", x)) {
342-
// There is no NVVM metadata, check the calling convention
343-
return F.getCallingConv() == CallingConv::PTX_Kernel;
344-
}
345-
return (x == 1);
310+
if (const auto X = findOneNVVMAnnotation(&F, "kernel"))
311+
return (*X == 1);
312+
313+
// There is no NVVM metadata, check the calling convention
314+
return F.getCallingConv() == CallingConv::PTX_Kernel;
346315
}
347316

348317
MaybeAlign getAlign(const Function &F, unsigned Index) {

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ class TargetMachine;
3232

3333
void clearAnnotationCache(const Module *);
3434

35-
bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
36-
unsigned &);
37-
bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
38-
std::vector<unsigned> &);
39-
4035
bool isTexture(const Value &);
4136
bool isSurface(const Value &);
4237
bool isSampler(const Value &);
@@ -46,23 +41,23 @@ bool isImageWriteOnly(const Value &);
4641
bool isImageReadWrite(const Value &);
4742
bool isManaged(const Value &);
4843

49-
std::string getTextureName(const Value &);
50-
std::string getSurfaceName(const Value &);
51-
std::string getSamplerName(const Value &);
44+
StringRef getTextureName(const Value &);
45+
StringRef getSurfaceName(const Value &);
46+
StringRef getSamplerName(const Value &);
5247

5348
std::optional<unsigned> getMaxNTIDx(const Function &);
5449
std::optional<unsigned> getMaxNTIDy(const Function &);
5550
std::optional<unsigned> getMaxNTIDz(const Function &);
56-
std::optional<unsigned> getMaxNTID(const Function &F);
51+
std::optional<unsigned> getMaxNTID(const Function &);
5752

5853
std::optional<unsigned> getReqNTIDx(const Function &);
5954
std::optional<unsigned> getReqNTIDy(const Function &);
6055
std::optional<unsigned> getReqNTIDz(const Function &);
6156
std::optional<unsigned> getReqNTID(const Function &);
6257

63-
bool getMaxClusterRank(const Function &, unsigned &);
64-
bool getMinCTASm(const Function &, unsigned &);
65-
bool getMaxNReg(const Function &, unsigned &);
58+
std::optional<unsigned> getMaxClusterRank(const Function &);
59+
std::optional<unsigned> getMinCTASm(const Function &);
60+
std::optional<unsigned> getMaxNReg(const Function &);
6661
bool isKernelFunction(const Function &);
6762
bool isParamGridConstant(const Value &);
6863

@@ -75,10 +70,9 @@ Function *getMaybeBitcastedCallee(const CallBase *CB);
7570
inline unsigned promoteScalarArgumentSize(unsigned size) {
7671
if (size <= 32)
7772
return 32;
78-
else if (size <= 64)
73+
if (size <= 64)
7974
return 64;
80-
else
81-
return size;
75+
return size;
8276
}
8377

8478
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);

0 commit comments

Comments
 (0)