Skip to content

Commit 8af94ce

Browse files
Use a helper to zip together parent and child captures for coroutine-closures
1 parent 54a93ab commit 8af94ce

File tree

3 files changed

+61
-73
lines changed

3 files changed

+61
-73
lines changed

compiler/rustc_middle/src/ty/closure.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::{mir, ty};
66
use std::fmt::Write;
77

88
use crate::query::Providers;
9+
use rustc_data_structures::captures::Captures;
910
use rustc_data_structures::fx::FxIndexMap;
1011
use rustc_hir as hir;
1112
use rustc_hir::def_id::LocalDefId;
@@ -415,6 +416,52 @@ impl BorrowKind {
415416
}
416417
}
417418

419+
pub fn analyze_coroutine_closure_captures<'a, 'tcx: 'a, T>(
420+
parent_captures: impl IntoIterator<Item = &'a CapturedPlace<'tcx>>,
421+
child_captures: impl IntoIterator<Item = &'a CapturedPlace<'tcx>>,
422+
mut for_each: impl FnMut((usize, &'a CapturedPlace<'tcx>), (usize, &'a CapturedPlace<'tcx>)) -> T,
423+
) -> impl Iterator<Item = T> + Captures<'a> + Captures<'tcx> {
424+
let mut parent_captures = parent_captures.into_iter().enumerate().peekable();
425+
// Make sure we use every field at least once, b/c why are we capturing something
426+
// if it's not used in the inner coroutine.
427+
let mut field_used_at_least_once = false;
428+
child_captures.into_iter().enumerate().map(move |(child_field_idx, child_capture)| {
429+
loop {
430+
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
431+
bug!("we ran out of parent captures!")
432+
};
433+
434+
let HirPlaceBase::Upvar(parent_base) = parent_capture.place.base else {
435+
bug!("expected capture to be an upvar");
436+
};
437+
let HirPlaceBase::Upvar(child_base) = child_capture.place.base else {
438+
bug!("expected capture to be an upvar");
439+
};
440+
441+
if parent_base.var_path.hir_id != child_base.var_path.hir_id
442+
|| !std::iter::zip(
443+
&child_capture.place.projections,
444+
&parent_capture.place.projections,
445+
)
446+
.all(|(child, parent)| child.kind == parent.kind)
447+
{
448+
// Make sure the field was used at least once.
449+
assert!(
450+
field_used_at_least_once,
451+
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
452+
);
453+
field_used_at_least_once = false;
454+
// Skip this field.
455+
let _ = parent_captures.next().unwrap();
456+
continue;
457+
}
458+
459+
field_used_at_least_once = true;
460+
break for_each((parent_field_idx, parent_capture), (child_field_idx, child_capture));
461+
}
462+
})
463+
}
464+
418465
pub fn provide(providers: &mut Providers) {
419466
*providers = Providers { closure_typeinfo, ..*providers }
420467
}

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ pub use rustc_type_ir::ConstKind::{
7777
pub use rustc_type_ir::*;
7878

7979
pub use self::closure::{
80-
is_ancestor_or_same_capture, place_to_string_for_capture, BorrowKind, CaptureInfo,
81-
CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, MinCaptureList,
82-
RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath, CAPTURE_STRUCT_LOCAL,
80+
analyze_coroutine_closure_captures, is_ancestor_or_same_capture, place_to_string_for_capture,
81+
BorrowKind, CaptureInfo, CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap,
82+
MinCaptureList, RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath,
83+
CAPTURE_STRUCT_LOCAL,
8384
};
8485
pub use self::consts::{
8586
Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree,

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
7272
use rustc_data_structures::unord::UnordMap;
7373
use rustc_hir as hir;
74-
use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind};
74+
use rustc_middle::hir::place::{Projection, ProjectionKind};
7575
use rustc_middle::mir::visit::MutVisitor;
7676
use rustc_middle::mir::{self, dump_mir, MirPass};
7777
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
@@ -124,62 +124,10 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
124124
.tuple_fields()
125125
.len();
126126

127-
let mut field_remapping = UnordMap::default();
128-
129-
// One parent capture may correspond to several child captures if we end up
130-
// refining the set of captures via edition-2021 precise captures. We want to
131-
// match up any number of child captures with one parent capture, so we keep
132-
// peeking off this `Peekable` until the child doesn't match anymore.
133-
let mut parent_captures =
134-
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
135-
// Make sure we use every field at least once, b/c why are we capturing something
136-
// if it's not used in the inner coroutine.
137-
let mut field_used_at_least_once = false;
138-
139-
for (child_field_idx, child_capture) in tcx
140-
.closure_captures(coroutine_def_id)
141-
.iter()
142-
.copied()
143-
// By construction we capture all the args first.
144-
.skip(num_args)
145-
.enumerate()
146-
{
147-
loop {
148-
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
149-
bug!("we ran out of parent captures!")
150-
};
151-
152-
let PlaceBase::Upvar(parent_base) = parent_capture.place.base else {
153-
bug!("expected capture to be an upvar");
154-
};
155-
let PlaceBase::Upvar(child_base) = child_capture.place.base else {
156-
bug!("expected capture to be an upvar");
157-
};
158-
159-
assert!(
160-
child_capture.place.projections.len() >= parent_capture.place.projections.len()
161-
);
162-
// A parent matches a child they share the same prefix of projections.
163-
// The child may have more, if it is capturing sub-fields out of
164-
// something that is captured by-move in the parent closure.
165-
if parent_base.var_path.hir_id != child_base.var_path.hir_id
166-
|| !std::iter::zip(
167-
&child_capture.place.projections,
168-
&parent_capture.place.projections,
169-
)
170-
.all(|(child, parent)| child.kind == parent.kind)
171-
{
172-
// Make sure the field was used at least once.
173-
assert!(
174-
field_used_at_least_once,
175-
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
176-
);
177-
field_used_at_least_once = false;
178-
// Skip this field.
179-
let _ = parent_captures.next().unwrap();
180-
continue;
181-
}
182-
127+
let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures(
128+
tcx.closure_captures(parent_def_id).iter().copied(),
129+
tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(),
130+
|(parent_field_idx, parent_capture), (child_field_idx, child_capture)| {
183131
// Store this set of additional projections (fields and derefs).
184132
// We need to re-apply them later.
185133
let child_precise_captures =
@@ -210,26 +158,18 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
210158
),
211159
};
212160

213-
field_remapping.insert(
161+
(
214162
FieldIdx::from_usize(child_field_idx + num_args),
215163
(
216164
FieldIdx::from_usize(parent_field_idx + num_args),
217165
parent_capture_ty,
218166
needs_deref,
219167
child_precise_captures,
220168
),
221-
);
222-
223-
field_used_at_least_once = true;
224-
break;
225-
}
226-
}
227-
228-
// Pop the last parent capture
229-
if field_used_at_least_once {
230-
let _ = parent_captures.next().unwrap();
231-
}
232-
assert_eq!(parent_captures.next(), None, "leftover parent captures?");
169+
)
170+
},
171+
)
172+
.collect();
233173

234174
if coroutine_kind == ty::ClosureKind::FnOnce {
235175
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());

0 commit comments

Comments
 (0)