@@ -305,6 +305,7 @@ mod llvm_enzyme {
305
305
let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
306
306
let d_body = gen_enzyme_body (
307
307
ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
308
+ & generics,
308
309
) ;
309
310
310
311
// The first element of it is the name of the function to be generated
@@ -477,6 +478,7 @@ mod llvm_enzyme {
477
478
new_decl_span : Span ,
478
479
idents : & [ Ident ] ,
479
480
errored : bool ,
481
+ generics : & Generics ,
480
482
) -> ( P < ast:: Block > , P < ast:: Expr > , P < ast:: Expr > , P < ast:: Expr > ) {
481
483
let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
482
484
let noop = ast:: InlineAsm {
@@ -499,7 +501,7 @@ mod llvm_enzyme {
499
501
} ;
500
502
let unsf_expr = ecx. expr_block ( P ( unsf_block) ) ;
501
503
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
502
- let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
504
+ let primal_call = gen_primal_call ( ecx, span, primal, idents, generics ) ;
503
505
let black_box_primal_call = ecx. expr_call (
504
506
new_decl_span,
505
507
blackbox_call_expr. clone ( ) ,
@@ -548,6 +550,7 @@ mod llvm_enzyme {
548
550
sig_span : Span ,
549
551
idents : Vec < Ident > ,
550
552
errored : bool ,
553
+ generics : & Generics ,
551
554
) -> P < ast:: Block > {
552
555
let new_decl_span = d_sig. span ;
553
556
@@ -568,6 +571,7 @@ mod llvm_enzyme {
568
571
new_decl_span,
569
572
& idents,
570
573
errored,
574
+ generics,
571
575
) ;
572
576
573
577
if !has_ret ( & d_sig. decl . output ) {
@@ -610,7 +614,6 @@ mod llvm_enzyme {
610
614
panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
611
615
}
612
616
} ;
613
-
614
617
if x. mode . is_fwd ( ) {
615
618
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
616
619
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -670,8 +673,10 @@ mod llvm_enzyme {
670
673
span : Span ,
671
674
primal : Ident ,
672
675
idents : & [ Ident ] ,
676
+ generics : & Generics ,
673
677
) -> P < ast:: Expr > {
674
678
let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
679
+
675
680
if has_self {
676
681
let args: ThinVec < _ > =
677
682
idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
@@ -680,7 +685,51 @@ mod llvm_enzyme {
680
685
} else {
681
686
let args: ThinVec < _ > =
682
687
idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
683
- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
688
+ let mut primal_path = ecx. path_ident ( span, primal) ;
689
+
690
+ let is_generic = !generics. params . is_empty ( ) ;
691
+
692
+ match ( is_generic, primal_path. segments . last_mut ( ) ) {
693
+ ( true , Some ( function_path) ) => {
694
+ let primal_generic_types = generics
695
+ . params
696
+ . iter ( )
697
+ . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
698
+
699
+ let generated_generic_types = primal_generic_types
700
+ . map ( |type_param| {
701
+ let generic_param = TyKind :: Path (
702
+ None ,
703
+ ast:: Path {
704
+ span,
705
+ segments : thin_vec ! [ ast:: PathSegment {
706
+ ident: type_param. ident,
707
+ args: None ,
708
+ id: ast:: DUMMY_NODE_ID ,
709
+ } ] ,
710
+ tokens : None ,
711
+ } ,
712
+ ) ;
713
+
714
+ ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( P ( ast:: Ty {
715
+ id : type_param. id ,
716
+ span,
717
+ kind : generic_param,
718
+ tokens : None ,
719
+ } ) ) )
720
+ } )
721
+ . collect ( ) ;
722
+
723
+ function_path. args =
724
+ Some ( P ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
725
+ span,
726
+ args : generated_generic_types,
727
+ } ) ) ) ;
728
+ }
729
+ _ => { }
730
+ }
731
+
732
+ let primal_call_expr = ecx. expr_path ( primal_path) ;
684
733
ecx. expr_call ( span, primal_call_expr, args)
685
734
}
686
735
}
0 commit comments