Skip to content

Commit 334d123

Browse files
authored
[HLSL] Adjust resource binding diagnostic flags code (#106657)
Adjust register binding diagnostic flags code in a couple of ways: - Store the resource class in the Flags struct to avoid duplicated scanning for HLSLResourceClassAttribute - Avoid unnecessary indirection when converting resource class to register type - Remove recursion and reduce duplicated code Also fixes a case where struct with an array was incorrectly diagnosed unfit for `c` register binding. This will also simplify work that is needed to be done in this area for #104861.
1 parent 0ef7b1d commit 334d123

File tree

2 files changed

+71
-112
lines changed

2 files changed

+71
-112
lines changed

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 64 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,9 @@ struct RegisterBindingFlags {
612612

613613
bool ContainsNumeric = false;
614614
bool DefaultGlobals = false;
615+
616+
// used only when Resource == true
617+
std::optional<llvm::dxil::ResourceClass> ResourceClass;
615618
};
616619

617620
static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
@@ -677,65 +680,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) {
677680
return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl);
678681
}
679682

680-
static void updateFlagsFromType(QualType TheQualTy,
681-
RegisterBindingFlags &Flags);
682-
683-
static void updateResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
684-
const RecordDecl *RD) {
685-
if (!RD)
686-
return;
687-
688-
if (RD->isCompleteDefinition()) {
689-
for (auto Field : RD->fields()) {
690-
QualType T = Field->getType();
691-
updateFlagsFromType(T, Flags);
683+
static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
684+
const RecordType *RT) {
685+
llvm::SmallVector<const Type *> TypesToScan;
686+
TypesToScan.emplace_back(RT);
687+
688+
while (!TypesToScan.empty()) {
689+
const Type *T = TypesToScan.pop_back_val();
690+
while (T->isArrayType())
691+
T = T->getArrayElementTypeNoTypeQual();
692+
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
693+
Flags.ContainsNumeric = true;
694+
continue;
692695
}
693-
}
694-
}
695-
696-
static void updateFlagsFromType(QualType TheQualTy,
697-
RegisterBindingFlags &Flags) {
698-
// if the member's type is a numeric type, set the ContainsNumeric flag
699-
if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
700-
Flags.ContainsNumeric = true;
701-
return;
702-
}
703-
704-
const clang::Type *TheBaseType = TheQualTy.getTypePtr();
705-
while (TheBaseType->isArrayType())
706-
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
707-
// otherwise, if the member's base type is not a record type, return
708-
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
709-
if (!TheRecordTy)
710-
return;
711-
712-
RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
713-
const HLSLResourceClassAttr *Attr =
714-
getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl);
715-
// find the attr if it's on the member, or on any of the member's fields
716-
if (Attr) {
717-
llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
718-
updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
719-
}
696+
const RecordType *RT = T->getAs<RecordType>();
697+
if (!RT)
698+
continue;
720699

721-
// otherwise, dig deeper and recurse into the member
722-
else {
723-
updateResourceClassFlagsFromRecordDecl(Flags, SubRecordDecl);
700+
const RecordDecl *RD = RT->getDecl();
701+
for (FieldDecl *FD : RD->fields()) {
702+
if (HLSLResourceClassAttr *RCAttr =
703+
FD->getAttr<HLSLResourceClassAttr>()) {
704+
updateResourceClassFlagsFromDeclResourceClass(
705+
Flags, RCAttr->getResourceClass());
706+
continue;
707+
}
708+
TypesToScan.emplace_back(FD->getType().getTypePtr());
709+
}
724710
}
725711
}
726712

727713
static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
728714
Decl *TheDecl) {
729-
730-
// Cbuffers and Tbuffers are HLSLBufferDecl types
731-
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
732-
// Samplers, UAVs, and SRVs are VarDecl types
733-
VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
734-
735-
assert(((TheVarDecl && !CBufferOrTBuffer) ||
736-
(!TheVarDecl && CBufferOrTBuffer)) &&
737-
"either TheVarDecl or CBufferOrTBuffer should be set");
738-
739715
RegisterBindingFlags Flags;
740716

741717
// check if the decl type is groupshared
@@ -744,58 +720,60 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
744720
return Flags;
745721
}
746722

