Skip to content

Commit de8a305

Browse files
committed
address code review
splitting off futures/generators in trait selection / checking
1 parent dfabfd8 commit de8a305

File tree

10 files changed

+264
-159
lines changed

10 files changed

+264
-159
lines changed

compiler/rustc_middle/src/traits/mod.rs

+20-2
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,9 @@ pub enum ImplSource<'tcx, N> {
660660
/// ImplSource automatically generated for a generator.
661661
Generator(ImplSourceGeneratorData<'tcx, N>),
662662

663+
/// ImplSource automatically generated for a generator backing an async future.
664+
Future(ImplSourceFutureData<'tcx, N>),
665+
663666
/// ImplSource for a trait alias.
664667
TraitAlias(ImplSourceTraitAliasData<'tcx, N>),
665668

@@ -676,6 +679,7 @@ impl<'tcx, N> ImplSource<'tcx, N> {
676679
ImplSource::AutoImpl(d) => d.nested,
677680
ImplSource::Closure(c) => c.nested,
678681
ImplSource::Generator(c) => c.nested,
682+
ImplSource::Future(c) => c.nested,
679683
ImplSource::Object(d) => d.nested,
680684
ImplSource::FnPointer(d) => d.nested,
681685
ImplSource::DiscriminantKind(ImplSourceDiscriminantKindData)
@@ -694,6 +698,7 @@ impl<'tcx, N> ImplSource<'tcx, N> {
694698
ImplSource::AutoImpl(d) => &d.nested,
695699
ImplSource::Closure(c) => &c.nested,
696700
ImplSource::Generator(c) => &c.nested,
701+
ImplSource::Future(c) => &c.nested,
697702
ImplSource::Object(d) => &d.nested,
698703
ImplSource::FnPointer(d) => &d.nested,
699704
ImplSource::DiscriminantKind(ImplSourceDiscriminantKindData)
@@ -736,7 +741,11 @@ impl<'tcx, N> ImplSource<'tcx, N> {
736741
generator_def_id: c.generator_def_id,
737742
substs: c.substs,
738743
nested: c.nested.into_iter().map(f).collect(),
739-
is_async: c.is_async,
744+
}),
745+
ImplSource::Future(c) => ImplSource::Future(ImplSourceFutureData {
746+
generator_def_id: c.generator_def_id,
747+
substs: c.substs,
748+
nested: c.nested.into_iter().map(f).collect(),
740749
}),
741750
ImplSource::FnPointer(p) => ImplSource::FnPointer(ImplSourceFnPointerData {
742751
fn_ty: p.fn_ty,
@@ -795,7 +804,16 @@ pub struct ImplSourceGeneratorData<'tcx, N> {
795804
/// Nested obligations. This can be non-empty if the generator
796805
/// signature contains associated types.
797806
pub nested: Vec<N>,
798-
pub is_async: bool,
807+
}
808+
809+
#[derive(Clone, PartialEq, Eq, TyEncodable, TyDecodable, HashStable, Lift)]
810+
#[derive(TypeFoldable, TypeVisitable)]
811+
pub struct ImplSourceFutureData<'tcx, N> {
812+
pub generator_def_id: DefId,
813+
pub substs: SubstsRef<'tcx>,
814+
/// Nested obligations. This can be non-empty if the generator
815+
/// signature contains associated types.
816+
pub nested: Vec<N>,
799817
}
800818

801819
#[derive(Clone, PartialEq, Eq, TyEncodable, TyDecodable, HashStable, Lift)]

compiler/rustc_middle/src/traits/select.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,13 @@ pub enum SelectionCandidate<'tcx> {
127127
/// generated for an `||` expression.
128128
ClosureCandidate,
129129

130-
/// Implementation of a `Generator` / `Future` trait by one of the anonymous types
130+
/// Implementation of a `Generator` trait by one of the anonymous types
131131
/// generated for a generator.
132-
GeneratorCandidate {
133-
is_async: bool,
134-
},
132+
GeneratorCandidate,
133+
134+
/// Implementation of a `Future` trait by one of the generator types
135+
/// generated for an async construct.
136+
FutureCandidate,
135137

136138
/// Implementation of a `Fn`-family trait by one of the anonymous
137139
/// types generated for a fn pointer type (e.g., `fn(int) -> int`)

compiler/rustc_middle/src/traits/structural_impls.rs

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSource<'tcx, N> {
1515

1616
super::ImplSource::Generator(ref d) => write!(f, "{:?}", d),
1717

18+
super::ImplSource::Future(ref d) => write!(f, "{:?}", d),
19+
1820
super::ImplSource::FnPointer(ref d) => write!(f, "({:?})", d),
1921

2022
super::ImplSource::DiscriminantKind(ref d) => write!(f, "{:?}", d),
@@ -58,6 +60,16 @@ impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSourceGeneratorData<'tcx, N
5860
}
5961
}
6062

63+
impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSourceFutureData<'tcx, N> {
64+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65+
write!(
66+
f,
67+
"ImplSourceFutureData(generator_def_id={:?}, substs={:?}, nested={:?})",
68+
self.generator_def_id, self.substs, self.nested
69+
)
70+
}
71+
}
72+
6173
impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSourceClosureData<'tcx, N> {
6274
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
6375
write!(

compiler/rustc_mir_transform/src/generator.rs

+39-36
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
//! generator in the MIR, since it is used to create the drop glue for the generator. We'd get
1212
//! infinite recursion otherwise.
1313
//!
14-
//! This pass creates the implementation for the Generator::resume function and the drop shim
15-
//! for the generator based on the MIR input. It converts the generator argument from Self to
16-
//! &mut Self adding derefs in the MIR as needed. It computes the final layout of the generator
17-
//! struct which looks like this:
14+
//! This pass creates the implementation for either the `Generator::resume` or `Future::poll`
15+
//! function and the drop shim for the generator based on the MIR input.
16+
//! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed.
17+
//! It computes the final layout of the generator struct which looks like this:
1818
//! First upvars are stored
1919
//! It is followed by the generator state field.
2020
//! Then finally the MIR locals which are live across a suspension point are stored.
@@ -32,14 +32,15 @@
3232
//! 2 - Generator has been poisoned
3333
//!
3434
//! It also rewrites `return x` and `yield y` as setting a new generator state and returning
35-
//! GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively.
35+
//! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`,
36+
//! or `Poll::Ready(x)` and `Poll::Pending` respectively.
3637
//! MIR locals which are live across a suspension point are moved to the generator struct
3738
//! with references to them being updated with references to the generator struct.
3839
//!
3940
//! The pass creates two functions which have a switch on the generator state giving
4041
//! the action to take.
4142
//!
42-
//! One of them is the implementation of Generator::resume.
43+
//! One of them is the implementation of `Generator::resume` / `Future::poll`.
4344
//! For generators with state 0 (unresumed) it starts the execution of the generator.
4445
//! For generators with state 1 (returned) and state 2 (poisoned) it panics.
4546
//! Otherwise it continues the execution from the last suspension point.
@@ -53,10 +54,10 @@ use crate::deref_separator::deref_finder;
5354
use crate::simplify;
5455
use crate::util::expand_aggregate;
5556
use crate::MirPass;
56-
use hir::GeneratorKind;
5757
use rustc_data_structures::fx::FxHashMap;
5858
use rustc_hir as hir;
5959
use rustc_hir::lang_items::LangItem;
60+
use rustc_hir::GeneratorKind;
6061
use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
6162
use rustc_index::vec::{Idx, IndexVec};
6263
use rustc_middle::mir::dump_mir;
@@ -246,14 +247,35 @@ impl<'tcx> TransformVisitor<'tcx> {
246247
// `core::ops::GeneratorState` only has single element tuple variants,
247248
// so we can just write to the downcasted first field and then set the
248249
// discriminant to the appropriate variant.
249-
fn make_state(&self, idx: VariantIdx) -> (AggregateKind<'tcx>, Option<Ty<'tcx>>) {
250+
fn make_state(
251+
&self,
252+
val: Operand<'tcx>,
253+
source_info: SourceInfo,
254+
is_return: bool,
255+
statements: &mut Vec<Statement<'tcx>>,
256+
) {
257+
let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
258+
(true, false) => 1, // GeneratorState::Complete
259+
(false, false) => 0, // GeneratorState::Yielded
260+
(true, true) => 0, // Poll::Ready
261+
(false, true) => 1, // Poll::Pending
262+
});
263+
250264
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None);
251265

252266
// `Poll::Pending`
253267
if self.is_async_kind && idx == VariantIdx::new(1) {
254268
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
255269

256-
return (kind, None);
270+
// FIXME(swatinem): assert that `val` is indeed unit?
271+
statements.extend(expand_aggregate(
272+
Place::return_place(),
273+
std::iter::empty(),
274+
kind,
275+
source_info,
276+
self.tcx,
277+
));
278+
return;
257279
}
258280

259281
// else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
@@ -264,7 +286,13 @@ impl<'tcx> TransformVisitor<'tcx> {
264286
.bound_type_of(self.state_adt_ref.variant(idx).fields[0].did)
265287
.subst(self.tcx, self.state_substs);
266288

267-
(kind, Some(ty))
289+
statements.extend(expand_aggregate(
290+
Place::return_place(),
291+
std::iter::once((val, ty)),
292+
kind,
293+
source_info,
294+
self.tcx,
295+
));
268296
}
269297

270298
// Create a Place referencing a generator struct field
@@ -347,32 +375,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
347375
if let Some((is_return, resume, v, drop)) = ret_val {
348376
let source_info = data.terminator().source_info;
349377
// We must assign the value first in case it gets declared dead below
350-
let state_idx = VariantIdx::new(match (is_return, self.is_async_kind) {
351-
(true, false) => 1, // GeneratorState::Complete
352-
(false, false) => 0, // GeneratorState::Yielded
353-
(true, true) => 0, // Poll::Ready
354-
(false, true) => 1, // Poll::Pending
355-
});
356-
let (kind, ty) = self.make_state(state_idx);
357-
if let Some(ty) = ty {
358-
data.statements.extend(expand_aggregate(
359-
Place::return_place(),
360-
std::iter::once((v, ty)),
361-
kind,
362-
source_info,
363-
self.tcx,
364-
));
365-
} else {
366-
// FIXME(swatinem): assert that `val` is indeed unit?
367-
data.statements.extend(expand_aggregate(
368-
Place::return_place(),
369-
std::iter::empty(),
370-
kind,
371-
source_info,
372-
self.tcx,
373-
));
374-
}
375-
378+
self.make_state(v, source_info, is_return, &mut data.statements);
376379
let state = if let Some((resume, mut resume_arg)) = resume {
377380
// Yield
378381
let state = RESERVED_VARIANTS + self.suspension_points.len();

compiler/rustc_trait_selection/src/traits/project.rs

+70-47
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use super::SelectionContext;
1212
use super::SelectionError;
1313
use super::{
1414
ImplSourceClosureData, ImplSourceDiscriminantKindData, ImplSourceFnPointerData,
15-
ImplSourceGeneratorData, ImplSourcePointeeData, ImplSourceUserDefinedData,
15+
ImplSourceFutureData, ImplSourceGeneratorData, ImplSourcePointeeData,
16+
ImplSourceUserDefinedData,
1617
};
1718
use super::{Normalized, NormalizedTy, ProjectionCacheEntry, ProjectionCacheKey};
1819

@@ -1556,6 +1557,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
15561557
let eligible = match &impl_source {
15571558
super::ImplSource::Closure(_)
15581559
| super::ImplSource::Generator(_)
1560+
| super::ImplSource::Future(_)
15591561
| super::ImplSource::FnPointer(_)
15601562
| super::ImplSource::TraitAlias(_) => true,
15611563
super::ImplSource::UserDefined(impl_data) => {
@@ -1844,6 +1846,7 @@ fn confirm_select_candidate<'cx, 'tcx>(
18441846
match impl_source {
18451847
super::ImplSource::UserDefined(data) => confirm_impl_candidate(selcx, obligation, data),
18461848
super::ImplSource::Generator(data) => confirm_generator_candidate(selcx, obligation, data),
1849+
super::ImplSource::Future(data) => confirm_future_candidate(selcx, obligation, data),
18471850
super::ImplSource::Closure(data) => confirm_closure_candidate(selcx, obligation, data),
18481851
super::ImplSource::FnPointer(data) => confirm_fn_pointer_candidate(selcx, obligation, data),
18491852
super::ImplSource::DiscriminantKind(data) => {
@@ -1884,55 +1887,75 @@ fn confirm_generator_candidate<'cx, 'tcx>(
18841887
debug!(?obligation, ?gen_sig, ?obligations, "confirm_generator_candidate");
18851888

18861889
let tcx = selcx.tcx();
1890+
let gen_def_id = tcx.require_lang_item(LangItem::Generator, None);
18871891

1888-
let predicate = if impl_source.is_async {
1889-
let fut_def_id = tcx.require_lang_item(LangItem::Future, None);
1890-
1891-
// FIXME(swatinem): this is just copy-pasted from `generator_trait_ref_and_outputs` for now
1892-
let self_ty = obligation.predicate.self_ty();
1893-
debug_assert!(!self_ty.has_escaping_bound_vars());
1894-
let trait_ref =
1895-
ty::TraitRef { def_id: fut_def_id, substs: tcx.mk_substs_trait(self_ty, &[]) };
1896-
gen_sig.map_bound(|sig| (trait_ref, sig.return_ty)).map_bound(|(trait_ref, return_ty)| {
1897-
let name = tcx.associated_item(obligation.predicate.item_def_id).name;
1898-
let ty = if name == sym::Output { return_ty } else { bug!() };
1899-
1900-
ty::ProjectionPredicate {
1901-
projection_ty: ty::ProjectionTy {
1902-
substs: trait_ref.substs,
1903-
item_def_id: obligation.predicate.item_def_id,
1904-
},
1905-
term: ty.into(),
1906-
}
1907-
})
1908-
} else {
1909-
let gen_def_id = tcx.require_lang_item(LangItem::Generator, None);
1892+
let predicate = super::util::generator_trait_ref_and_outputs(
1893+
tcx,
1894+
gen_def_id,
1895+
obligation.predicate.self_ty(),
1896+
gen_sig,
1897+
)
1898+
.map_bound(|(trait_ref, yield_ty, return_ty)| {
1899+
let name = tcx.associated_item(obligation.predicate.item_def_id).name;
1900+
let ty = if name == sym::Return {
1901+
return_ty
1902+
} else if name == sym::Yield {
1903+
yield_ty
1904+
} else {
1905+
bug!()
1906+
};
19101907

1911-
super::util::generator_trait_ref_and_outputs(
1912-
tcx,
1913-
gen_def_id,
1914-
obligation.predicate.self_ty(),
1915-
gen_sig,
1916-
)
1917-
.map_bound(|(trait_ref, yield_ty, return_ty)| {
1918-
let name = tcx.associated_item(obligation.predicate.item_def_id).name;
1919-
let ty = if name == sym::Return {
1920-
return_ty
1921-
} else if name == sym::Yield {
1922-
yield_ty
1923-
} else {
1924-
bug!()
1925-
};
1908+
ty::ProjectionPredicate {
1909+
projection_ty: ty::ProjectionTy {
1910+
substs: trait_ref.substs,
1911+
item_def_id: obligation.predicate.item_def_id,
1912+
},
1913+
term: ty.into(),
1914+
}
1915+
});
19261916

1927-
ty::ProjectionPredicate {
1928-
projection_ty: ty::ProjectionTy {
1929-
substs: trait_ref.substs,
1930-
item_def_id: obligation.predicate.item_def_id,
1931-
},
1932-
term: ty.into(),
1933-
}
1934-
})
1935-
};
1917+
confirm_param_env_candidate(selcx, obligation, predicate, false)
1918+
.with_addl_obligations(impl_source.nested)
1919+
.with_addl_obligations(obligations)
1920+
}
1921+
1922+
fn confirm_future_candidate<'cx, 'tcx>(
1923+
selcx: &mut SelectionContext<'cx, 'tcx>,
1924+
obligation: &ProjectionTyObligation<'tcx>,
1925+
impl_source: ImplSourceFutureData<'tcx, PredicateObligation<'tcx>>,
1926+
) -> Progress<'tcx> {
1927+
let gen_sig = impl_source.substs.as_generator().poly_sig();
1928+
let Normalized { value: gen_sig, obligations } = normalize_with_depth(
1929+
selcx,
1930+
obligation.param_env,
1931+
obligation.cause.clone(),
1932+
obligation.recursion_depth + 1,
1933+
gen_sig,
1934+
);
1935+
1936+
debug!(?obligation, ?gen_sig, ?obligations, "confirm_future_candidate");
1937+
1938+
let tcx = selcx.tcx();
1939+
let fut_def_id = tcx.require_lang_item(LangItem::Future, None);
1940+
1941+
let predicate = super::util::future_trait_ref_and_outputs(
1942+
tcx,
1943+
fut_def_id,
1944+
obligation.predicate.self_ty(),
1945+
gen_sig,
1946+
)
1947+
.map_bound(|(trait_ref, return_ty)| {
1948+
let name = tcx.associated_item(obligation.predicate.item_def_id).name;
1949+
let ty = if name == sym::Output { return_ty } else { bug!() };
1950+
1951+
ty::ProjectionPredicate {
1952+
projection_ty: ty::ProjectionTy {
1953+
substs: trait_ref.substs,
1954+
item_def_id: obligation.predicate.item_def_id,
1955+
},
1956+
term: ty.into(),
1957+
}
1958+
});
19361959

19371960
confirm_param_env_candidate(selcx, obligation, predicate, false)
19381961
.with_addl_obligations(impl_source.nested)

0 commit comments

Comments
 (0)