Skip to content

[HLSL][RootSignature] Add parsing for RootFlags #138055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: users/inbelic/pr-138007
Choose a base branch
from

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Apr 30, 2025

  • defines the RootFlags in-memory enum

  • defines parseRootFlags to parse the various flag enums into a single uint32_t

  • adds corresponding unit tests

  • improves the diagnostic message for when we provide a non-zero integer value to the flags

Resolves #126575

inbelic added 5 commits April 30, 2025 23:14
- defines the `RootFlags` in-memory enum
- defines a template of `parseRootFlags` that will allow handling of
parsing root flags
- we shouldn't pass in a string to the diagnostic, instead we should
define the string properly and provide better context
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Apr 30, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes
  • defines the RootFlags in-memory enum

  • defines parseRootFlags to parse the various flag enums into a single uint32_t

  • adds corresponding unit tests

  • improves the diagnostic message for when we provide a non-zero integer value to the flags

Resolves #126575


Full diff: https://github.com/llvm/llvm-project/pull/138055.diff

7 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1)
  • (modified) clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def (+19)
  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+1)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+63-10)
  • (modified) clang/unittests/Lex/LexHLSLRootSignatureTest.cpp (+14-1)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+51-1)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+19-2)
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 72e765bcb800d..75ed28f95cd32 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1842,5 +1842,6 @@ def err_hlsl_unexpected_end_of_params
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
+def err_hlsl_rootsig_non_zero_flag : Error<"non-zero integer literal specified for flag value">;
 
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
index ecb8cfc7afa16..eac6ebda84965 100644
--- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
@@ -27,6 +27,9 @@
 #endif
 
 // Defines the various types of enum
+#ifndef ROOT_FLAG_ENUM
+#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
 #ifndef UNBOUNDED_ENUM
 #define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT)
 #endif
@@ -73,6 +76,7 @@ PUNCTUATOR(minus,   '-')
 
 // RootElement Keywords:
 KEYWORD(RootSignature) // used only for diagnostic messaging
+KEYWORD(RootFlags)
 KEYWORD(DescriptorTable)
 KEYWORD(RootConstants)
 
@@ -100,6 +104,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded")
 // Descriptor Range Offset Enum:
 DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")
 
+// Root Flag Enums:
+ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT")
+ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT")
+ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE")
+ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED")
+ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED")
+
 // Root Descriptor Flag Enums:
 ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE")
 ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE")
@@ -127,6 +145,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH")
 #undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
 #undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
 #undef ROOT_DESCRIPTOR_FLAG_ENUM
+#undef ROOT_FLAG_ENUM
 #undef DESCRIPTOR_RANGE_OFFSET_ENUM
 #undef UNBOUNDED_ENUM
 #undef ENUM
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 2ac2083983741..915266f8a36ae 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -71,6 +71,7 @@ class RootSignatureParser {
   // expected, or, there is a lexing error
 
   /// Root Element parse methods:
+  std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
   std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
   std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
   std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index a5006b77a6e44..4780af0f94162 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
 bool RootSignatureParser::parse() {
   // Iterate as many RootElements as possible
   do {
+    if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      auto Flags = parseRootFlags();
+      if (!Flags.has_value())
+        return true;
+      Elements.push_back(*Flags);
+    }
+
     if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
       auto Constants = parseRootConstants();
       if (!Constants.has_value())
@@ -47,6 +54,61 @@ bool RootSignatureParser::parse() {
                               /*param of=*/TokenKind::kw_RootSignature);
 }
 
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+  if (!Flags.has_value())
+    return Flag;
+
+  return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+                               llvm::to_underlying(Flag));
+}
+
+std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
+  assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
+         "Expects to only be invoked starting at given keyword");
+
+  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
+                           CurToken.TokKind))
+    return std::nullopt;
+
+  std::optional<RootFlags> Flags = RootFlags::None;
+
+  // Handle the edge-case of '0' to specify no flags set
+  if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+    if (!verifyZeroFlag()) {
+      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
+      return std::nullopt;
+    }
+  } else {
+    // Otherwise, parse as many flags as possible
+    TokenKind Expected[] = {
+#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+    };
+
+    do {
+      if (tryConsumeExpectedToken(Expected)) {
+        switch (CurToken.TokKind) {
+#define ROOT_FLAG_ENUM(NAME, LIT)                                              \
+  case TokenKind::en_##NAME:                                                   \
+    Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME);                    \
+    break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+        default:
+          llvm_unreachable("Switch for consumed enum token was not provided");
+        }
+      }
+    } while (tryConsumeExpectedToken(TokenKind::pu_or));
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_RootFlags))
+    return std::nullopt;
+
+  return Flags;
+}
+
 std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
          "Expects to only be invoked starting at given keyword");
