Skip to content

[HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType #93847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 8, 2024

Conversation

hekota
Copy link
Member

@hekota hekota commented May 30, 2024

HLSLShaderAttr::ShaderType enum is a subset of llvm::Triple::EnvironmentType. We can use llvm::Triple::EnvironmentType directly and avoid converting one enum to another.

hekota added 3 commits May 30, 2024 09:10
HLSLShaderAttr::ShaderType enum is a subset of llvm::Triple::EnvironmentType and is not needed.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels May 30, 2024
@llvmbot
Copy link
Member

llvmbot commented May 30, 2024

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clang-codegen

Author: Helena Kotas (hekota)

Changes

HLSLShaderAttr::ShaderType enum is a subset of llvm::Triple::EnvironmentType. We can use llvm::Triple::EnvironmentType directly and avoid converting one enum to another.


Full diff: https://github.com/llvm/llvm-project/pull/93847.diff

5 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+5-21)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+3-3)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+1-1)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+2-2)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+58-47)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 2665b7353ca4a..a337509d3e2b5 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4469,36 +4469,20 @@ def HLSLShader : InheritableAttr {
   let Subjects = SubjectList<[HLSLEntry]>;
   let LangOpts = [HLSL];
   let Args = [
-    EnumArgument<"Type", "ShaderType", /*is_string=*/true,
+    EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
                  ["pixel", "vertex", "geometry", "hull", "domain", "compute",
                   "raygeneration", "intersection", "anyhit", "closesthit",
                   "miss", "callable", "mesh", "amplification"],
                  ["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
                   "RayGeneration", "Intersection", "AnyHit", "ClosestHit",
-                  "Miss", "Callable", "Mesh", "Amplification"]>
+                  "Miss", "Callable", "Mesh", "Amplification"],
+                  /*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
   ];
   let Documentation = [HLSLSV_ShaderTypeAttrDocs];
   let AdditionalMembers =
 [{
-  static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;
-
-  static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
-    switch (ShaderType) {
-      case HLSLShaderAttr::Pixel:         return llvm::Triple::Pixel;
-      case HLSLShaderAttr::Vertex:        return llvm::Triple::Vertex;
-      case HLSLShaderAttr::Geometry:      return llvm::Triple::Geometry;
-      case HLSLShaderAttr::Hull:          return llvm::Triple::Hull;
-      case HLSLShaderAttr::Domain:        return llvm::Triple::Domain;
-      case HLSLShaderAttr::Compute:       return llvm::Triple::Compute;
-      case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
-      case HLSLShaderAttr::Intersection:  return llvm::Triple::Intersection;
-      case HLSLShaderAttr::AnyHit:        return llvm::Triple::AnyHit;
-      case HLSLShaderAttr::ClosestHit:    return llvm::Triple::ClosestHit;
-      case HLSLShaderAttr::Miss:          return llvm::Triple::Miss;
-      case HLSLShaderAttr::Callable:      return llvm::Triple::Callable;
-      case HLSLShaderAttr::Mesh:          return llvm::Triple::Mesh;
-      case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
-    }
+  static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
+    return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
   }
 }];
 }
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index eac1f7c07c85d..00df6c2bd15e4 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -38,7 +38,7 @@ class SemaHLSL : public SemaBase {
                                           const AttributeCommonInfo &AL, int X,
                                           int Y, int Z);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                                  HLSLShaderAttr::ShaderType ShaderType);
+                                  llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
   mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
                          HLSLParamModifierAttr::Spelling Spelling);
@@ -47,8 +47,8 @@ class SemaHLSL : public SemaBase {
   void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
                                const HLSLAnnotationAttr *AnnotationAttr);
   void DiagnoseAttrStageMismatch(
-      const Attr *A, HLSLShaderAttr::ShaderType Stage,
-      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
+      const Attr *A, llvm::Triple::EnvironmentType Stage,
+      std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
   void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
 };
 
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 5e6a3dd4878f4..55ba21ae2ba69 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
   assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
   const StringRef ShaderAttrKindStr = "hlsl.shader";
   Fn->addFnAttr(ShaderAttrKindStr,
-                ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
+                llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
   if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
     const StringRef NumThreadsKindStr = "hlsl.numthreads";
     std::string NumThreadsStr =
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 7c1fb23b90728..49c9de73aafb5 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7341,8 +7341,8 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
     return;
 
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
+  llvm::Triple::EnvironmentType ShaderType;
+  if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
     S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
         << AL << Str << ArgLoc;
     return;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9e614ae99f37d..da9bda3eaf3d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -145,7 +145,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
 
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                          HLSLShaderAttr::ShaderType ShaderType) {
+                          llvm::Triple::EnvironmentType ShaderType) {
   if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
     if (NT->getType() != ShaderType) {
       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
@@ -183,13 +183,12 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
   if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
     return;
 
-  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+  llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
+  if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
     if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
       // The entry point is already annotated - check that it matches the
       // triple.
-      if (Shader->getType() != ShaderType) {
+      if (Shader->getType() != Env) {
         Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
             << Shader;
         FD->setInvalidDecl();
@@ -197,11 +196,11 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
     } else {
       // Implicitly add the shader attribute if the entry function isn't
       // explicitly annotated.
-      FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+      FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
                                                  FD->getBeginLoc()));
     }
   } else {
-    switch (TargetInfo.getTriple().getEnvironment()) {
+    switch (Env) {
     case llvm::Triple::UnknownEnvironment:
     case llvm::Triple::Library:
       break;
@@ -214,38 +213,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
   assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+  llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
 
   switch (ST) {
-  case HLSLShaderAttr::Pixel:
-  case HLSLShaderAttr::Vertex:
-  case HLSLShaderAttr::Geometry:
-  case HLSLShaderAttr::Hull:
-  case HLSLShaderAttr::Domain:
-  case HLSLShaderAttr::RayGeneration:
-  case HLSLShaderAttr::Intersection:
-  case HLSLShaderAttr::AnyHit:
-  case HLSLShaderAttr::ClosestHit:
-  case HLSLShaderAttr::Miss:
-  case HLSLShaderAttr::Callable:
+  case llvm::Triple::Pixel:
+  case llvm::Triple::Vertex:
+  case llvm::Triple::Geometry:
+  case llvm::Triple::Hull:
+  case llvm::Triple::Domain:
+  case llvm::Triple::RayGeneration:
+  case llvm::Triple::Intersection:
+  case llvm::Triple::AnyHit:
+  case llvm::Triple::ClosestHit:
+  case llvm::Triple::Miss:
+  case llvm::Triple::Callable:
     if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
       DiagnoseAttrStageMismatch(NT, ST,
-                                {HLSLShaderAttr::Compute,
-                                 HLSLShaderAttr::Amplification,
-                                 HLSLShaderAttr::Mesh});
+                                {llvm::Triple::Compute,
+                                 llvm::Triple::Amplification,
+                                 llvm::Triple::Mesh});
       FD->setInvalidDecl();
     }
     break;
 
-  case HLSLShaderAttr::Compute:
-  case HLSLShaderAttr::Amplification:
-  case HLSLShaderAttr::Mesh:
+  case llvm::Triple::Compute:
+  case llvm::Triple::Amplification:
+  case llvm::Triple::Mesh:
     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+          << llvm::Triple::getEnvironmentTypeName(ST);
       FD->setInvalidDecl();
     }
     break;
+  default:
+    llvm_unreachable("Unhandled environment in triple");
   }
 
   for (ParmVarDecl *Param : FD->parameters()) {
@@ -267,14 +268,14 @@ void SemaHLSL::CheckSemanticAnnotation(
     const HLSLAnnotationAttr *AnnotationAttr) {
   auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
   assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+  llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
 
   switch (AnnotationAttr->getKind()) {
   case attr::HLSLSV_DispatchThreadID:
   case attr::HLSLSV_GroupIndex:
-    if (ST == HLSLShaderAttr::Compute)
+    if (ST == llvm::Triple::Compute)
       return;
-    DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+    DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
     break;
   default:
     llvm_unreachable("Unknown HLSLAnnotationAttr");
@@ -282,16 +283,16 @@ void SemaHLSL::CheckSemanticAnnotation(
 }
 
 void SemaHLSL::DiagnoseAttrStageMismatch(
-    const Attr *A, HLSLShaderAttr::ShaderType Stage,
-    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+    const Attr *A, llvm::Triple::EnvironmentType Stage,
+    std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
   SmallVector<StringRef, 8> StageStrings;
   llvm::transform(AllowedStages, std::back_inserter(StageStrings),
-                  [](HLSLShaderAttr::ShaderType ST) {
+                  [](llvm::Triple::EnvironmentType ST) {
                     return StringRef(
-                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+                        HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
                   });
   Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+      << A << llvm::Triple::getEnvironmentTypeName(Stage)
       << (AllowedStages.size() != 1) << join(StageStrings, ", ");
 }
 
@@ -321,16 +322,22 @@ class DiagnoseHLSLAvailability
   //
   // Maps FunctionDecl to an unsigned number that represents the set of shader
   // environments the function has been scanned for.
-  // Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
-  // defined without any assigned values, it is guaranteed to be numbered
-  // sequentially from 0 up and we can use it to 'index' individual bits
-  // in the set.
+  // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
+  // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
+  // (verified by static_asserts in Triple.cpp), we can use it to index
+  // individual bits in the set, as long as we shift the values to start with 0
+  // by subtracting the value of llvm::Triple::Pixel first.
+  //
   // The N'th bit in the set will be set if the function has been scanned
-  // in shader environment whose ShaderType integer value equals N.
+  // in shader environment whose llvm::Triple::EnvironmentType integer value
+  // equals (llvm::Triple::Pixel + N).
+  //
   // For example, if a function has been scanned in compute and pixel stage
-  // environment, the value will be 0x21 (100001 binary) because
-  // (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
-  // (int)HLSLShaderAttr::ShaderType::Compute == 5.
+  // environment, the value will be 0x21 (100001 binary) because:
+  //
+  //   (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
+  //   (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
+  //
   // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
   // been scanned in any environment.
   llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
@@ -346,12 +353,16 @@ class DiagnoseHLSLAvailability
   bool ReportOnlyShaderStageIssues;
 
   // Helper methods for dealing with current stage context / environment
-  void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
+  void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
     static_assert(sizeof(unsigned) >= 4);
-    assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"
-
-    CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
-    CurrentShaderStageBit = (1 << ShaderType);
+    assert(HLSLShaderAttr::isValidShaderType(ShaderType));
+    assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
+           "ShaderType is too big for this bitmap"); // 31 is reserved for
+                                                     // "unknown"
+
+    unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
+    CurrentShaderEnvironment = ShaderType;
+    CurrentShaderStageBit = (1 << bitmapIndex);
   }
 
   void SetUnknownShaderStageContext() {

Copy link
Contributor

@coopp coopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I like how clean this change is!

@hekota hekota merged commit 5d87ba1 into llvm:main Jun 8, 2024
7 checks passed
nekoshirro pushed a commit to nekoshirro/Alchemist-LLVM that referenced this pull request Jun 9, 2024
…haderType (llvm#93847)

`HLSLShaderAttr::ShaderType` enum is a subset of
`llvm::Triple::EnvironmentType`. We can use
`llvm::Triple::EnvironmentType` directly and avoid converting one enum
to another.

Signed-off-by: Hafidz Muzakky <[email protected]>
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

6 participants