@@ -145,27 +145,46 @@ mod llvm_enzyme {
145
145
return vec ! [ item] ;
146
146
}
147
147
let dcx = ecx. sess . dcx ( ) ;
148
- // first get the annotable item:
149
- let ( primal, sig, is_impl) : ( Ident , FnSig , bool ) = match & item {
148
+
149
+ // first get information about the annotable item:
150
+ let ( sig, vis, primal) = match & item {
150
151
Annotatable :: Item ( iitem) => {
151
- let ( ident , sig ) = match & iitem. kind {
152
- ItemKind :: Fn ( box ast:: Fn { ident , sig , .. } ) => ( ident , sig ) ,
152
+ let ( sig , ident ) = match & iitem. kind {
153
+ ItemKind :: Fn ( box ast:: Fn { sig , ident , .. } ) => ( sig , ident ) ,
153
154
_ => {
154
155
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
155
156
return vec ! [ item] ;
156
157
}
157
158
} ;
158
- ( * ident , sig . clone ( ) , false )
159
+ ( sig . clone ( ) , iitem . vis . clone ( ) , ident . clone ( ) )
159
160
}
160
161
Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
161
- let ( ident, sig) = match & assoc_item. kind {
162
- ast:: AssocItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
162
+ let ( sig, ident) = match & assoc_item. kind {
163
+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => ( sig, ident) ,
164
+ _ => {
165
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
166
+ return vec ! [ item] ;
167
+ }
168
+ } ;
169
+ ( sig. clone ( ) , assoc_item. vis . clone ( ) , ident. clone ( ) )
170
+ }
171
+ Annotatable :: Stmt ( stmt) => {
172
+ let ( sig, vis, ident) = match & stmt. kind {
173
+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
174
+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
175
+ ( sig. clone ( ) , iitem. vis . clone ( ) , ident. clone ( ) )
176
+ }
177
+ _ => {
178
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
179
+ return vec ! [ item] ;
180
+ }
181
+ } ,
163
182
_ => {
164
183
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
165
184
return vec ! [ item] ;
166
185
}
167
186
} ;
168
- ( * ident , sig . clone ( ) , true )
187
+ ( sig , vis , ident )
169
188
}
170
189
_ => {
171
190
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -184,15 +203,6 @@ mod llvm_enzyme {
184
203
let has_ret = has_ret ( & sig. decl . output ) ;
185
204
let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
186
205
187
- let vis = match & item {
188
- Annotatable :: Item ( iitem) => iitem. vis . clone ( ) ,
189
- Annotatable :: AssocItem ( assoc_item, _) => assoc_item. vis . clone ( ) ,
190
- _ => {
191
- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
192
- return vec ! [ item] ;
193
- }
194
- } ;
195
-
196
206
// create TokenStream from vec elemtents:
197
207
// meta_item doesn't have a .tokens field
198
208
let comma: Token = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
@@ -303,6 +313,22 @@ mod llvm_enzyme {
303
313
}
304
314
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
305
315
}
316
+ Annotatable :: Stmt ( ref mut stmt) => {
317
+ match stmt. kind {
318
+ ast:: StmtKind :: Item ( ref mut iitem) => {
319
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
320
+ iitem. attrs . push ( attr) ;
321
+ }
322
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
323
+ {
324
+ iitem. attrs . push ( inline_never. clone ( ) ) ;
325
+ }
326
+ }
327
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
328
+ } ;
329
+
330
+ Annotatable :: Stmt ( stmt. clone ( ) )
331
+ }
306
332
_ => {
307
333
unreachable ! ( "annotatable kind checked previously" )
308
334
}
@@ -313,22 +339,40 @@ mod llvm_enzyme {
313
339
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
314
340
tokens : ts,
315
341
} ) ;
342
+
316
343
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
317
- let d_annotatable = if is_impl {
318
- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
319
- let d_fn = P ( ast:: AssocItem {
320
- attrs : thin_vec ! [ d_attr, inline_never] ,
321
- id : ast:: DUMMY_NODE_ID ,
322
- span,
323
- vis,
324
- kind : assoc_item,
325
- tokens : None ,
326
- } ) ;
327
- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
328
- } else {
329
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
330
- d_fn. vis = vis;
331
- Annotatable :: Item ( d_fn)
344
+ let d_annotatable = match & item {
345
+ Annotatable :: AssocItem ( _, _) => {
346
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
347
+ let d_fn = P ( ast:: AssocItem {
348
+ attrs : thin_vec ! [ d_attr, inline_never] ,
349
+ id : ast:: DUMMY_NODE_ID ,
350
+ span,
351
+ vis,
352
+ kind : assoc_item,
353
+ tokens : None ,
354
+ } ) ;
355
+ Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
356
+ }
357
+ Annotatable :: Item ( _) => {
358
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
359
+ d_fn. vis = vis;
360
+
361
+ Annotatable :: Item ( d_fn)
362
+ }
363
+ Annotatable :: Stmt ( _) => {
364
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
365
+ d_fn. vis = vis;
366
+
367
+ Annotatable :: Stmt ( P ( ast:: Stmt {
368
+ id : ast:: DUMMY_NODE_ID ,
369
+ kind : ast:: StmtKind :: Item ( d_fn) ,
370
+ span,
371
+ } ) )
372
+ }
373
+ _ => {
374
+ unreachable ! ( "item kind checked previously" )
375
+ }
332
376
} ;
333
377
334
378
return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments