Skip to content

[HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType #93847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 5 additions & 22 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4470,37 +4470,20 @@ def HLSLShader : InheritableAttr {
let Subjects = SubjectList<[HLSLEntry]>;
let LangOpts = [HLSL];
let Args = [
EnumArgument<"Type", "ShaderType", /*is_string=*/true,
EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
["pixel", "vertex", "geometry", "hull", "domain", "compute",
"raygeneration", "intersection", "anyhit", "closesthit",
"miss", "callable", "mesh", "amplification"],
["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
"RayGeneration", "Intersection", "AnyHit", "ClosestHit",
"Miss", "Callable", "Mesh", "Amplification"]>
"Miss", "Callable", "Mesh", "Amplification"],
/*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
];
let Documentation = [HLSLSV_ShaderTypeAttrDocs];
let AdditionalMembers =
[{
static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;

static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
switch (ShaderType) {
case HLSLShaderAttr::Pixel: return llvm::Triple::Pixel;
case HLSLShaderAttr::Vertex: return llvm::Triple::Vertex;
case HLSLShaderAttr::Geometry: return llvm::Triple::Geometry;
case HLSLShaderAttr::Hull: return llvm::Triple::Hull;
case HLSLShaderAttr::Domain: return llvm::Triple::Domain;
case HLSLShaderAttr::Compute: return llvm::Triple::Compute;
case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
case HLSLShaderAttr::Intersection: return llvm::Triple::Intersection;
case HLSLShaderAttr::AnyHit: return llvm::Triple::AnyHit;
case HLSLShaderAttr::ClosestHit: return llvm::Triple::ClosestHit;
case HLSLShaderAttr::Miss: return llvm::Triple::Miss;
case HLSLShaderAttr::Callable: return llvm::Triple::Callable;
case HLSLShaderAttr::Mesh: return llvm::Triple::Mesh;
case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
}
llvm_unreachable("unknown enumeration value");
static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
}
}];
}
Expand Down
6 changes: 3 additions & 3 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SemaHLSL : public SemaBase {
const AttributeCommonInfo &AL, int X,
int Y, int Z);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLShaderAttr::ShaderType ShaderType);
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);
Expand All @@ -48,8 +48,8 @@ class SemaHLSL : public SemaBase {
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
const Attr *A, HLSLShaderAttr::ShaderType Stage,
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);

void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
const StringRef ShaderAttrKindStr = "hlsl.shader";
Fn->addFnAttr(ShaderAttrKindStr,
ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
const StringRef NumThreadsKindStr = "hlsl.numthreads";
std::string NumThreadsStr =
Expand Down
109 changes: 60 additions & 49 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,

HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLShaderAttr::ShaderType ShaderType) {
llvm::Triple::EnvironmentType ShaderType) {
if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
if (NT->getType() != ShaderType) {
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
Expand Down Expand Up @@ -184,25 +184,24 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
return;

StringRef Env = TargetInfo.getTriple().getEnvironmentName();
HLSLShaderAttr::ShaderType ShaderType;
if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
// The entry point is already annotated - check that it matches the
// triple.
if (Shader->getType() != ShaderType) {
if (Shader->getType() != Env) {
Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
<< Shader;
FD->setInvalidDecl();
}
} else {
// Implicitly add the shader attribute if the entry function isn't
// explicitly annotated.
FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
FD->getBeginLoc()));
}
} else {
switch (TargetInfo.getTriple().getEnvironment()) {
switch (Env) {
case llvm::Triple::UnknownEnvironment:
case llvm::Triple::Library:
break;
Expand All @@ -215,38 +214,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();

switch (ST) {
case HLSLShaderAttr::Pixel:
case HLSLShaderAttr::Vertex:
case HLSLShaderAttr::Geometry:
case HLSLShaderAttr::Hull:
case HLSLShaderAttr::Domain:
case HLSLShaderAttr::RayGeneration:
case HLSLShaderAttr::Intersection:
case HLSLShaderAttr::AnyHit:
case HLSLShaderAttr::ClosestHit:
case HLSLShaderAttr::Miss:
case HLSLShaderAttr::Callable:
case llvm::Triple::Pixel:
case llvm::Triple::Vertex:
case llvm::Triple::Geometry:
case llvm::Triple::Hull:
case llvm::Triple::Domain:
case llvm::Triple::RayGeneration:
case llvm::Triple::Intersection:
case llvm::Triple::AnyHit:
case llvm::Triple::ClosestHit:
case llvm::Triple::Miss:
case llvm::Triple::Callable:
if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
DiagnoseAttrStageMismatch(NT, ST,
{HLSLShaderAttr::Compute,
HLSLShaderAttr::Amplification,
HLSLShaderAttr::Mesh});
{llvm::Triple::Compute,
llvm::Triple::Amplification,
llvm::Triple::Mesh});
FD->setInvalidDecl();
}
break;

case HLSLShaderAttr::Compute:
case HLSLShaderAttr::Amplification:
case HLSLShaderAttr::Mesh:
case llvm::Triple::Compute:
case llvm::Triple::Amplification:
case llvm::Triple::Mesh:
if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
<< HLSLShaderAttr::ConvertShaderTypeToStr(ST);
<< llvm::Triple::getEnvironmentTypeName(ST);
FD->setInvalidDecl();
}
break;
default:
llvm_unreachable("Unhandled environment in triple");
}

for (ParmVarDecl *Param : FD->parameters()) {
Expand All @@ -268,31 +269,31 @@ void SemaHLSL::CheckSemanticAnnotation(
const HLSLAnnotationAttr *AnnotationAttr) {
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();

switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
if (ST == HLSLShaderAttr::Compute)
if (ST == llvm::Triple::Compute)
return;
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
break;
default:
llvm_unreachable("Unknown HLSLAnnotationAttr");
}
}

void SemaHLSL::DiagnoseAttrStageMismatch(
const Attr *A, HLSLShaderAttr::ShaderType Stage,
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
SmallVector<StringRef, 8> StageStrings;
llvm::transform(AllowedStages, std::back_inserter(StageStrings),
[](HLSLShaderAttr::ShaderType ST) {
[](llvm::Triple::EnvironmentType ST) {
return StringRef(
HLSLShaderAttr::ConvertShaderTypeToStr(ST));
HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
});
Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
<< A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
<< A << llvm::Triple::getEnvironmentTypeName(Stage)
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
}

Expand Down Expand Up @@ -430,8 +431,8 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
return;

HLSLShaderAttr::ShaderType ShaderType;
if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
llvm::Triple::EnvironmentType ShaderType;
if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
<< AL << Str << ArgLoc;
return;
Expand Down Expand Up @@ -549,16 +550,22 @@ class DiagnoseHLSLAvailability
//
// Maps FunctionDecl to an unsigned number that represents the set of shader
// environments the function has been scanned for.
// Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
// defined without any assigned values, it is guaranteed to be numbered
// sequentially from 0 up and we can use it to 'index' individual bits
// in the set.
// The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
// to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
// (verified by static_asserts in Triple.cpp), we can use it to index
// individual bits in the set, as long as we shift the values to start with 0
// by subtracting the value of llvm::Triple::Pixel first.
//
// The N'th bit in the set will be set if the function has been scanned
// in shader environment whose ShaderType integer value equals N.
// in shader environment whose llvm::Triple::EnvironmentType integer value
// equals (llvm::Triple::Pixel + N).
//
// For example, if a function has been scanned in compute and pixel stage
// environment, the value will be 0x21 (100001 binary) because
// (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
// (int)HLSLShaderAttr::ShaderType::Compute == 5.
// environment, the value will be 0x21 (100001 binary) because:
//
// (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
// (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
//
// A FunctionDecl is mapped to 0 (or not included in the map) if it has not
// been scanned in any environment.
llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
Expand All @@ -574,12 +581,16 @@ class DiagnoseHLSLAvailability
bool ReportOnlyShaderStageIssues;

// Helper methods for dealing with current stage context / environment
void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
static_assert(sizeof(unsigned) >= 4);
assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"

CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
CurrentShaderStageBit = (1 << ShaderType);
assert(HLSLShaderAttr::isValidShaderType(ShaderType));
assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
"ShaderType is too big for this bitmap"); // 31 is reserved for
// "unknown"

unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
CurrentShaderEnvironment = ShaderType;
CurrentShaderStageBit = (1 << bitmapIndex);
}

void SetUnknownShaderStageContext() {
Expand Down
Loading