@@ -3,6 +3,7 @@ use core::ops::ControlFlow;
3
3
use hir:: def:: CtorKind ;
4
4
use hir:: intravisit:: { Visitor , walk_expr, walk_stmt} ;
5
5
use hir:: { LetStmt , QPath } ;
6
+ use itertools:: { EitherOrBoth , Itertools } ;
6
7
use rustc_data_structures:: fx:: FxIndexSet ;
7
8
use rustc_errors:: { Applicability , Diag } ;
8
9
use rustc_hir as hir;
@@ -20,7 +21,7 @@ use tracing::debug;
20
21
use crate :: error_reporting:: TypeErrCtxt ;
21
22
use crate :: error_reporting:: infer:: hir:: Path ;
22
23
use crate :: errors:: {
23
- ConsiderAddingAwait , FnConsiderCasting , FnItemsAreDistinct , FnUniqTypes ,
24
+ ConsiderAddingAwait , FnConsiderCasting , FnConsiderCastingBoth , FnItemsAreDistinct , FnUniqTypes ,
24
25
FunctionPointerSuggestion , SuggestAccessingField , SuggestRemoveSemiOrReturnBinding ,
25
26
SuggestTuplePatternMany , SuggestTuplePatternOne , TypeErrorAdditionalDiags ,
26
27
} ;
@@ -369,6 +370,133 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
369
370
}
370
371
}
371
372
373
+ fn find_mismatched_fn_item (
374
+ & self ,
375
+ ty1 : Ty < ' tcx > ,
376
+ ty2 : Ty < ' tcx > ,
377
+ ) -> Option < ( Ty < ' tcx > , Ty < ' tcx > ) > {
378
+ if let Some ( fns) = self . find_mismatched_fn_items ( ty1, ty2)
379
+ && fns. len ( ) == 1
380
+ {
381
+ Some ( fns[ 0 ] )
382
+ } else {
383
+ None
384
+ }
385
+ }
386
+
387
+ fn find_mismatched_fn_items (
388
+ & self ,
389
+ ty1 : Ty < ' tcx > ,
390
+ ty2 : Ty < ' tcx > ,
391
+ ) -> Option < Vec < ( Ty < ' tcx > , Ty < ' tcx > ) > > {
392
+ match ( ty1. kind ( ) , ty2. kind ( ) ) {
393
+ ( & ty:: Adt ( def1, sub1) , & ty:: Adt ( def2, sub2) ) if sub1. len ( ) == sub2. len ( ) => {
394
+ let did1 = def1. did ( ) ;
395
+ let did2 = def2. did ( ) ;
396
+
397
+ if did1 != did2 {
398
+ return None ;
399
+ }
400
+
401
+ for lifetime in sub1. regions ( ) . zip_longest ( sub2. regions ( ) ) {
402
+ match lifetime {
403
+ EitherOrBoth :: Both ( l1, l2) if l1 == l2 => continue ,
404
+ _ => return None ,
405
+ }
406
+ }
407
+
408
+ for ca in sub1. consts ( ) . zip_longest ( sub2. consts ( ) ) {
409
+ match ca {
410
+ EitherOrBoth :: Both ( c1, c2) if c1 == c2 => continue ,
411
+ _ => return None ,
412
+ }
413
+ }
414
+
415
+ let mut fns = Vec :: new ( ) ;
416
+ for ty in sub1. types ( ) . zip_longest ( sub2. types ( ) ) {
417
+ match ty {
418
+ EitherOrBoth :: Both ( t1, t2) => {
419
+ let Some ( new_fns) = self . find_mismatched_fn_items ( t1, t2) else {
420
+ return None ;
421
+ } ;
422
+
423
+ fns. extend ( new_fns) ;
424
+ }
425
+ _ => return None ,
426
+ }
427
+ }
428
+ Some ( fns)
429
+ }
430
+
431
+ ( & ty:: Tuple ( args1) , & ty:: Tuple ( args2) ) if args1. len ( ) == args2. len ( ) => {
432
+ let mut fns = Vec :: new ( ) ;
433
+ for ( left, right) in args1. iter ( ) . zip ( args2) {
434
+ let Some ( new_fns) = self . find_mismatched_fn_items ( left, right) else {
435
+ return None ;
436
+ } ;
437
+ fns. extend ( new_fns) ;
438
+ }
439
+ Some ( fns)
440
+ }
441
+
442
+ ( ty:: FnDef ( did, args) , ty:: FnPtr ( sig_tys, hdr) ) => {
443
+ let sig1 =
444
+ & ( self . normalize_fn_sig ) ( self . tcx . fn_sig ( * did) . instantiate ( self . tcx , args) ) ;
445
+ let sig2 = & ( self . normalize_fn_sig ) ( sig_tys. with ( * hdr) ) ;
446
+ self . same_type_modulo_infer ( * sig1, * sig2) . then ( || vec ! [ ( ty1, ty2) ] )
447
+ }
448
+
449
+ ( ty:: FnDef ( did1, args1) , ty:: FnDef ( did2, args2) ) => {
450
+ let sig1 =
451
+ & ( self . normalize_fn_sig ) ( self . tcx . fn_sig ( * did1) . instantiate ( self . tcx , args1) ) ;
452
+ let sig2 =
453
+ & ( self . normalize_fn_sig ) ( self . tcx . fn_sig ( * did2) . instantiate ( self . tcx , args2) ) ;
454
+ self . same_type_modulo_infer ( * sig1, * sig2) . then ( || vec ! [ ( ty1, ty2) ] )
455
+ }
456
+
457
+ ( ty:: FnPtr ( sig_tys, hdr) , ty:: FnDef ( did, args) ) => {
458
+ let sig1 = & ( self . normalize_fn_sig ) ( sig_tys. with ( * hdr) ) ;
459
+ let sig2 =
460
+ & ( self . normalize_fn_sig ) ( self . tcx . fn_sig ( * did) . instantiate ( self . tcx , args) ) ;
461
+ self . same_type_modulo_infer ( * sig1, * sig2) . then ( || vec ! [ ( ty1, ty2) ] )
462
+ }
463
+
464
+ _ => ty1. eq ( & ty2) . then ( || Vec :: new ( ) ) ,
465
+ }
466
+ }
467
+
468
+ pub fn suggest_function_pointers_simple (
469
+ & self ,
470
+ diag : & mut Diag < ' _ > ,
471
+ found : Ty < ' tcx > ,
472
+ expected : Ty < ' tcx > ,
473
+ ) {
474
+ let Some ( ( found, expected) ) = self . find_mismatched_fn_item ( found, expected) else {
475
+ return ;
476
+ } ;
477
+
478
+ match ( expected. kind ( ) , found. kind ( ) ) {
479
+ ( ty:: FnPtr ( sig_tys, hdr) , ty:: FnDef ( did, args) )
480
+ | ( ty:: FnDef ( did, args) , ty:: FnPtr ( sig_tys, hdr) ) => {
481
+ let sig = sig_tys. with ( * hdr) ;
482
+
483
+ let fn_name = self . tcx . def_path_str_with_args ( * did, args) ;
484
+ let casting = format ! ( "{fn_name} as {sig}" ) ;
485
+
486
+ diag. subdiagnostic ( FnItemsAreDistinct ) ;
487
+ diag. subdiagnostic ( FnConsiderCasting { casting } ) ;
488
+ }
489
+ ( ty:: FnDef ( did, args) , ty:: FnDef ( ..) ) => {
490
+ let sig =
491
+ ( self . normalize_fn_sig ) ( self . tcx . fn_sig ( * did) . instantiate ( self . tcx , args) ) ;
492
+
493
+ diag. subdiagnostic ( FnUniqTypes ) ;
494
+ diag. subdiagnostic ( FnConsiderCastingBoth { sig } ) ;
495
+ }
496
+ _ => ( ) ,
497
+ } ;
498
+ }
499
+
372
500
pub ( super ) fn suggest_function_pointers (
373
501
& self ,
374
502
cause : & ObligationCause < ' tcx > ,
@@ -381,6 +509,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
381
509
let expected_inner = expected. peel_refs ( ) ;
382
510
let found_inner = found. peel_refs ( ) ;
383
511
if !expected_inner. is_fn ( ) || !found_inner. is_fn ( ) {
512
+ self . suggest_function_pointers_simple ( diag, * found, * expected) ;
384
513
return ;
385
514
}
386
515
match ( expected_inner. kind ( ) , found_inner. kind ( ) ) {
0 commit comments