@@ -612,6 +612,9 @@ struct RegisterBindingFlags {
612
612
613
613
bool ContainsNumeric = false ;
614
614
bool DefaultGlobals = false ;
615
+
616
+ // used only when Resource == true
617
+ std::optional<llvm::dxil::ResourceClass> ResourceClass;
615
618
};
616
619
617
620
static bool isDeclaredWithinCOrTBuffer (const Decl *TheDecl) {
@@ -677,65 +680,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) {
677
680
return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl);
678
681
}
679
682
680
- static void updateFlagsFromType (QualType TheQualTy ,
681
- RegisterBindingFlags &Flags);
682
-
683
- static void updateResourceClassFlagsFromRecordDecl (RegisterBindingFlags &Flags,
684
- const RecordDecl *RD) {
685
- if (!RD)
686
- return ;
687
-
688
- if (RD-> isCompleteDefinition ()) {
689
- for ( auto Field : RD-> fields ()) {
690
- QualType T = Field-> getType () ;
691
- updateFlagsFromType (T, Flags) ;
683
+ static void updateResourceClassFlagsFromRecordType (RegisterBindingFlags &Flags ,
684
+ const RecordType *RT) {
685
+ llvm::SmallVector< const Type *> TypesToScan;
686
+ TypesToScan. emplace_back (RT);
687
+
688
+ while (!TypesToScan. empty ()) {
689
+ const Type *T = TypesToScan. pop_back_val () ;
690
+ while (T-> isArrayType ())
691
+ T = T-> getArrayElementTypeNoTypeQual ();
692
+ if (T-> isIntegralOrEnumerationType () || T-> isFloatingType ()) {
693
+ Flags. ContainsNumeric = true ;
694
+ continue ;
692
695
}
693
- }
694
- }
695
-
696
- static void updateFlagsFromType (QualType TheQualTy,
697
- RegisterBindingFlags &Flags) {
698
- // if the member's type is a numeric type, set the ContainsNumeric flag
699
- if (TheQualTy->isIntegralOrEnumerationType () || TheQualTy->isFloatingType ()) {
700
- Flags.ContainsNumeric = true ;
701
- return ;
702
- }
703
-
704
- const clang::Type *TheBaseType = TheQualTy.getTypePtr ();
705
- while (TheBaseType->isArrayType ())
706
- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
707
- // otherwise, if the member's base type is not a record type, return
708
- const RecordType *TheRecordTy = TheBaseType->getAs <RecordType>();
709
- if (!TheRecordTy)
710
- return ;
711
-
712
- RecordDecl *SubRecordDecl = TheRecordTy->getDecl ();
713
- const HLSLResourceClassAttr *Attr =
714
- getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl);
715
- // find the attr if it's on the member, or on any of the member's fields
716
- if (Attr) {
717
- llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass ();
718
- updateResourceClassFlagsFromDeclResourceClass (Flags, DeclResourceClass);
719
- }
696
+ const RecordType *RT = T->getAs <RecordType>();
697
+ if (!RT)
698
+ continue ;
720
699
721
- // otherwise, dig deeper and recurse into the member
722
- else {
723
- updateResourceClassFlagsFromRecordDecl (Flags, SubRecordDecl);
700
+ const RecordDecl *RD = RT->getDecl ();
701
+ for (FieldDecl *FD : RD->fields ()) {
702
+ if (HLSLResourceClassAttr *RCAttr =
703
+ FD->getAttr <HLSLResourceClassAttr>()) {
704
+ updateResourceClassFlagsFromDeclResourceClass (
705
+ Flags, RCAttr->getResourceClass ());
706
+ continue ;
707
+ }
708
+ TypesToScan.emplace_back (FD->getType ().getTypePtr ());
709
+ }
724
710
}
725
711
}
726
712
727
713
static RegisterBindingFlags HLSLFillRegisterBindingFlags (Sema &S,
728
714
Decl *TheDecl) {
729
-
730
- // Cbuffers and Tbuffers are HLSLBufferDecl types
731
- HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
732
- // Samplers, UAVs, and SRVs are VarDecl types
733
- VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
734
-
735
- assert (((TheVarDecl && !CBufferOrTBuffer) ||
736
- (!TheVarDecl && CBufferOrTBuffer)) &&
737
- " either TheVarDecl or CBufferOrTBuffer should be set" );
738
-
739
715
RegisterBindingFlags Flags;
740
716
741
717
// check if the decl type is groupshared
@@ -744,58 +720,60 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
744
720
return Flags;
745
721
}
746
722
747
- if (!isDeclaredWithinCOrTBuffer (TheDecl)) {
748
- // make sure the type is a basic / numeric type
749
- if (TheVarDecl) {
750
- QualType TheQualTy = TheVarDecl->getType ();
751
- // a numeric variable or an array of numeric variables
752
- // will inevitably end up in $Globals buffer
753
- const clang::Type *TheBaseType = TheQualTy.getTypePtr ();
754
- while (TheBaseType->isArrayType ())
755
- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
756
- if (TheBaseType->isIntegralType (S.getASTContext ()) ||
757
- TheBaseType->isFloatingType ())
758
- Flags.DefaultGlobals = true ;
759
- }
760
- }
761
-
762
- if (CBufferOrTBuffer) {
723
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
724
+ if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
763
725
Flags.Resource = true ;
764
- if (CBufferOrTBuffer->isCBuffer ())
765
- Flags.CBV = true ;
766
- else
767
- Flags.SRV = true ;
768
- } else if (TheVarDecl) {
726
+ Flags.ResourceClass = CBufferOrTBuffer->isCBuffer ()
727
+ ? llvm::dxil::ResourceClass::CBuffer
728
+ : llvm::dxil::ResourceClass::SRV;
729
+ }
730
+ // Samplers, UAVs, and SRVs are VarDecl types
731
+ else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
769
732
const HLSLResourceClassAttr *resClassAttr =
770
733
getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
771
-
772
734
if (resClassAttr) {
773
- llvm::hlsl::ResourceClass DeclResourceClass =
774
- resClassAttr->getResourceClass ();
775
735
Flags.Resource = true ;
776
- updateResourceClassFlagsFromDeclResourceClass ( Flags, DeclResourceClass );
736
+ Flags. ResourceClass = resClassAttr-> getResourceClass ( );
777
737
} else {
778
738
const clang::Type *TheBaseType = TheVarDecl->getType ().getTypePtr ();
779
739
while (TheBaseType->isArrayType ())
780
740
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
781
- if (TheBaseType->isArithmeticType ())
741
+
742
+ if (TheBaseType->isArithmeticType ()) {
782
743
Flags.Basic = true ;
783
- else if (TheBaseType->isRecordType ()) {
744
+ if (!isDeclaredWithinCOrTBuffer (TheDecl) &&
745
+ (TheBaseType->isIntegralType (S.getASTContext ()) ||
746
+ TheBaseType->isFloatingType ()))
747
+ Flags.DefaultGlobals = true ;
748
+ } else if (TheBaseType->isRecordType ()) {
784
749
Flags.UDT = true ;
785
750
const RecordType *TheRecordTy = TheBaseType->getAs <RecordType>();
786
- assert (TheRecordTy && " The Qual Type should be Record Type" );
787
- const RecordDecl *TheRecordDecl = TheRecordTy->getDecl ();
788
- // recurse through members, set appropriate resource class flags.
789
- updateResourceClassFlagsFromRecordDecl (Flags, TheRecordDecl);
751
+ updateResourceClassFlagsFromRecordType (Flags, TheRecordTy);
790
752
} else
791
753
Flags.Other = true ;
792
754
}
755
+ } else {
756
+ llvm_unreachable (" expected be VarDecl or HLSLBufferDecl" );
793
757
}
794
758
return Flags;
795
759
}
796
760
797
761
enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
798
762
763
+ static RegisterType getRegisterType (llvm::dxil::ResourceClass RC) {
764
+ switch (RC) {
765
+ case llvm::dxil::ResourceClass::SRV:
766
+ return RegisterType::SRV;
767
+ case llvm::dxil::ResourceClass::UAV:
768
+ return RegisterType::UAV;
769
+ case llvm::dxil::ResourceClass::CBuffer:
770
+ return RegisterType::CBuffer;
771
+ case llvm::dxil::ResourceClass::Sampler:
772
+ return RegisterType::Sampler;
773
+ }
774
+ llvm_unreachable (" unexpected ResourceClass value" );
775
+ }
776
+
799
777
static RegisterType getRegisterType (StringRef Slot) {
800
778
switch (Slot[0 ]) {
801
779
case ' t' :
@@ -886,34 +864,8 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
886
864
// next, if resource is set, make sure the register type in the register
887
865
// annotation is compatible with the variable's resource type.
888
866
if (Flags.Resource ) {
889
- const HLSLResourceClassAttr *resClassAttr = nullptr ;
890
- if (CBufferOrTBuffer) {
891
- resClassAttr = CBufferOrTBuffer->getAttr <HLSLResourceClassAttr>();
892
- } else if (TheVarDecl) {
893
- resClassAttr =
894
- getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
895
- }
896
-
897
- assert (resClassAttr &&
898
- " any decl that set the resource flag on analysis should "
899
- " have a resource class attribute attached." );
900
- const llvm::hlsl::ResourceClass DeclResourceClass =
901
- resClassAttr->getResourceClass ();
902
-
903
- // confirm that the register type is bound to its expected resource class
904
- static RegisterType ExpectedRegisterTypesForResourceClass[] = {
905
- RegisterType::SRV,
906
- RegisterType::UAV,
907
- RegisterType::CBuffer,
908
- RegisterType::Sampler,
909
- };
910
- assert ((size_t )DeclResourceClass <
911
- std::size (ExpectedRegisterTypesForResourceClass) &&
912
- " DeclResourceClass has unexpected value" );
913
-
914
- RegisterType ExpectedRegisterType =
915
- ExpectedRegisterTypesForResourceClass[(int )DeclResourceClass];
916
- if (regType != ExpectedRegisterType) {
867
+ RegisterType expRegType = getRegisterType (Flags.ResourceClass .value ());
868
+ if (regType != expRegType) {
917
869
S.Diag (TheDecl->getLocation (), diag::err_hlsl_binding_type_mismatch)
918
870
<< regTypeNum;
919
871
}
@@ -955,7 +907,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
955
907
}
956
908
957
909
void SemaHLSL::handleResourceBindingAttr (Decl *TheDecl, const ParsedAttr &AL) {
958
- if (dyn_cast <VarDecl>(TheDecl)) {
910
+ if (isa <VarDecl>(TheDecl)) {
959
911
if (SemaRef.RequireCompleteType (TheDecl->getBeginLoc (),
960
912
cast<ValueDecl>(TheDecl)->getType (),
961
913
diag::err_incomplete_type))
0 commit comments