@@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed;
9
9
use rustc_hir as hir;
10
10
use rustc_hir:: def:: DefKind ;
11
11
use rustc_hir:: def_id:: { DefId , LocalDefId } ;
12
- use rustc_hir:: { CoroutineKind , Node } ;
12
+ use rustc_hir:: Node ;
13
13
use rustc_index:: bit_set:: GrowableBitSet ;
14
14
use rustc_index:: { Idx , IndexSlice , IndexVec } ;
15
15
use rustc_infer:: infer:: { InferCtxt , TyCtxtInferExt } ;
@@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
177
177
check_overflow : bool ,
178
178
fn_span : Span ,
179
179
arg_count : usize ,
180
- coroutine_kind : Option < CoroutineKind > ,
180
+ coroutine : Option < Box < CoroutineInfo < ' tcx > > > ,
181
181
182
182
/// The current set of scopes, updated as we traverse;
183
183
/// see the `scope` module for more details.
@@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
458
458
) -> Body < ' tcx > {
459
459
let span = tcx. def_span ( fn_def) ;
460
460
let fn_id = tcx. local_def_id_to_hir_id ( fn_def) ;
461
- let coroutine_kind = tcx. coroutine_kind ( fn_def) ;
462
461
463
462
// The representation of thir for `-Zunpretty=thir-tree` relies on
464
463
// the entry expression being the last element of `thir.exprs`.
@@ -488,17 +487,17 @@ fn construct_fn<'tcx>(
488
487
489
488
let arguments = & thir. params ;
490
489
491
- let ( resume_ty , yield_ty , return_ty ) = if coroutine_kind . is_some ( ) {
492
- let coroutine_ty = arguments [ thir :: UPVAR_ENV_PARAM ] . ty ;
493
- let coroutine_sig = match coroutine_ty . kind ( ) {
494
- ty :: Coroutine ( _ , gen_args , .. ) => gen_args . as_coroutine ( ) . sig ( ) ,
495
- _ => {
496
- span_bug ! ( span , "coroutine w/o coroutine type: {:?}" , coroutine_ty )
497
- }
498
- } ;
499
- ( Some ( coroutine_sig . resume_ty ) , Some ( coroutine_sig . yield_ty ) , coroutine_sig . return_ty )
500
- } else {
501
- ( None , None , fn_sig . output ( ) )
490
+ let ( return_ty , coroutine ) = match tcx . type_of ( fn_def ) . instantiate_identity ( ) . kind ( ) {
491
+ ty :: Coroutine ( _ , args ) => (
492
+ fn_sig . output ( ) ,
493
+ Some ( Box :: new ( CoroutineInfo :: initial (
494
+ tcx . coroutine_kind ( fn_def ) . unwrap ( ) ,
495
+ args . as_coroutine ( ) . yield_ty ( ) ,
496
+ args . as_coroutine ( ) . resume_ty ( ) ,
497
+ ) ) ) ,
498
+ ) ,
499
+ ty :: Closure ( .. ) | ty :: FnDef ( .. ) => ( fn_sig . output ( ) , None ) ,
500
+ ty => span_bug ! ( span_with_body , "unexpected type of body: {ty:?}" ) ,
502
501
} ;
503
502
504
503
if let Some ( custom_mir_attr) =
@@ -529,7 +528,7 @@ fn construct_fn<'tcx>(
529
528
safety,
530
529
return_ty,
531
530
return_ty_span,
532
- coroutine_kind ,
531
+ coroutine ,
533
532
) ;
534
533
535
534
let call_site_scope =
@@ -563,11 +562,6 @@ fn construct_fn<'tcx>(
563
562
None
564
563
} ;
565
564
566
- if coroutine_kind. is_some ( ) {
567
- body. coroutine . as_mut ( ) . unwrap ( ) . yield_ty = yield_ty;
568
- body. coroutine . as_mut ( ) . unwrap ( ) . resume_ty = resume_ty;
569
- }
570
-
571
565
body
572
566
}
573
567
@@ -632,45 +626,63 @@ fn construct_const<'a, 'tcx>(
632
626
fn construct_error ( tcx : TyCtxt < ' _ > , def_id : LocalDefId , guar : ErrorGuaranteed ) -> Body < ' _ > {
633
627
let span = tcx. def_span ( def_id) ;
634
628
let hir_id = tcx. local_def_id_to_hir_id ( def_id) ;
635
- let coroutine_kind = tcx. coroutine_kind ( def_id) ;
636
629
637
- let ( inputs, output, resume_ty , yield_ty ) = match tcx. def_kind ( def_id) {
630
+ let ( inputs, output, coroutine ) = match tcx. def_kind ( def_id) {
638
631
DefKind :: Const
639
632
| DefKind :: AssocConst
640
633
| DefKind :: AnonConst
641
634
| DefKind :: InlineConst
642
- | DefKind :: Static ( _) => ( vec ! [ ] , tcx. type_of ( def_id) . instantiate_identity ( ) , None , None ) ,
635
+ | DefKind :: Static ( _) => ( vec ! [ ] , tcx. type_of ( def_id) . instantiate_identity ( ) , None ) ,
643
636
DefKind :: Ctor ( ..) | DefKind :: Fn | DefKind :: AssocFn => {
644
637
let sig = tcx. liberate_late_bound_regions (
645
638
def_id. to_def_id ( ) ,
646
639
tcx. fn_sig ( def_id) . instantiate_identity ( ) ,
647
640
) ;
648
- ( sig. inputs ( ) . to_vec ( ) , sig. output ( ) , None , None )
649
- }
650
- DefKind :: Closure if coroutine_kind. is_some ( ) => {
651
- let coroutine_ty = tcx. type_of ( def_id) . instantiate_identity ( ) ;
652
- let ty:: Coroutine ( _, args) = coroutine_ty. kind ( ) else {
653
- bug ! ( "expected type of coroutine-like closure to be a coroutine" )
654
- } ;
655
- let args = args. as_coroutine ( ) ;
656
- let resume_ty = args. resume_ty ( ) ;
657
- let yield_ty = args. yield_ty ( ) ;
658
- let return_ty = args. return_ty ( ) ;
659
- ( vec ! [ coroutine_ty, args. resume_ty( ) ] , return_ty, Some ( resume_ty) , Some ( yield_ty) )
641
+ ( sig. inputs ( ) . to_vec ( ) , sig. output ( ) , None )
660
642
}
661
643
DefKind :: Closure => {
662
644
let closure_ty = tcx. type_of ( def_id) . instantiate_identity ( ) ;
663
- let ty:: Closure ( _, args) = closure_ty. kind ( ) else {
664
- bug ! ( "expected type of closure to be a closure" )
665
- } ;
666
- let args = args. as_closure ( ) ;
667
- let sig = tcx. liberate_late_bound_regions ( def_id. to_def_id ( ) , args. sig ( ) ) ;
668
- let self_ty = match args. kind ( ) {
669
- ty:: ClosureKind :: Fn => Ty :: new_imm_ref ( tcx, tcx. lifetimes . re_erased , closure_ty) ,
670
- ty:: ClosureKind :: FnMut => Ty :: new_mut_ref ( tcx, tcx. lifetimes . re_erased , closure_ty) ,
671
- ty:: ClosureKind :: FnOnce => closure_ty,
672
- } ;
673
- ( [ self_ty] . into_iter ( ) . chain ( sig. inputs ( ) . to_vec ( ) ) . collect ( ) , sig. output ( ) , None , None )
645
+ match closure_ty. kind ( ) {
646
+ ty:: Closure ( _, args) => {
647
+ let args = args. as_closure ( ) ;
648
+ let sig = tcx. liberate_late_bound_regions ( def_id. to_def_id ( ) , args. sig ( ) ) ;
649
+ let self_ty = match args. kind ( ) {
650
+ ty:: ClosureKind :: Fn => {
651
+ Ty :: new_imm_ref ( tcx, tcx. lifetimes . re_erased , closure_ty)
652
+ }
653
+ ty:: ClosureKind :: FnMut => {
654
+ Ty :: new_mut_ref ( tcx, tcx. lifetimes . re_erased , closure_ty)
655
+ }
656
+ ty:: ClosureKind :: FnOnce => closure_ty,
657
+ } ;
658
+ (
659
+ [ self_ty] . into_iter ( ) . chain ( sig. inputs ( ) . to_vec ( ) ) . collect ( ) ,
660
+ sig. output ( ) ,
661
+ None ,
662
+ )
663
+ }
664
+ ty:: Coroutine ( _, args) => {
665
+ let args = args. as_coroutine ( ) ;
666
+ let resume_ty = args. resume_ty ( ) ;
667
+ let yield_ty = args. yield_ty ( ) ;
668
+ let return_ty = args. return_ty ( ) ;
669
+ (
670
+ vec ! [ closure_ty, args. resume_ty( ) ] ,
671
+ return_ty,
672
+ Some ( Box :: new ( CoroutineInfo :: initial (
673
+ tcx. coroutine_kind ( def_id) . unwrap ( ) ,
674
+ yield_ty,
675
+ resume_ty,
676
+ ) ) ) ,
677
+ )
678
+ }
679
+ _ => {
680
+ span_bug ! (
681
+ tcx. def_span( def_id) ,
682
+ "expected type of closure body to be a closure or coroutine"
683
+ ) ;
684
+ }
685
+ }
674
686
}
675
687
dk => bug ! ( "{:?} is not a body: {:?}" , def_id, dk) ,
676
688
} ;
@@ -696,7 +708,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
696
708
697
709
cfg. terminate ( START_BLOCK , source_info, TerminatorKind :: Unreachable ) ;
698
710
699
- let mut body = Body :: new (
711
+ Body :: new (
700
712
MirSource :: item ( def_id. to_def_id ( ) ) ,
701
713
cfg. basic_blocks ,
702
714
source_scopes,
@@ -705,16 +717,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
705
717
inputs. len ( ) ,
706
718
vec ! [ ] ,
707
719
span,
708
- coroutine_kind ,
720
+ coroutine ,
709
721
Some ( guar) ,
710
- ) ;
711
-
712
- body. coroutine . as_mut ( ) . map ( |gen| {
713
- gen. yield_ty = yield_ty;
714
- gen. resume_ty = resume_ty;
715
- } ) ;
716
-
717
- body
722
+ )
718
723
}
719
724
720
725
impl < ' a , ' tcx > Builder < ' a , ' tcx > {
@@ -728,7 +733,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
728
733
safety : Safety ,
729
734
return_ty : Ty < ' tcx > ,
730
735
return_span : Span ,
731
- coroutine_kind : Option < CoroutineKind > ,
736
+ coroutine : Option < Box < CoroutineInfo < ' tcx > > > ,
732
737
) -> Builder < ' a , ' tcx > {
733
738
let tcx = infcx. tcx ;
734
739
let attrs = tcx. hir ( ) . attrs ( hir_id) ;
@@ -759,7 +764,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
759
764
cfg : CFG { basic_blocks : IndexVec :: new ( ) } ,
760
765
fn_span : span,
761
766
arg_count,
762
- coroutine_kind ,
767
+ coroutine ,
763
768
scopes : scope:: Scopes :: new ( ) ,
764
769
block_context : BlockContext :: new ( ) ,
765
770
source_scopes : IndexVec :: new ( ) ,
@@ -803,7 +808,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
803
808
self . arg_count ,
804
809
self . var_debug_info ,
805
810
self . fn_span ,
806
- self . coroutine_kind ,
811
+ self . coroutine ,
807
812
None ,
808
813
)
809
814
}
0 commit comments