Skip to content

Commit 55afaa6

Browse files
committed
fix usage of autodiff macro with inner functions
- fix errors caused by the move of `ast::Item::ident` (see #138740) - move the logic of getting `sig`, `vis`, and `ident` from two seperate `match` statements into one (less repetition especially with the nested `match`)
1 parent ed20157 commit 55afaa6

File tree

1 file changed

+76
-32
lines changed

1 file changed

+76
-32
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

+76-32
Original file line numberDiff line numberDiff line change
@@ -145,27 +145,46 @@ mod llvm_enzyme {
145145
return vec![item];
146146
}
147147
let dcx = ecx.sess.dcx();
148-
// first get the annotable item:
149-
let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
148+
149+
// first get information about the annotable item:
150+
let (sig, vis, primal) = match &item {
150151
Annotatable::Item(iitem) => {
151-
let (ident, sig) = match &iitem.kind {
152-
ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
152+
let (sig, ident) = match &iitem.kind {
153+
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
153154
_ => {
154155
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
155156
return vec![item];
156157
}
157158
};
158-
(*ident, sig.clone(), false)
159+
(sig.clone(), iitem.vis.clone(), ident.clone())
159160
}
160161
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
161-
let (ident, sig) = match &assoc_item.kind {
162-
ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
162+
let (sig, ident) = match &assoc_item.kind {
163+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
164+
_ => {
165+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
166+
return vec![item];
167+
}
168+
};
169+
(sig.clone(), assoc_item.vis.clone(), ident.clone())
170+
}
171+
Annotatable::Stmt(stmt) => {
172+
let (sig, vis, ident) = match &stmt.kind {
173+
ast::StmtKind::Item(iitem) => match &iitem.kind {
174+
ast::ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
175+
(sig.clone(), iitem.vis.clone(), ident.clone())
176+
}
177+
_ => {
178+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
179+
return vec![item];
180+
}
181+
},
163182
_ => {
164183
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
165184
return vec![item];
166185
}
167186
};
168-
(*ident, sig.clone(), true)
187+
(sig, vis, ident)
169188
}
170189
_ => {
171190
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -184,15 +203,6 @@ mod llvm_enzyme {
184203
let has_ret = has_ret(&sig.decl.output);
185204
let sig_span = ecx.with_call_site_ctxt(sig.span);
186205

187-
let vis = match &item {
188-
Annotatable::Item(iitem) => iitem.vis.clone(),
189-
Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
190-
_ => {
191-
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
192-
return vec![item];
193-
}
194-
};
195-
196206
// create TokenStream from vec elemtents:
197207
// meta_item doesn't have a .tokens field
198208
let comma: Token = Token::new(TokenKind::Comma, Span::default());
@@ -303,6 +313,22 @@ mod llvm_enzyme {
303313
}
304314
Annotatable::AssocItem(assoc_item.clone(), i)
305315
}
316+
Annotatable::Stmt(ref mut stmt) => {
317+
match stmt.kind {
318+
ast::StmtKind::Item(ref mut iitem) => {
319+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
320+
iitem.attrs.push(attr);
321+
}
322+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
323+
{
324+
iitem.attrs.push(inline_never.clone());
325+
}
326+
}
327+
_ => unreachable!("stmt kind checked previously"),
328+
};
329+
330+
Annotatable::Stmt(stmt.clone())
331+
}
306332
_ => {
307333
unreachable!("annotatable kind checked previously")
308334
}
@@ -313,22 +339,40 @@ mod llvm_enzyme {
313339
delim: rustc_ast::token::Delimiter::Parenthesis,
314340
tokens: ts,
315341
});
342+
316343
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
317-
let d_annotatable = if is_impl {
318-
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
319-
let d_fn = P(ast::AssocItem {
320-
attrs: thin_vec![d_attr, inline_never],
321-
id: ast::DUMMY_NODE_ID,
322-
span,
323-
vis,
324-
kind: assoc_item,
325-
tokens: None,
326-
});
327-
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
328-
} else {
329-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
330-
d_fn.vis = vis;
331-
Annotatable::Item(d_fn)
344+
let d_annotatable = match &item {
345+
Annotatable::AssocItem(_, _) => {
346+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
347+
let d_fn = P(ast::AssocItem {
348+
attrs: thin_vec![d_attr, inline_never],
349+
id: ast::DUMMY_NODE_ID,
350+
span,
351+
vis,
352+
kind: assoc_item,
353+
tokens: None,
354+
});
355+
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
356+
}
357+
Annotatable::Item(_) => {
358+
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
359+
d_fn.vis = vis;
360+
361+
Annotatable::Item(d_fn)
362+
}
363+
Annotatable::Stmt(_) => {
364+
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
365+
d_fn.vis = vis;
366+
367+
Annotatable::Stmt(P(ast::Stmt {
368+
id: ast::DUMMY_NODE_ID,
369+
kind: ast::StmtKind::Item(d_fn),
370+
span,
371+
}))
372+
}
373+
_ => {
374+
unreachable!("item kind checked previously")
375+
}
332376
};
333377

334378
return vec![orig_annotatable, d_annotatable];

0 commit comments

Comments
 (0)