Skip to content

Commit c7f3aac

Browse files
committed
feat: add generated parameters to generated function
- update pretty printing tests - only add generic parameters when function is actually generic (no empty turbofish)
1 parent a8c9242 commit c7f3aac

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

+52-3
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ mod llvm_enzyme {
305305
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
306306
let d_body = gen_enzyme_body(
307307
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
308+
&generics,
308309
);
309310

310311
// The first element of it is the name of the function to be generated
@@ -477,6 +478,7 @@ mod llvm_enzyme {
477478
new_decl_span: Span,
478479
idents: &[Ident],
479480
errored: bool,
481+
generics: &Generics,
480482
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
481483
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
482484
let noop = ast::InlineAsm {
@@ -499,7 +501,7 @@ mod llvm_enzyme {
499501
};
500502
let unsf_expr = ecx.expr_block(P(unsf_block));
501503
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
502-
let primal_call = gen_primal_call(ecx, span, primal, idents);
504+
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
503505
let black_box_primal_call = ecx.expr_call(
504506
new_decl_span,
505507
blackbox_call_expr.clone(),
@@ -548,6 +550,7 @@ mod llvm_enzyme {
548550
sig_span: Span,
549551
idents: Vec<Ident>,
550552
errored: bool,
553+
generics: &Generics,
551554
) -> P<ast::Block> {
552555
let new_decl_span = d_sig.span;
553556

@@ -568,6 +571,7 @@ mod llvm_enzyme {
568571
new_decl_span,
569572
&idents,
570573
errored,
574+
generics,
571575
);
572576

573577
if !has_ret(&d_sig.decl.output) {
@@ -610,7 +614,6 @@ mod llvm_enzyme {
610614
panic!("Did not expect Default ret ty: {:?}", span);
611615
}
612616
};
613-
614617
if x.mode.is_fwd() {
615618
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
616619
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -670,8 +673,10 @@ mod llvm_enzyme {
670673
span: Span,
671674
primal: Ident,
672675
idents: &[Ident],
676+
generics: &Generics,
673677
) -> P<ast::Expr> {
674678
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
679+
675680
if has_self {
676681
let args: ThinVec<_> =
677682
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@@ -680,7 +685,51 @@ mod llvm_enzyme {
680685
} else {
681686
let args: ThinVec<_> =
682687
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
683-
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
688+
let mut primal_path = ecx.path_ident(span, primal);
689+
690+
let is_generic = !generics.params.is_empty();
691+
692+
match (is_generic, primal_path.segments.last_mut()) {
693+
(true, Some(function_path)) => {
694+
let primal_generic_types = generics
695+
.params
696+
.iter()
697+
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
698+
699+
let generated_generic_types = primal_generic_types
700+
.map(|type_param| {
701+
let generic_param = TyKind::Path(
702+
None,
703+
ast::Path {
704+
span,
705+
segments: thin_vec![ast::PathSegment {
706+
ident: type_param.ident,
707+
args: None,
708+
id: ast::DUMMY_NODE_ID,
709+
}],
710+
tokens: None,
711+
},
712+
);
713+
714+
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
715+
id: type_param.id,
716+
span,
717+
kind: generic_param,
718+
tokens: None,
719+
})))
720+
})
721+
.collect();
722+
723+
function_path.args =
724+
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
725+
span,
726+
args: generated_generic_types,
727+
})));
728+
}
729+
_ => {}
730+
}
731+
732+
let primal_call_expr = ecx.expr_path(primal_path);
684733
ecx.expr_call(span, primal_call_expr, args)
685734
}
686735
}

tests/pretty/autodiff/autodiff_forward.pp

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@
191191
pub fn d_square<T: std::ops::Mul<Output = T> +
192192
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
193193
unsafe { asm!("NOP", options(pure, nomem)); };
194-
::core::hint::black_box(f10(x));
194+
::core::hint::black_box(f10::<T>(x));
195195
::core::hint::black_box((dx_0, dret));
196-
::core::hint::black_box(f10(x))
196+
::core::hint::black_box(f10::<T>(x))
197197
}
198198
fn main() {}

0 commit comments

Comments
 (0)