747-
if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
748-
// make sure the type is a basic / numeric type
749-
if (TheVarDecl) {
750-
QualType TheQualTy = TheVarDecl->getType();
751-
// a numeric variable or an array of numeric variables
752-
// will inevitably end up in $Globals buffer
753-
const clang::Type *TheBaseType = TheQualTy.getTypePtr();
754-
while (TheBaseType->isArrayType())
755-
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
756-
if (TheBaseType->isIntegralType(S.getASTContext()) ||
757-
TheBaseType->isFloatingType())
758-
Flags.DefaultGlobals = true;
759-
}
760-
}
761-
762-
if (CBufferOrTBuffer) {
723+
// Cbuffers and Tbuffers are HLSLBufferDecl types
724+
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
763725
Flags.Resource = true;
764-
if (CBufferOrTBuffer->isCBuffer())
765-
Flags.CBV = true;
766-
else
767-
Flags.SRV = true;
768-
} else if (TheVarDecl) {
726+
Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
727+
? llvm::dxil::ResourceClass::CBuffer
728+
: llvm::dxil::ResourceClass::SRV;
729+
}
730+
// Samplers, UAVs, and SRVs are VarDecl types
731+
else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
769732
const HLSLResourceClassAttr *resClassAttr =
770733
getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
771-
772734
if (resClassAttr) {
773-
llvm::hlsl::ResourceClass DeclResourceClass =
774-
resClassAttr->getResourceClass();
775735
Flags.Resource = true;
776-
updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
736+
Flags.ResourceClass = resClassAttr->getResourceClass();
777737
} else {
778738
const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
779739
while (TheBaseType->isArrayType())
780740
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
781-
if (TheBaseType->isArithmeticType())
741+
742+
if (TheBaseType->isArithmeticType()) {
782743
Flags.Basic = true;
783-
else if (TheBaseType->isRecordType()) {
744+
if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
745+
(TheBaseType->isIntegralType(S.getASTContext()) ||
746+
TheBaseType->isFloatingType()))
747+
Flags.DefaultGlobals = true;
748+
} else if (TheBaseType->isRecordType()) {
784749
Flags.UDT = true;
785750
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
786-
assert(TheRecordTy && "The Qual Type should be Record Type");
787-
const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
788-
// recurse through members, set appropriate resource class flags.
789-
updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
751+
updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
790752
} else
791753
Flags.Other = true;
792754
}
755+
} else {
756+
llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
793757
}
794758
return Flags;
795759
}
796760

797761
enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
798762

763+
static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
764+
switch (RC) {
765+
case llvm::dxil::ResourceClass::SRV:
766+
return RegisterType::SRV;
767+
case llvm::dxil::ResourceClass::UAV:
768+
return RegisterType::UAV;
769+
case llvm::dxil::ResourceClass::CBuffer:
770+
return RegisterType::CBuffer;
771+
case llvm::dxil::ResourceClass::Sampler:
772+
return RegisterType::Sampler;
773+
}
774+
llvm_unreachable("unexpected ResourceClass value");
775+
}
776+
799777
static RegisterType getRegisterType(StringRef Slot) {
800778
switch (Slot[0]) {
801779
case 't':
@@ -886,34 +864,8 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
886864
// next, if resource is set, make sure the register type in the register
887865
// annotation is compatible with the variable's resource type.
888866
if (Flags.Resource) {
889-
const HLSLResourceClassAttr *resClassAttr = nullptr;
890-
if (CBufferOrTBuffer) {
891-
resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
892-
} else if (TheVarDecl) {
893-
resClassAttr =
894-
getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
895-
}
896-
897-
assert(resClassAttr &&
898-
"any decl that set the resource flag on analysis should "
899-
"have a resource class attribute attached.");
900-
const llvm::hlsl::ResourceClass DeclResourceClass =
901-
resClassAttr->getResourceClass();
902-
903-
// confirm that the register type is bound to its expected resource class
904-
static RegisterType ExpectedRegisterTypesForResourceClass[] = {
905-
RegisterType::SRV,
906-
RegisterType::UAV,
907-
RegisterType::CBuffer,
908-
RegisterType::Sampler,
909-
};
910-
assert((size_t)DeclResourceClass <
911-
std::size(ExpectedRegisterTypesForResourceClass) &&
912-
"DeclResourceClass has unexpected value");
913-
914-
RegisterType ExpectedRegisterType =
915-
ExpectedRegisterTypesForResourceClass[(int)DeclResourceClass];
916-
if (regType != ExpectedRegisterType) {
867+
RegisterType expRegType = getRegisterType(Flags.ResourceClass.value());
868+
if (regType != expRegType) {
917869
S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
918870
<< regTypeNum;
919871
}
@@ -955,7 +907,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
955907
}
956908

957909
void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
958-
if (dyn_cast<VarDecl>(TheDecl)) {
910+
if (isa<VarDecl>(TheDecl)) {
959911
if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
960912
cast<ValueDecl>(TheDecl)->getType(),
961913
diag::err_incomplete_type))

clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,10 @@ struct Eg14{
126126
};
127127
// expected-warning@+1{{binding type 't' only applies to types containing SRV resources}}
128128
Eg14 e14 : register(t9);
129+
130+
struct Eg15 {
131+
float f[4];
132+
};
133+
// expected no error
134+
Eg15 e15 : register(c0);
135+

0 commit comments

Comments
 (0)