Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 0ee4e6a

Browse files
committed
Auto merge of rust-lang#12086 - iDawer:infer.rpit, r=flodiebold
infer from RPIT bounds of _this_ function Collect obligations from RPITs (Return Position `impl Trait`) of a function which is being inferred. This allows inferring {unknown}s from RPIT bounds. Closes rust-lang#8403
2 parents eeb4532 + 970276b commit 0ee4e6a

File tree

2 files changed

+75
-13
lines changed

2 files changed

+75
-13
lines changed

crates/hir-ty/src/infer.rs

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use std::sync::Arc;
1919
use chalk_ir::{cast::Cast, ConstValue, DebruijnIndex, Mutability, Safety, Scalar, TypeFlags};
2020
use hir_def::{
2121
body::Body,
22-
data::{ConstData, FunctionData, StaticData},
22+
data::{ConstData, StaticData},
2323
expr::{BindingAnnotation, ExprId, PatId},
2424
lang_item::LangItemTarget,
2525
path::{path, Path},
@@ -32,12 +32,13 @@ use hir_expand::name::{name, Name};
3232
use itertools::Either;
3333
use la_arena::ArenaMap;
3434
use rustc_hash::FxHashMap;
35-
use stdx::impl_from;
35+
use stdx::{always, impl_from};
3636

3737
use crate::{
38-
db::HirDatabase, fold_tys_and_consts, infer::coerce::CoerceMany, lower::ImplTraitLoweringMode,
39-
to_assoc_type_id, AliasEq, AliasTy, Const, DomainGoal, GenericArg, Goal, InEnvironment,
40-
Interner, ProjectionTy, Substitution, TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
38+
db::HirDatabase, fold_tys, fold_tys_and_consts, infer::coerce::CoerceMany,
39+
lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Const, DomainGoal,
40+
GenericArg, Goal, ImplTraitId, InEnvironment, Interner, ProjectionTy, Substitution,
41+
TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
4142
};
4243

4344
// This lint has a false positive here. See the link below for details.
@@ -64,7 +65,7 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<Infer
6465

6566
match def {
6667
DefWithBodyId::ConstId(c) => ctx.collect_const(&db.const_data(c)),
67-
DefWithBodyId::FunctionId(f) => ctx.collect_fn(&db.function_data(f)),
68+
DefWithBodyId::FunctionId(f) => ctx.collect_fn(f),
6869
DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_data(s)),
6970
}
7071

@@ -457,7 +458,8 @@ impl<'a> InferenceContext<'a> {
457458
self.return_ty = self.make_ty(&data.type_ref);
458459
}
459460

460-
fn collect_fn(&mut self, data: &FunctionData) {
461+
fn collect_fn(&mut self, func: FunctionId) {
462+
let data = self.db.function_data(func);
461463
let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver)
462464
.with_impl_trait_mode(ImplTraitLoweringMode::Param);
463465
let param_tys =
@@ -474,8 +476,42 @@ impl<'a> InferenceContext<'a> {
474476
} else {
475477
&*data.ret_type
476478
};
477-
let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Disallowed); // FIXME implement RPIT
479+
let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Opaque);
478480
self.return_ty = return_ty;
481+
482+
if let Some(rpits) = self.db.return_type_impl_traits(func) {
483+
// RPIT opaque types use substitution of their parent function.
484+
let fn_placeholders = TyBuilder::placeholder_subst(self.db, func);
485+
self.return_ty = fold_tys(
486+
self.return_ty.clone(),
487+
|ty, _| {
488+
let opaque_ty_id = match ty.kind(Interner) {
489+
TyKind::OpaqueType(opaque_ty_id, _) => *opaque_ty_id,
490+
_ => return ty,
491+
};
492+
let idx = match self.db.lookup_intern_impl_trait_id(opaque_ty_id.into()) {
493+
ImplTraitId::ReturnTypeImplTrait(_, idx) => idx,
494+
_ => unreachable!(),
495+
};
496+
let bounds = (*rpits).map_ref(|rpits| {
497+
rpits.impl_traits[idx as usize].bounds.map_ref(|it| it.into_iter())
498+
});
499+
let var = self.table.new_type_var();
500+
let var_subst = Substitution::from1(Interner, var.clone());
501+
for bound in bounds {
502+
let predicate =
503+
bound.map(|it| it.cloned()).substitute(Interner, &fn_placeholders);
504+
let (var_predicate, binders) = predicate
505+
.substitute(Interner, &var_subst)
506+
.into_value_and_skipped_binders();
507+
always!(binders.len(Interner) == 0); // quantified where clauses not yet handled
508+
self.push_obligation(var_predicate.cast(Interner));
509+
}
510+
var
511+
},
512+
DebruijnIndex::INNERMOST,
513+
);
514+
}
479515
}
480516

481517
fn infer_body(&mut self) {

crates/hir-ty/src/tests/traits.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,32 @@ fn test() {
12551255
);
12561256
}
12571257

1258+
#[test]
1259+
fn infer_from_return_pos_impl_trait() {
1260+
check_infer_with_mismatches(
1261+
r#"
1262+
//- minicore: fn, sized
1263+
trait Trait<T> {}
1264+
struct Bar<T>(T);
1265+
impl<T> Trait<T> for Bar<T> {}
1266+
fn foo<const C: u8, T>() -> (impl FnOnce(&str, T), impl Trait<u8>) {
1267+
(|input, t| {}, Bar(C))
1268+
}
1269+
"#,
1270+
expect![[r#"
1271+
134..165 '{ ...(C)) }': (|&str, T| -> (), Bar<u8>)
1272+
140..163 '(|inpu...ar(C))': (|&str, T| -> (), Bar<u8>)
1273+
141..154 '|input, t| {}': |&str, T| -> ()
1274+
142..147 'input': &str
1275+
149..150 't': T
1276+
152..154 '{}': ()
1277+
156..159 'Bar': Bar<u8>(u8) -> Bar<u8>
1278+
156..162 'Bar(C)': Bar<u8>
1279+
160..161 'C': u8
1280+
"#]],
1281+
);
1282+
}
1283+
12581284
#[test]
12591285
fn dyn_trait() {
12601286
check_infer(
@@ -2392,7 +2418,7 @@ fn test() -> impl Trait<i32> {
23922418
171..182 '{ loop {} }': T
23932419
173..180 'loop {}': !
23942420
178..180 '{}': ()
2395-
213..309 '{ ...t()) }': S<{unknown}>
2421+
213..309 '{ ...t()) }': S<i32>
23962422
223..225 's1': S<u32>
23972423
228..229 'S': S<u32>(u32) -> S<u32>
23982424
228..240 'S(default())': S<u32>
@@ -2408,10 +2434,10 @@ fn test() -> impl Trait<i32> {
24082434
276..288 'S(default())': S<i32>
24092435
278..285 'default': fn default<i32>() -> i32
24102436
278..287 'default()': i32
2411-
295..296 'S': S<{unknown}>({unknown}) -> S<{unknown}>
2412-
295..307 'S(default())': S<{unknown}>
2413-
297..304 'default': fn default<{unknown}>() -> {unknown}
2414-
297..306 'default()': {unknown}
2437+
295..296 'S': S<i32>(i32) -> S<i32>
2438+
295..307 'S(default())': S<i32>
2439+
297..304 'default': fn default<i32>() -> i32
2440+
297..306 'default()': i32
24152441
"#]],
24162442
);
24172443
}

0 commit comments

Comments
 (0)