-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[HLSL][RootSignature] Add parsing for RootFlags #138055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/inbelic/pr-138007
Are you sure you want to change the base?
[HLSL][RootSignature] Add parsing for RootFlags #138055
Conversation
- defines the `RootFlags` in-memory enum - defines a template of `parseRootFlags` that will allow handling of parsing root flags
- we shouldn't pass in a string to the diagnostic, instead we should define the string properly and provide better context
@llvm/pr-subscribers-clang @llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) Changes
Resolves #126575 Full diff: https://github.com/llvm/llvm-project/pull/138055.diff 7 Files Affected:
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 72e765bcb800d..75ed28f95cd32 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1842,5 +1842,6 @@ def err_hlsl_unexpected_end_of_params
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
+def err_hlsl_rootsig_non_zero_flag : Error<"non-zero integer literal specified for flag value">;
} // end of Parser diagnostics
diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
index ecb8cfc7afa16..eac6ebda84965 100644
--- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
@@ -27,6 +27,9 @@
#endif
// Defines the various types of enum
+#ifndef ROOT_FLAG_ENUM
+#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
+#endif
#ifndef UNBOUNDED_ENUM
#define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT)
#endif
@@ -73,6 +76,7 @@ PUNCTUATOR(minus, '-')
// RootElement Keywords:
KEYWORD(RootSignature) // used only for diagnostic messaging
+KEYWORD(RootFlags)
KEYWORD(DescriptorTable)
KEYWORD(RootConstants)
@@ -100,6 +104,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded")
// Descriptor Range Offset Enum:
DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")
+// Root Flag Enums:
+ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT")
+ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS")
+ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT")
+ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE")
+ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED")
+ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED")
+
// Root Descriptor Flag Enums:
ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE")
ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE")
@@ -127,6 +145,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH")
#undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
#undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
#undef ROOT_DESCRIPTOR_FLAG_ENUM
+#undef ROOT_FLAG_ENUM
#undef DESCRIPTOR_RANGE_OFFSET_ENUM
#undef UNBOUNDED_ENUM
#undef ENUM
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 2ac2083983741..915266f8a36ae 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -71,6 +71,7 @@ class RootSignatureParser {
// expected, or, there is a lexing error
/// Root Element parse methods:
+ std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index a5006b77a6e44..4780af0f94162 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
bool RootSignatureParser::parse() {
// Iterate as many RootElements as possible
do {
+ if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+ auto Flags = parseRootFlags();
+ if (!Flags.has_value())
+ return true;
+ Elements.push_back(*Flags);
+ }
+
if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
auto Constants = parseRootConstants();
if (!Constants.has_value())
@@ -47,6 +54,61 @@ bool RootSignatureParser::parse() {
/*param of=*/TokenKind::kw_RootSignature);
}
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+ if (!Flags.has_value())
+ return Flag;
+
+ return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+ llvm::to_underlying(Flag));
+}
+
+std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
+ assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
+ "Expects to only be invoked starting at given keyword");
+
+ if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
+ CurToken.TokKind))
+ return std::nullopt;
+
+ std::optional<RootFlags> Flags = RootFlags::None;
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
+ return std::nullopt;
+ }
+ } else {
+ // Otherwise, parse as many flags as possible
+ TokenKind Expected[] = {
+#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define ROOT_FLAG_ENUM(NAME, LIT) \
+ case TokenKind::en_##NAME: \
+ Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_r_paren,
+ diag::err_hlsl_unexpected_end_of_params,
+ /*param of=*/TokenKind::kw_RootFlags))
+ return std::nullopt;
+
+ return Flags;
+}
+
std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
"Expects to only be invoked starting at given keyword");
@@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}
-template <typename FlagType>
-static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
- if (!Flags.has_value())
- return Flag;
-
- return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
- llvm::to_underlying(Flag));
-}
-
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
RootSignatureParser::parseDescriptorRangeFlags() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
@@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() {
// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
- getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
return std::nullopt;
}
return DescriptorRangeFlags::None;
diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
index 89e9a3183ad03..21a1f1f08ae05 100644
--- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
@@ -87,7 +87,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
RootSignature
- DescriptorTable RootConstants
+ RootFlags DescriptorTable RootConstants
num32BitConstants
@@ -98,6 +98,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
unbounded
DESCRIPTOR_RANGE_OFFSET_APPEND
+ allow_input_assembler_input_layout
+ deny_vertex_shader_root_access
+ deny_hull_shader_root_access
+ deny_domain_shader_root_access
+ deny_geometry_shader_root_access
+ deny_pixel_shader_root_access
+ deny_amplification_shader_root_access
+ deny_mesh_shader_root_access
+ allow_stream_output
+ local_root_signature
+ cbv_srv_uav_heap_directly_indexed
+ sampler_heap_directly_indexed
+
DATA_VOLATILE
DATA_STATIC_WHILE_SET_AT_EXECUTE
DATA_STATIC
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 150eb3e6e54ef..18e1e517dae8f 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(),
+ RootFlags(0),
+ RootFlags(
+ deny_domain_shader_root_access |
+ deny_pixel_shader_root_access |
+ local_root_signature |
+ cbv_srv_uav_heap_directly_indexed |
+ deny_amplification_shader_root_access |
+ deny_geometry_shader_root_access |
+ deny_hull_shader_root_access |
+ deny_mesh_shader_root_access |
+ allow_stream_output |
+ sampler_heap_directly_indexed |
+ allow_input_assembler_input_layout |
+ deny_vertex_shader_root_access
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ ASSERT_EQ(Elements.size(), 3u);
+
+ RootElement Elem = Elements[0];
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ Elem = Elements[1];
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ Elem = Elements[2];
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
@@ -496,7 +546,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
// Test correct diagnostic produced
- Consumer->setExpected(diag::err_expected);
+ Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag);
ASSERT_TRUE(Parser.parse());
ASSERT_TRUE(Consumer->isSatisfied());
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 8b8324df18bb3..2ecaf69fc2f9c 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,23 @@ namespace rootsig {
// Definition of the various enumerations and flags
+enum class RootFlags : uint32_t {
+ None = 0,
+ AllowInputAssemblerInputLayout = 0x1,
+ DenyVertexShaderRootAccess = 0x2,
+ DenyHullShaderRootAccess = 0x4,
+ DenyDomainShaderRootAccess = 0x8,
+ DenyGeometryShaderRootAccess = 0x10,
+ DenyPixelShaderRootAccess = 0x20,
+ AllowStreamOutput = 0x40,
+ LocalRootSignature = 0x80,
+ DenyAmplificationShaderRootAccess = 0x100,
+ DenyMeshShaderRootAccess = 0x200,
+ CBVSRVUAVHeapDirectlyIndexed = 0x400,
+ SamplerHeapDirectlyIndexed = 0x800,
+ ValidFlags = 0x00000fff
+};
+
enum class DescriptorRangeFlags : unsigned {
None = 0,
DescriptorsVolatile = 0x1,
@@ -97,8 +114,8 @@ struct DescriptorTableClause {
};
// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
-using RootElement =
- std::variant<RootConstants, DescriptorTable, DescriptorTableClause>;
+using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
+ DescriptorTableClause>;
} // namespace rootsig
} // namespace hlsl
|
defines the
RootFlags
in-memory enumdefines
parseRootFlags
to parse the various flag enums into a singleuint32_t
adds corresponding unit tests
improves the diagnostic message for when we provide a non-zero integer value to the flags
Resolves #126575