16
16
include "mlir/Dialect/GPU/IR/GPUBase.td"
17
17
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
18
18
19
+ //===----------------------------------------------------------------------===//
20
+ // GPU kernel metadata attribute
21
+ //===----------------------------------------------------------------------===//
22
+
23
+ def GPU_KernelMetadataAttr : GPU_Attr<"KernelMetadata", "kernel_metadata"> {
24
+ let description = [{
25
+ GPU attribute for storing metadata related to a compiled kernel. The
26
+ attribute contains the name and arguments type of the kernel.
27
+
28
+ The attribute also contains optional parameters for storing the arguments
29
+ attributes as well as a dictionary for additional metadata, like occupancy
30
+ information or other function attributes.
31
+
32
+ Note: The `arg_attrs` parameter is expected to follow all the constraints
33
+ imposed by the `mlir::FunctionOpInterface` interface.
34
+
35
+ Examples:
36
+ ```mlir
37
+ #gpu.kernel_metadata<@kernel1, (i32) -> (), arg_attrs = [...], metadata = {reg_count = 255, ...}>
38
+ #gpu.kernel_metadata<@kernel2, (i32, f64) -> ()>
39
+ ```
40
+ }];
41
+ let parameters = (ins
42
+ "StringAttr":$name,
43
+ "Type":$function_type,
44
+ OptionalParameter<"ArrayAttr", "arguments attributes">:$arg_attrs,
45
+ OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
46
+ );
47
+ let assemblyFormat = [{
48
+ `<` $name `,` $function_type (`,` struct($arg_attrs, $metadata)^)? `>`
49
+ }];
50
+ let builders = [
51
+ AttrBuilderWithInferredContext<(ins "StringAttr":$name,
52
+ "Type":$functionType,
53
+ CArg<"ArrayAttr", "nullptr">:$argAttrs,
54
+ CArg<"DictionaryAttr",
55
+ "nullptr">:$metadata), [{
56
+ assert(name && "invalid name");
57
+ return $_get(name.getContext(), name, functionType, argAttrs, metadata);
58
+ }]>,
59
+ AttrBuilderWithInferredContext<(ins "FunctionOpInterface":$kernel,
60
+ CArg<"DictionaryAttr",
61
+ "nullptr">:$metadata)>
62
+ ];
63
+ let genVerifyDecl = 1;
64
+ let extraClassDeclaration = [{
65
+ /// Compare two kernels based on the name.
66
+ bool operator<(const KernelMetadataAttr& other) const {
67
+ return getName().getValue() < other.getName().getValue();
68
+ }
69
+
70
+ /// Returns the metadata attribute corresponding to `key` or `nullptr`
71
+ /// if missing.
72
+ Attribute getAttr(StringRef key) const {
73
+ DictionaryAttr attrs = getMetadata();
74
+ return attrs ? attrs.get(key) : nullptr;
75
+ }
76
+ template <typename ConcreteAttr>
77
+ ConcreteAttr getAttr(StringRef key) const {
78
+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
79
+ }
80
+ Attribute getAttr(StringAttr key) const {
81
+ DictionaryAttr attrs = getMetadata();
82
+ return attrs ? attrs.get(key) : nullptr;
83
+ }
84
+ template <typename ConcreteAttr>
85
+ ConcreteAttr getAttr(StringAttr key) const {
86
+ return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
87
+ }
88
+
89
+ /// Returns the attribute dictionary at position `index`.
90
+ DictionaryAttr getArgAttrDict(unsigned index) {
91
+ ArrayAttr argArray = getArgAttrs();
92
+ return argArray ? llvm::cast<DictionaryAttr>(argArray[index]) : nullptr;
93
+ }
94
+
95
+ /// Return the specified attribute, if present, for the argument at 'index',
96
+ /// null otherwise.
97
+ Attribute getArgAttr(unsigned index, StringAttr name) {
98
+ DictionaryAttr argDict = getArgAttrDict(index);
99
+ return argDict ? argDict.get(name) : nullptr;
100
+ }
101
+ Attribute getArgAttr(unsigned index, StringRef name) {
102
+ DictionaryAttr argDict = getArgAttrDict(index);
103
+ return argDict ? argDict.get(name) : nullptr;
104
+ }
105
+
106
+ /// Returns a new KernelMetadataAttr that contains `attrs` in the metadata dictionary.
107
+ KernelMetadataAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
108
+ }];
109
+ }
110
+
111
+ //===----------------------------------------------------------------------===//
112
+ // GPU kernel table attribute
113
+ //===----------------------------------------------------------------------===//
114
+
115
+ def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
116
+ let description = [{
117
+ GPU attribute representing a list of `#gpu.kernel_metadata` attributes. This
118
+ attribute supports searching kernels by name. All kernels in the table must
119
+ have an unique name.
120
+
121
+ Examples:
122
+ ```mlir
123
+ // Empty table.
124
+ #gpu.kernel_table<>
125
+
126
+ // Table with a single kernel.
127
+ #gpu.kernel_table<[#gpu.kernel_metadata<kernel0, () -> () >]>
128
+
129
+ // Table with multiple kernels.
130
+ #gpu.kernel_table<[
131
+ #gpu.kernel_metadata<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
132
+ #gpu.kernel_metadata<"kernel1", (i32) -> ()>
133
+ ]>
134
+ ```
135
+ }];
136
+ let parameters = (ins
137
+ OptionalArrayRefParameter<"KernelMetadataAttr", "array of kernels">:$kernel_table
138
+ );
139
+ let assemblyFormat = [{
140
+ `<` (`[` qualified($kernel_table)^ `]`)? `>`
141
+ }];
142
+ let builders = [
143
+ AttrBuilder<(ins "ArrayRef<KernelMetadataAttr>":$kernels,
144
+ CArg<"bool", "false">:$isSorted)>
145
+ ];
146
+ let skipDefaultBuilders = 1;
147
+ let genVerifyDecl = 1;
148
+ let extraClassDeclaration = [{
149
+ llvm::ArrayRef<KernelMetadataAttr>::iterator begin() const {
150
+ return getKernelTable().begin();
151
+ }
152
+ llvm::ArrayRef<KernelMetadataAttr>::iterator end() const {
153
+ return getKernelTable().end();
154
+ }
155
+ size_t size() const {
156
+ return getKernelTable().size();
157
+ }
158
+ bool empty() const {
159
+ return getKernelTable().empty();
160
+ }
161
+
162
+ /// Returns the kernel with name `key` or `nullptr` if not present.
163
+ KernelMetadataAttr lookup(StringRef key) const;
164
+ KernelMetadataAttr lookup(StringAttr key) const;
165
+ }];
166
+ }
167
+
19
168
//===----------------------------------------------------------------------===//
20
169
// GPU object attribute.
21
170
//===----------------------------------------------------------------------===//
@@ -36,8 +185,9 @@ def GPU_CompilationTargetEnum : GPU_I32Enum<
36
185
def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
37
186
let description = [{
38
187
A GPU object attribute glues together a GPU target, the object kind, a
39
- binary string with the object, and the object properties, encapsulating how
40
- the object was generated and its properties with the object itself.
188
+ binary string with the object, the object properties, and kernel metadata,
189
+ encapsulating how the object was generated and its properties with the
190
+ object itself.
41
191
42
192
There are four object formats:
43
193
1. `Offload`: represents generic objects not described by the other three
@@ -55,6 +205,10 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
55
205
56
206
Object properties are specified through the `properties` dictionary
57
207
attribute and can be used to define additional information.
208
+
209
+ Kernel metadata is specified through the `kernels` parameter, and can be
210
+ used to specify additional information on a kernel by kernel basis.
211
+
58
212
The target attribute must implement or promise the `TargetAttrInterface`
59
213
interface.
60
214
@@ -63,16 +217,29 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
63
217
#gpu.object<#nvvm.target, properties = {O = 3 : i32}, assembly = "..."> // An assembly object with additional properties.
64
218
#gpu.object<#rocdl.target, bin = "..."> // A binary object.
65
219
#gpu.object<#nvvm.target, "..."> // A fatbin object.
220
+ #gpu.object<#nvvm.target, kernels = #gpu.kernel_table<...>, "..."> // An object with a kernel table.
66
221
```
67
222
}];
68
223
let parameters = (ins
69
224
"Attribute":$target,
70
225
DefaultValuedParameter<"CompilationTarget", "CompilationTarget::Fatbin">:$format,
71
226
"StringAttr":$object,
72
- OptionalParameter<"DictionaryAttr">:$properties
227
+ OptionalParameter<"DictionaryAttr">:$properties,
228
+ OptionalParameter<"KernelTableAttr">:$kernels
73
229
);
230
+ let builders = [
231
+ AttrBuilderWithInferredContext<(ins "Attribute":$target,
232
+ "CompilationTarget":$format,
233
+ "StringAttr":$object,
234
+ CArg<"DictionaryAttr", "nullptr">:$properties,
235
+ CArg<"KernelTableAttr", "nullptr">:$kernels), [{
236
+ assert(target && "invalid target");
237
+ return $_get(target.getContext(), target, format, object, properties, kernels);
238
+ }]>
239
+ ];
74
240
let assemblyFormat = [{ `<`
75
- $target `,` (`properties` `=` $properties ^ `,`)?
241
+ $target `,` (`properties` `=` $properties^ `,`)?
242
+ (`kernels` `=` $kernels^ `,`)?
76
243
custom<Object>($format, $object)
77
244
`>`
78
245
}];
0 commit comments