-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType #93847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
HLSLShaderAttr::ShaderType enum is a subset of llvm::Triple::EnvironmentType and is not needed.
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clang-codegen Author: Helena Kotas (hekota) Changes
Full diff: https://github.com/llvm/llvm-project/pull/93847.diff 5 Files Affected:
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 2665b7353ca4a..a337509d3e2b5 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4469,36 +4469,20 @@ def HLSLShader : InheritableAttr {
let Subjects = SubjectList<[HLSLEntry]>;
let LangOpts = [HLSL];
let Args = [
- EnumArgument<"Type", "ShaderType", /*is_string=*/true,
+ EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
["pixel", "vertex", "geometry", "hull", "domain", "compute",
"raygeneration", "intersection", "anyhit", "closesthit",
"miss", "callable", "mesh", "amplification"],
["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
"RayGeneration", "Intersection", "AnyHit", "ClosestHit",
- "Miss", "Callable", "Mesh", "Amplification"]>
+ "Miss", "Callable", "Mesh", "Amplification"],
+ /*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
];
let Documentation = [HLSLSV_ShaderTypeAttrDocs];
let AdditionalMembers =
[{
- static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;
-
- static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
- switch (ShaderType) {
- case HLSLShaderAttr::Pixel: return llvm::Triple::Pixel;
- case HLSLShaderAttr::Vertex: return llvm::Triple::Vertex;
- case HLSLShaderAttr::Geometry: return llvm::Triple::Geometry;
- case HLSLShaderAttr::Hull: return llvm::Triple::Hull;
- case HLSLShaderAttr::Domain: return llvm::Triple::Domain;
- case HLSLShaderAttr::Compute: return llvm::Triple::Compute;
- case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
- case HLSLShaderAttr::Intersection: return llvm::Triple::Intersection;
- case HLSLShaderAttr::AnyHit: return llvm::Triple::AnyHit;
- case HLSLShaderAttr::ClosestHit: return llvm::Triple::ClosestHit;
- case HLSLShaderAttr::Miss: return llvm::Triple::Miss;
- case HLSLShaderAttr::Callable: return llvm::Triple::Callable;
- case HLSLShaderAttr::Mesh: return llvm::Triple::Mesh;
- case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
- }
+ static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
+ return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
}
}];
}
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index eac1f7c07c85d..00df6c2bd15e4 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -38,7 +38,7 @@ class SemaHLSL : public SemaBase {
const AttributeCommonInfo &AL, int X,
int Y, int Z);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType);
+ llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);
@@ -47,8 +47,8 @@ class SemaHLSL : public SemaBase {
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
+ const Attr *A, llvm::Triple::EnvironmentType Stage,
+ std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
};
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 5e6a3dd4878f4..55ba21ae2ba69 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
const StringRef ShaderAttrKindStr = "hlsl.shader";
Fn->addFnAttr(ShaderAttrKindStr,
- ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
+ llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
const StringRef NumThreadsKindStr = "hlsl.numthreads";
std::string NumThreadsStr =
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 7c1fb23b90728..49c9de73aafb5 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7341,8 +7341,8 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
return;
- HLSLShaderAttr::ShaderType ShaderType;
- if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
+ llvm::Triple::EnvironmentType ShaderType;
+ if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
<< AL << Str << ArgLoc;
return;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9e614ae99f37d..da9bda3eaf3d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -145,7 +145,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType) {
+ llvm::Triple::EnvironmentType ShaderType) {
if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
if (NT->getType() != ShaderType) {
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
@@ -183,13 +183,12 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
return;
- StringRef Env = TargetInfo.getTriple().getEnvironmentName();
- HLSLShaderAttr::ShaderType ShaderType;
- if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+ llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
+ if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
// The entry point is already annotated - check that it matches the
// triple.
- if (Shader->getType() != ShaderType) {
+ if (Shader->getType() != Env) {
Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
<< Shader;
FD->setInvalidDecl();
@@ -197,11 +196,11 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
} else {
// Implicitly add the shader attribute if the entry function isn't
// explicitly annotated.
- FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+ FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
FD->getBeginLoc()));
}
} else {
- switch (TargetInfo.getTriple().getEnvironment()) {
+ switch (Env) {
case llvm::Triple::UnknownEnvironment:
case llvm::Triple::Library:
break;
@@ -214,38 +213,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+ llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
switch (ST) {
- case HLSLShaderAttr::Pixel:
- case HLSLShaderAttr::Vertex:
- case HLSLShaderAttr::Geometry:
- case HLSLShaderAttr::Hull:
- case HLSLShaderAttr::Domain:
- case HLSLShaderAttr::RayGeneration:
- case HLSLShaderAttr::Intersection:
- case HLSLShaderAttr::AnyHit:
- case HLSLShaderAttr::ClosestHit:
- case HLSLShaderAttr::Miss:
- case HLSLShaderAttr::Callable:
+ case llvm::Triple::Pixel:
+ case llvm::Triple::Vertex:
+ case llvm::Triple::Geometry:
+ case llvm::Triple::Hull:
+ case llvm::Triple::Domain:
+ case llvm::Triple::RayGeneration:
+ case llvm::Triple::Intersection:
+ case llvm::Triple::AnyHit:
+ case llvm::Triple::ClosestHit:
+ case llvm::Triple::Miss:
+ case llvm::Triple::Callable:
if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
DiagnoseAttrStageMismatch(NT, ST,
- {HLSLShaderAttr::Compute,
- HLSLShaderAttr::Amplification,
- HLSLShaderAttr::Mesh});
+ {llvm::Triple::Compute,
+ llvm::Triple::Amplification,
+ llvm::Triple::Mesh});
FD->setInvalidDecl();
}
break;
- case HLSLShaderAttr::Compute:
- case HLSLShaderAttr::Amplification:
- case HLSLShaderAttr::Mesh:
+ case llvm::Triple::Compute:
+ case llvm::Triple::Amplification:
+ case llvm::Triple::Mesh:
if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
- << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+ << llvm::Triple::getEnvironmentTypeName(ST);
FD->setInvalidDecl();
}
break;
+ default:
+ llvm_unreachable("Unhandled environment in triple");
}
for (ParmVarDecl *Param : FD->parameters()) {
@@ -267,14 +268,14 @@ void SemaHLSL::CheckSemanticAnnotation(
const HLSLAnnotationAttr *AnnotationAttr) {
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+ llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
- if (ST == HLSLShaderAttr::Compute)
+ if (ST == llvm::Triple::Compute)
return;
- DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+ DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
break;
default:
llvm_unreachable("Unknown HLSLAnnotationAttr");
@@ -282,16 +283,16 @@ void SemaHLSL::CheckSemanticAnnotation(
}
void SemaHLSL::DiagnoseAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+ const Attr *A, llvm::Triple::EnvironmentType Stage,
+ std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
SmallVector<StringRef, 8> StageStrings;
llvm::transform(AllowedStages, std::back_inserter(StageStrings),
- [](HLSLShaderAttr::ShaderType ST) {
+ [](llvm::Triple::EnvironmentType ST) {
return StringRef(
- HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+ HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
});
Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
- << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+ << A << llvm::Triple::getEnvironmentTypeName(Stage)
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
}
@@ -321,16 +322,22 @@ class DiagnoseHLSLAvailability
//
// Maps FunctionDecl to an unsigned number that represents the set of shader
// environments the function has been scanned for.
- // Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
- // defined without any assigned values, it is guaranteed to be numbered
- // sequentially from 0 up and we can use it to 'index' individual bits
- // in the set.
+ // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
+ // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
+ // (verified by static_asserts in Triple.cpp), we can use it to index
+ // individual bits in the set, as long as we shift the values to start with 0
+ // by subtracting the value of llvm::Triple::Pixel first.
+ //
// The N'th bit in the set will be set if the function has been scanned
- // in shader environment whose ShaderType integer value equals N.
+ // in shader environment whose llvm::Triple::EnvironmentType integer value
+ // equals (llvm::Triple::Pixel + N).
+ //
// For example, if a function has been scanned in compute and pixel stage
- // environment, the value will be 0x21 (100001 binary) because
- // (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
- // (int)HLSLShaderAttr::ShaderType::Compute == 5.
+ // environment, the value will be 0x21 (100001 binary) because:
+ //
+ // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
+ // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
+ //
// A FunctionDecl is mapped to 0 (or not included in the map) if it has not
// been scanned in any environment.
llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
@@ -346,12 +353,16 @@ class DiagnoseHLSLAvailability
bool ReportOnlyShaderStageIssues;
// Helper methods for dealing with current stage context / environment
- void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
+ void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
static_assert(sizeof(unsigned) >= 4);
- assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"
-
- CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
- CurrentShaderStageBit = (1 << ShaderType);
+ assert(HLSLShaderAttr::isValidShaderType(ShaderType));
+ assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
+ "ShaderType is too big for this bitmap"); // 31 is reserved for
+ // "unknown"
+
+ unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
+ CurrentShaderEnvironment = ShaderType;
+ CurrentShaderStageBit = (1 << bitmapIndex);
}
void SetUnknownShaderStageContext() {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I like how clean this change is!
…haderType (llvm#93847) `HLSLShaderAttr::ShaderType` enum is a subset of `llvm::Triple::EnvironmentType`. We can use `llvm::Triple::EnvironmentType` directly and avoid converting one enum to another. Signed-off-by: Hafidz Muzakky <[email protected]>
HLSLShaderAttr::ShaderType
enum is a subset ofllvm::Triple::EnvironmentType
. We can usellvm::Triple::EnvironmentType
directly and avoid converting one enum to another.