Skip to content

Commit 24dcf6f

Browse files
committed
Allow to use super trait bounds in where clauses
1 parent 361543d commit 24dcf6f

File tree

6 files changed

+167
-23
lines changed

6 files changed

+167
-23
lines changed

compiler/rustc_middle/src/query/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ rustc_queries! {
438438

439439
/// To avoid cycles within the predicates of a single item we compute
440440
/// per-type-parameter predicates for resolving `T::AssocTy`.
441-
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
441+
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
442442
desc { |tcx| "computing the bounds for type parameter `{}`", {
443443
let id = tcx.hir().local_def_id_to_hir_id(key.1);
444444
tcx.hir().ty_param_name(id)

compiler/rustc_middle/src/ty/query/keys.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
77
use crate::ty::{self, Ty, TyCtxt};
88
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
99
use rustc_query_system::query::DefaultCacheSelector;
10-
use rustc_span::symbol::Symbol;
10+
use rustc_span::symbol::{Ident, Symbol};
1111
use rustc_span::{Span, DUMMY_SP};
1212

1313
/// The `Key` trait controls what types can legally be used as the key
@@ -149,6 +149,17 @@ impl Key for (LocalDefId, DefId) {
149149
}
150150
}
151151

152+
impl Key for (DefId, LocalDefId, Ident) {
153+
type CacheSelector = DefaultCacheSelector;
154+
155+
fn query_crate(&self) -> CrateNum {
156+
self.0.krate
157+
}
158+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
159+
self.1.default_span(tcx)
160+
}
161+
}
162+
152163
impl Key for (CrateNum, DefId) {
153164
type CacheSelector = DefaultCacheSelector;
154165

compiler/rustc_typeck/src/astconv/mod.rs

+13-6
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {
4949

5050
fn default_constness_for_trait_bounds(&self) -> Constness;
5151

52-
/// Returns predicates in scope of the form `X: Foo`, where `X` is
53-
/// a type parameter `X` with the given id `def_id`. This is a
54-
/// subset of the full set of predicates.
52+
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
53+
/// is a type parameter `X` with the given id `def_id` and T
54+
/// matches assoc_name. This is a subset of the full set of
55+
/// predicates.
5556
///
5657
/// This is used for one specific purpose: resolving "short-hand"
5758
/// associated type references like `T::Item`. In principle, we
@@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
6061
/// but this can lead to cycle errors. The problem is that we have
6162
/// to do this resolution *in order to create the predicates in
6263
/// the first place*. Hence, we have this "special pass".
63-
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
64+
fn get_type_parameter_bounds(
65+
&self,
66+
span: Span,
67+
def_id: DefId,
68+
assoc_name: Ident,
69+
) -> ty::GenericPredicates<'tcx>;
6470

6571
/// Returns the lifetime to use when a lifetime is omitted (and not elided).
6672
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
@@ -1361,8 +1367,9 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
13611367
ty_param_def_id, assoc_name, span,
13621368
);
13631369

1364-
let predicates =
1365-
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
1370+
let predicates = &self
1371+
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
1372+
.predicates;
13661373

13671374
debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);
13681375

compiler/rustc_typeck/src/check/fn_ctxt/mod.rs

+24-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
2020
use rustc_middle::ty::subst::GenericArgKind;
2121
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
2222
use rustc_session::Session;
23+
use rustc_span::symbol::Ident;
2324
use rustc_span::{self, Span};
2425
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
2526

@@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
183184
}
184185
}
185186