@@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() {
   return std::nullopt;
 }
 
-template <typename FlagType>
-static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
-  if (!Flags.has_value())
-    return Flag;
-
-  return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
-                               llvm::to_underlying(Flag));
-}
-
 std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
 RootSignatureParser::parseDescriptorRangeFlags() {
   assert(CurToken.TokKind == TokenKind::pu_equal &&
@@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() {
   // Handle the edge-case of '0' to specify no flags set
   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
     if (!verifyZeroFlag()) {
-      getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+      getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
       return std::nullopt;
     }
     return DescriptorRangeFlags::None;
diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
index 89e9a3183ad03..21a1f1f08ae05 100644
--- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
@@ -87,7 +87,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
 
     RootSignature
 
-    DescriptorTable RootConstants
+    RootFlags DescriptorTable RootConstants
 
     num32BitConstants
 
@@ -98,6 +98,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
     unbounded
     DESCRIPTOR_RANGE_OFFSET_APPEND
 
+    allow_input_assembler_input_layout
+    deny_vertex_shader_root_access
+    deny_hull_shader_root_access
+    deny_domain_shader_root_access
+    deny_geometry_shader_root_access
+    deny_pixel_shader_root_access
+    deny_amplification_shader_root_access
+    deny_mesh_shader_root_access
+    allow_stream_output
+    local_root_signature
+    cbv_srv_uav_heap_directly_indexed
+    sampler_heap_directly_indexed
+
     DATA_VOLATILE
     DATA_STATIC_WHILE_SET_AT_EXECUTE
     DATA_STATIC
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 150eb3e6e54ef..18e1e517dae8f 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(),
+    RootFlags(0),
+    RootFlags(
+      deny_domain_shader_root_access |
+      deny_pixel_shader_root_access |
+      local_root_signature |
+      cbv_srv_uav_heap_directly_indexed |
+      deny_amplification_shader_root_access |
+      deny_geometry_shader_root_access |
+      deny_hull_shader_root_access |
+      deny_mesh_shader_root_access |
+      allow_stream_output |
+      sampler_heap_directly_indexed |
+      allow_input_assembler_input_layout |
+      deny_vertex_shader_root_access
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  ASSERT_EQ(Elements.size(), 3u);
+
+  RootElement Elem = Elements[0];
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+  Elem = Elements[2];
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
   // This test will checks we can handling trailing commas ','
   const llvm::StringLiteral Source = R"cc(
@@ -496,7 +546,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
   hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
 
   // Test correct diagnostic produced
-  Consumer->setExpected(diag::err_expected);
+  Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag);
   ASSERT_TRUE(Parser.parse());
 
   ASSERT_TRUE(Consumer->isSatisfied());
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 8b8324df18bb3..2ecaf69fc2f9c 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,23 @@ namespace rootsig {
 
 // Definition of the various enumerations and flags
 
+enum class RootFlags : uint32_t {
+  None = 0,
+  AllowInputAssemblerInputLayout = 0x1,
+  DenyVertexShaderRootAccess = 0x2,
+  DenyHullShaderRootAccess = 0x4,
+  DenyDomainShaderRootAccess = 0x8,
+  DenyGeometryShaderRootAccess = 0x10,
+  DenyPixelShaderRootAccess = 0x20,
+  AllowStreamOutput = 0x40,
+  LocalRootSignature = 0x80,
+  DenyAmplificationShaderRootAccess = 0x100,
+  DenyMeshShaderRootAccess = 0x200,
+  CBVSRVUAVHeapDirectlyIndexed = 0x400,
+  SamplerHeapDirectlyIndexed = 0x800,
+  ValidFlags = 0x00000fff
+};
+
 enum class DescriptorRangeFlags : unsigned {
   None = 0,
   DescriptorsVolatile = 0x1,
@@ -97,8 +114,8 @@ struct DescriptorTableClause {
 };
 
 // Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
-using RootElement =
-    std::variant<RootConstants, DescriptorTable, DescriptorTableClause>;
+using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
+                                 DescriptorTableClause>;
 
 } // namespace rootsig
 } // namespace hlsl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants