@@ -722,4 +722,164 @@ def ROCDL_TargettAttr :
722
722
}
723
723
}];
724
724
}
725
+
726
+ //===----------------------------------------------------------------------===//
727
+ // ROCDL kernel attribute
728
+ //===----------------------------------------------------------------------===//
729
+
730
+ def ROCDL_KernelAttr :
731
+ ROCDL_Attr<"ROCDLKernel", "kernel"> {
732
+ let description = [{
733
+ ROCDL attribute for storing metadata related to a compiled kernel. It
734
+ contains the attribute dictionary of the LLVM function used to generate the
735
+ kernel, as well as an optional dictionary for additional metadata, like ELF
736
+ related metadata.
737
+ For details on the ELF metadata see:
738
+ https://llvm.org/docs/AMDGPUUsage.html#code-object-v5-metadata
739
+
740
+ Examples:
741
+ ```mlir
742
+ #rocdl.kernel<{sym_name = "test_fusion__part_0", ...},
743
+ metadata = {sgpr_count = 255, ...}>
744
+ ```
745
+ }];
746
+ let parameters = (ins
747
+ "DictionaryAttr":$func_attrs,
748
+ OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
749
+ );
750
+ let assemblyFormat = [{
751
+ `<` $func_attrs (`,` `metadata` `=` $metadata^ )? `>`
752
+ }];
753
+ let builders = [
754
+ AttrBuilderWithInferredContext<(ins "DictionaryAttr":$funcAttrs,
755
+ CArg<"DictionaryAttr",
756
+ "nullptr">:$metadata), [{
757
+ assert(funcAttrs && "invalid function attributes dictionary");
758
+ return $_get(funcAttrs.getContext(), funcAttrs, metadata);
759
+ }]>
760
+ ];
761
+ let extraClassDeclaration = [{
762
+ /// Returns the function attribute corresponding to key or nullptr if missing.
763
+ Attribute getAttr(StringRef key) const {
764
+ return getFuncAttrs().get(key);
765
+ }
766
+ template <typename ConcreteAttr>
767
+ ConcreteAttr getAttr(StringRef key) const {
768
+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
769
+ }
770
+ Attribute getAttr(StringAttr key) const;
771
+ template <typename ConcreteAttr>
772
+ ConcreteAttr getAttr(StringAttr key) const {
773
+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
774
+ }
775
+
776
+ /// Returns the name of the kernel.
777
+ StringAttr getName() const {
778
+ return getAttr<StringAttr>("sym_name");
779
+ }
780
+
781
+ /// Returns the metadta attribute corresponding to key or nullptr if missing.
782
+ Attribute getMDAttr(StringRef key) const {
783
+ if (DictionaryAttr attrs = getMetadata())
784
+ return attrs.get(key);
785
+ return nullptr;
786
+ }
787
+ template <typename ConcreteAttr>
788
+ ConcreteAttr getMDAttr(StringRef key) const {
789
+ return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
790
+ }
791
+ Attribute getMDAttr(StringAttr key) const;
792
+ template <typename ConcreteAttr>
793
+ ConcreteAttr getMDAttr(StringAttr key) const {
794
+ return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
795
+ }
796
+
797
+ /// Returns the number of required scalar registers, or nullptr if the field
798
+ /// is missing.
799
+ IntegerAttr getSGPR() const {
800
+ return getMDAttr<IntegerAttr>("sgpr_count");
801
+ }
802
+
803
+ /// Returns the number of required scalar registers, or nullptr if the field
804
+ /// is missing.
805
+ IntegerAttr getVGPR() const {
806
+ return getMDAttr<IntegerAttr>("vgpr_count");
807
+ }
808
+
809
+ /// Returns the number of required scalar registers, or nullptr if the field
810
+ /// is missing.
811
+ IntegerAttr getAGPR() const {
812
+ return getMDAttr<IntegerAttr>("agpr_count");
813
+ }
814
+
815
+ /// Returns the number of spilled SGPR, or nullptr if the field is missing.
816
+ IntegerAttr getSGPRSpill() const {
817
+ return getMDAttr<IntegerAttr>("sgpr_spill_count");
818
+ }
819
+
820
+ /// Returns the number of spilled VGPR, or nullptr if the field is missing.
821
+ IntegerAttr getVGPRSpill() const {
822
+ return getMDAttr<IntegerAttr>("vgpr_spill_count");
823
+ }
824
+
825
+ /// Helper function for appending metadata to a kernel attribute.
826
+ ROCDLKernelAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
827
+ }];
828
+ }
829
+
830
+ //===----------------------------------------------------------------------===//
831
+ // ROCDL object metadata
832
+ //===----------------------------------------------------------------------===//
833
+
834
+ def ROCDL_ObjectMDAttr :
835
+ ROCDL_Attr<"ROCDLObjectMD", "object_metadata"> {
836
+ let description = [{
837
+ ROCDL attribute representing a table of kernels metadata. All the attributes
838
+ in the dictionary must be of type `#rocdl.kernel`.
839
+
840
+ Examples:
841
+ ```mlir
842
+ #rocdl.object_metadata<{kernel0 = #rocdl.kernel<...>}>
843
+ ```
844
+ }];
845
+ let parameters = (ins
846
+ "DictionaryAttr":$kernel_table
847
+ );
848
+ let assemblyFormat = [{
849
+ `<` $kernel_table `>`
850
+ }];
851
+ let builders = [
852
+ AttrBuilderWithInferredContext<(ins "DictionaryAttr":$kernel_table), [{
853
+ assert(kernel_table && "invalid kernel table");
854
+ return $_get(kernel_table.getContext(), kernel_table);
855
+ }]>
856
+ ];
857
+ let skipDefaultBuilders = 1;
858
+ let genVerifyDecl = 1;
859
+ let extraClassDeclaration = [{
860
+ /// Helper iterator class for traversing the kernel table.
861
+ struct KernelIterator
862
+ : llvm::mapped_iterator_base<KernelIterator,
863
+ llvm::ArrayRef<NamedAttribute>::iterator,
864
+ std::pair<StringAttr, ROCDLKernelAttr>> {
865
+ using llvm::mapped_iterator_base<
866
+ KernelIterator, llvm::ArrayRef<NamedAttribute>::iterator,
867
+ std::pair<StringAttr, ROCDLKernelAttr>>::mapped_iterator_base;
868
+ /// Map the iterator to the kernel name and a KernelAttribute.
869
+ std::pair<StringAttr, ROCDLKernelAttr> mapElement(NamedAttribute attr) const {
870
+ return {attr.getName(), llvm::cast<ROCDLKernelAttr>(attr.getValue())};
871
+ }
872
+ };
873
+ auto begin() const {
874
+ return KernelIterator(getKernelTable().begin());
875
+ }
876
+ auto end() const {
877
+ return KernelIterator(getKernelTable().end());
878
+ }
879
+ size_t size() const {
880
+ return getKernelTable().size();
881
+ }
882
+ }];
883
+ }
884
+
725
885
#endif // ROCDLIR_OPS
0 commit comments