186-
fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
187+
fn get_type_parameter_bounds(
188+
&self,
189+
_: Span,
190+
def_id: DefId,
191+
assoc_name: Ident,
192+
) -> ty::GenericPredicates<'tcx> {
187193
let tcx = self.tcx;
188194
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
189195
let item_id = tcx.hir().ty_param_owner(hir_id);
@@ -196,9 +202,23 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
196202
self.param_env.caller_bounds().iter().filter_map(|predicate| {
197203
match predicate.skip_binders() {
198204
ty::PredicateAtom::Trait(data, _) if data.self_ty().is_param(index) => {
199-
// HACK(eddyb) should get the original `Span`.
200-
let span = tcx.def_span(def_id);
201-
Some((predicate, span))
205+
let trait_did = data.def_id();
206+
if tcx
207+
.associated_items(trait_did)
208+
.find_by_name_and_kind(
209+
tcx,
210+
assoc_name,
211+
ty::AssocKind::Type,
212+
trait_did,
213+
)
214+
.is_some()
215+
{
216+
// HACK(eddyb) should get the original `Span`.
217+
let span = tcx.def_span(def_id);
218+
Some((predicate, span))
219+
} else {
220+
None
221+
}
202222
}
203223
_ => None,
204224
}

compiler/rustc_typeck/src/collect.rs

+102-11
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,17 @@ impl AstConv<'tcx> for ItemCtxt<'tcx> {
310310
}
311311
}
312312

313-
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
314-
self.tcx.at(span).type_param_predicates((self.item_def_id, def_id.expect_local()))
313+
fn get_type_parameter_bounds(
314+
&self,
315+
span: Span,
316+
def_id: DefId,
317+
assoc_name: Ident,
318+
) -> ty::GenericPredicates<'tcx> {
319+
self.tcx.at(span).type_param_predicates((
320+
self.item_def_id,
321+
def_id.expect_local(),
322+
assoc_name,
323+
))
315324
}
316325

317326
fn re_infer(&self, _: Option<&ty::GenericParamDef>, _: Span) -> Option<ty::Region<'tcx>> {
@@ -492,7 +501,7 @@ fn get_new_lifetime_name<'tcx>(
492501
/// `X: Foo` where `X` is the type parameter `def_id`.
493502
fn type_param_predicates(
494503
tcx: TyCtxt<'_>,
495-
(item_def_id, def_id): (DefId, LocalDefId),
504+
(item_def_id, def_id, assoc_name): (DefId, LocalDefId, Ident),
496505
) -> ty::GenericPredicates<'_> {
497506
use rustc_hir::*;
498507

@@ -517,7 +526,7 @@ fn type_param_predicates(
517526
let mut result = parent
518527
.map(|parent| {
519528
let icx = ItemCtxt::new(tcx, parent);
520-
icx.get_type_parameter_bounds(DUMMY_SP, def_id.to_def_id())
529+
icx.get_type_parameter_bounds(DUMMY_SP, def_id.to_def_id(), assoc_name)
521530
})
522531
.unwrap_or_default();
523532
let mut extend = None;
@@ -560,12 +569,18 @@ fn type_param_predicates(
560569

561570
let icx = ItemCtxt::new(tcx, item_def_id);
562571
let extra_predicates = extend.into_iter().chain(
563-
icx.type_parameter_bounds_in_generics(ast_generics, param_id, ty, OnlySelfBounds(true))
564-
.into_iter()
565-
.filter(|(predicate, _)| match predicate.skip_binders() {
566-
ty::PredicateAtom::Trait(data, _) => data.self_ty().is_param(index),
567-
_ => false,
568-
}),
572+
icx.type_parameter_bounds_in_generics(
573+
ast_generics,
574+
param_id,
575+
ty,
576+
OnlySelfBounds(true),
577+
Some(assoc_name),
578+
)
579+
.into_iter()
580+
.filter(|(predicate, _)| match predicate.skip_binders() {
581+
ty::PredicateAtom::Trait(data, _) => data.self_ty().is_param(index),
582+
_ => false,
583+
}),
569584
);
570585
result.predicates =
571586
tcx.arena.alloc_from_iter(result.predicates.iter().copied().chain(extra_predicates));
@@ -583,6 +598,7 @@ impl ItemCtxt<'tcx> {
583598
param_id: hir::HirId,
584599
ty: Ty<'tcx>,
585600
only_self_bounds: OnlySelfBounds,
601+
assoc_name: Option<Ident>,
586602
) -> Vec<(ty::Predicate<'tcx>, Span)> {
587603
let constness = self.default_constness_for_trait_bounds();
588604
let from_ty_params = ast_generics
@@ -593,6 +609,10 @@ impl ItemCtxt<'tcx> {
593609
_ => None,
594610
})
595611
.flat_map(|bounds| bounds.iter())
612+
.filter(|b| match assoc_name {
613+
Some(assoc_name) => self.bound_defines_assoc_item(b, assoc_name),
614+
None => true,
615+
})
596616
.flat_map(|b| predicates_from_bound(self, ty, b, constness));
597617

598618
let from_where_clauses = ast_generics
@@ -611,12 +631,43 @@ impl ItemCtxt<'tcx> {
611631
} else {
612632
None
613633
};
614-
bp.bounds.iter().filter_map(move |b| bt.map(|bt| (bt, b)))
634+
bp.bounds
635+
.iter()
636+
.filter(|b| match assoc_name {
637+
Some(assoc_name) => self.bound_defines_assoc_item(b, assoc_name),
638+
None => true,
639+
})
640+
.filter_map(move |b| bt.map(|bt| (bt, b)))
615641
})
616642
.flat_map(|(bt, b)| predicates_from_bound(self, bt, b, constness));
617643

618644
from_ty_params.chain(from_where_clauses).collect()
619645
}
646+
647+
fn bound_defines_assoc_item(&self, b: &hir::GenericBound<'_>, assoc_name: Ident) -> bool {
648+
debug!("bound_defines_assoc_item(b={:?}, assoc_name={:?})", b, assoc_name);
649+
650+
match b {
651+
hir::GenericBound::Trait(poly_trait_ref, _) => {
652+
let trait_ref = &poly_trait_ref.trait_ref;
653+
let trait_did = trait_ref.trait_def_id().unwrap();
654+
let traits_did = super_traits_of(self.tcx, trait_did);
655+
656+
traits_did.iter().any(|trait_did| {
657+
self.tcx
658+
.associated_items(*trait_did)
659+
.find_by_name_and_kind(
660+
self.tcx,
661+
assoc_name,
662+
ty::AssocKind::Type,
663+
*trait_did,
664+
)
665+
.is_some()
666+
})
667+
}
668+
_ => false,
669+
}
670+
}
620671
}
621672

622673
/// Tests whether this is the AST for a reference to the type
@@ -1017,6 +1068,7 @@ fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredi
10171068
item.hir_id,
10181069
self_param_ty,
10191070
OnlySelfBounds(!is_trait_alias),
1071+
None,
10201072
);
10211073

10221074
// Combine the two lists to form the complete set of superbounds:
@@ -1034,6 +1086,45 @@ fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredi
10341086
ty::GenericPredicates { parent: None, predicates: superbounds }
10351087
}
10361088

1089+
pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<Item = DefId> {
1090+
let mut set = FxHashSet::default();
1091+
let mut stack = vec![trait_def_id];
1092+
while let Some(trait_did) = stack.pop() {
1093+
if !set.insert(trait_did) {
1094+
continue;
1095+
}
1096+
1097+
if trait_did.is_local() {
1098+
let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_did.expect_local());
1099+
1100+
let item = match tcx.hir().get(trait_hir_id) {
1101+
Node::Item(item) => item,
1102+
_ => bug!("super_trait_of {} is not an item", trait_hir_id),
1103+
};
1104+
1105+
let supertraits = match item.kind {
1106+
hir::ItemKind::Trait(.., ref supertraits, _) => supertraits,
1107+
hir::ItemKind::TraitAlias(_, ref supertraits) => supertraits,
1108+
_ => span_bug!(item.span, "super_trait_of invoked on non-trait"),
1109+
};
1110+
1111+
for supertrait in supertraits.iter() {
1112+
let trait_ref = supertrait.trait_ref();
1113+
if let Some(trait_did) = trait_ref.and_then(|trait_ref| trait_ref.trait_def_id()) {
1114+
stack.push(trait_did);
1115+
}
1116+
}
1117+
} else {
1118+
let generic_predicates = tcx.super_predicates_of(trait_did);
1119+
for (predicate, _) in generic_predicates.predicates {
1120+
if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
1121+
stack.push(data.def_id());
1122+
}
1123+
}
1124+
}
1125+
}
1126+
}
1127+
10371128
fn trait_def(tcx: TyCtxt<'_>, def_id: DefId) -> ty::TraitDef {
10381129
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
10391130
let item = tcx.hir().expect_item(hir_id);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// check-pass
2+
trait Foo {
3+
type Item;
4+
}
5+
6+
trait Bar<T> {}
7+
8+
fn baz<T>()
9+
where
10+
T: Foo,
11+
T: Bar<T::Item>,
12+
{
13+
}
14+
15+
fn main() {}

0 commit comments

Comments
 (0)