Skip to content

Commit 3d9dd68

Browse files
Resolve vars in same_type_modulo_infer
1 parent 039a6ad commit 3d9dd68

File tree

2 files changed

+79
-70
lines changed
  • compiler
    • rustc_infer/src/infer/error_reporting
    • rustc_trait_selection/src/traits/error_reporting

2 files changed

+79
-70
lines changed

compiler/rustc_infer/src/infer/error_reporting/mod.rs

+78-68
Original file line numberDiff line numberDiff line change
@@ -316,37 +316,6 @@ pub fn unexpected_hidden_region_diagnostic<'tcx>(
316316
err
317317
}
318318

319-
/// Structurally compares two types, modulo any inference variables.
320-
///
321-
/// Returns `true` if two types are equal, or if one type is an inference variable compatible
322-
/// with the other type. A TyVar inference type is compatible with any type, and an IntVar or
323-
/// FloatVar inference type are compatible with themselves or their concrete types (Int and
324-
/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
325-
pub fn same_type_modulo_infer<'tcx>(a: Ty<'tcx>, b: Ty<'tcx>) -> bool {
326-
match (&a.kind(), &b.kind()) {
327-
(&ty::Adt(did_a, substs_a), &ty::Adt(did_b, substs_b)) => {
328-
if did_a != did_b {
329-
return false;
330-
}
331-
332-
substs_a.types().zip(substs_b.types()).all(|(a, b)| same_type_modulo_infer(a, b))
333-
}
334-
(&ty::Int(_), &ty::Infer(ty::InferTy::IntVar(_)))
335-
| (&ty::Infer(ty::InferTy::IntVar(_)), &ty::Int(_) | &ty::Infer(ty::InferTy::IntVar(_)))
336-
| (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_)))
337-
| (
338-
&ty::Infer(ty::InferTy::FloatVar(_)),
339-
&ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)),
340-
)
341-
| (&ty::Infer(ty::InferTy::TyVar(_)), _)
342-
| (_, &ty::Infer(ty::InferTy::TyVar(_))) => true,
343-
(&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => {
344-
mut_a == mut_b && same_type_modulo_infer(*ty_a, *ty_b)
345-
}
346-
_ => a == b,
347-
}
348-
}
349-
350319
impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
351320
pub fn report_region_errors(
352321
&self,
@@ -1723,15 +1692,14 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
17231692
};
17241693
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code());
17251694
if let Some(exp_found) = exp_found {
1726-
let should_suggest_fixes = if let ObligationCauseCode::Pattern { root_ty, .. } =
1727-
cause.code()
1728-
{
1729-
// Skip if the root_ty of the pattern is not the same as the expected_ty.
1730-
// If these types aren't equal then we've probably peeled off a layer of arrays.
1731-
same_type_modulo_infer(self.resolve_vars_if_possible(*root_ty), exp_found.expected)
1732-
} else {
1733-
true
1734-
};
1695+
let should_suggest_fixes =
1696+
if let ObligationCauseCode::Pattern { root_ty, .. } = cause.code() {
1697+
// Skip if the root_ty of the pattern is not the same as the expected_ty.
1698+
// If these types aren't equal then we've probably peeled off a layer of arrays.
1699+
self.same_type_modulo_infer(*root_ty, exp_found.expected)
1700+
} else {
1701+
true
1702+
};
17351703

17361704
if should_suggest_fixes {
17371705
self.suggest_tuple_pattern(cause, &exp_found, diag);
@@ -1786,7 +1754,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
17861754
.filter_map(|variant| {
17871755
let sole_field = &variant.fields[0];
17881756
let sole_field_ty = sole_field.ty(self.tcx, substs);
1789-
if same_type_modulo_infer(sole_field_ty, exp_found.found) {
1757+
if self.same_type_modulo_infer(sole_field_ty, exp_found.found) {
17901758
let variant_path =
17911759
with_no_trimmed_paths!(self.tcx.def_path_str(variant.def_id));
17921760
// FIXME #56861: DRYer prelude filtering
@@ -1902,47 +1870,50 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
19021870
self.get_impl_future_output_ty(exp_found.expected).map(Binder::skip_binder),
19031871
self.get_impl_future_output_ty(exp_found.found).map(Binder::skip_binder),
19041872
) {
1905-
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match cause.code() {
1906-
ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => {
1907-
diag.multipart_suggestion(
1908-
"consider `await`ing on both `Future`s",
1909-
vec![
1910-
(then.shrink_to_hi(), ".await".to_string()),
1911-
(exp_span.shrink_to_hi(), ".await".to_string()),
1912-
],
1913-
Applicability::MaybeIncorrect,
1914-
);
1915-
}
1916-
ObligationCauseCode::MatchExpressionArm(box MatchExpressionArmCause {
1917-
prior_arms,
1918-
..
1919-
}) => {
1920-
if let [.., arm_span] = &prior_arms[..] {
1873+
(Some(exp), Some(found)) if self.same_type_modulo_infer(exp, found) => {
1874+
match cause.code() {
1875+
ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => {
19211876
diag.multipart_suggestion(
19221877
"consider `await`ing on both `Future`s",
19231878
vec![
1924-
(arm_span.shrink_to_hi(), ".await".to_string()),
1879+
(then.shrink_to_hi(), ".await".to_string()),
19251880
(exp_span.shrink_to_hi(), ".await".to_string()),
19261881
],
19271882
Applicability::MaybeIncorrect,
19281883
);
1929-
} else {
1884+
}
1885+
ObligationCauseCode::MatchExpressionArm(box MatchExpressionArmCause {
1886+
prior_arms,
1887+
..
1888+
}) => {
1889+
if let [.., arm_span] = &prior_arms[..] {
1890+
diag.multipart_suggestion(
1891+
"consider `await`ing on both `Future`s",
1892+
vec![
1893+
(arm_span.shrink_to_hi(), ".await".to_string()),
1894+
(exp_span.shrink_to_hi(), ".await".to_string()),
1895+
],
1896+
Applicability::MaybeIncorrect,
1897+
);
1898+
} else {
1899+
diag.help("consider `await`ing on both `Future`s");
1900+
}
1901+
}
1902+
_ => {
19301903
diag.help("consider `await`ing on both `Future`s");
19311904
}
19321905
}
1933-
_ => {
1934-
diag.help("consider `await`ing on both `Future`s");
1935-
}
1936-
},
1937-
(_, Some(ty)) if same_type_modulo_infer(exp_found.expected, ty) => {
1906+
}
1907+
(_, Some(ty)) if self.same_type_modulo_infer(exp_found.expected, ty) => {
19381908
diag.span_suggestion_verbose(
19391909
exp_span.shrink_to_hi(),
19401910
"consider `await`ing on the `Future`",
19411911
".await",
19421912
Applicability::MaybeIncorrect,
19431913
);
19441914
}
1945-
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code() {
1915+
(Some(ty), _) if self.same_type_modulo_infer(ty, exp_found.found) => match cause.code()
1916+
{
19461917
ObligationCauseCode::Pattern { span: Some(span), .. }
19471918
| ObligationCauseCode::IfExpression(box IfExpressionCause { then: span, .. }) => {
19481919
diag.span_suggestion_verbose(
@@ -1992,7 +1963,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
19921963
.iter()
19931964
.filter(|field| field.vis.is_accessible_from(field.did, self.tcx))
19941965
.map(|field| (field.name, field.ty(self.tcx, expected_substs)))
1995-
.find(|(_, ty)| same_type_modulo_infer(*ty, exp_found.found))
1966+
.find(|(_, ty)| self.same_type_modulo_infer(*ty, exp_found.found))
19961967
{
19971968
if let ObligationCauseCode::Pattern { span: Some(span), .. } = *cause.code() {
19981969
if let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span) {
@@ -2057,7 +2028,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
20572028
| (_, ty::Infer(_))
20582029
| (ty::Param(_), _)
20592030
| (ty::Infer(_), _) => {}
2060-
_ if same_type_modulo_infer(exp_ty, found_ty) => {}
2031+
_ if self.same_type_modulo_infer(exp_ty, found_ty) => {}
20612032
_ => show_suggestion = false,
20622033
};
20632034
}
@@ -2179,7 +2150,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
21792150
) {
21802151
let [expected_tup_elem] = expected_fields[..] else { return };
21812152

2182-
if !same_type_modulo_infer(expected_tup_elem, found) {
2153+
if !self.same_type_modulo_infer(expected_tup_elem, found) {
21832154
return;
21842155
}
21852156

@@ -2647,6 +2618,45 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
26472618
span.is_desugaring(DesugaringKind::QuestionMark)
26482619
&& self.tcx.is_diagnostic_item(sym::From, trait_def_id)
26492620
}
2621+
2622+
/// Structurally compares two types, modulo any inference variables.
2623+
///
2624+
/// Returns `true` if two types are equal, or if one type is an inference variable compatible
2625+
/// with the other type. A TyVar inference type is compatible with any type, and an IntVar or
2626+
/// FloatVar inference type are compatible with themselves or their concrete types (Int and
2627+
/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
2628+
pub fn same_type_modulo_infer(&self, a: Ty<'tcx>, b: Ty<'tcx>) -> bool {
2629+
let (a, b) = self.resolve_vars_if_possible((a, b));
2630+
match (&a.kind(), &b.kind()) {
2631+
(&ty::Adt(did_a, substs_a), &ty::Adt(did_b, substs_b)) => {
2632+
if did_a != did_b {
2633+
return false;
2634+
}
2635+
2636+
substs_a
2637+
.types()
2638+
.zip(substs_b.types())
2639+
.all(|(a, b)| self.same_type_modulo_infer(a, b))
2640+
}
2641+
(&ty::Int(_) | &ty::Uint(_), &ty::Infer(ty::InferTy::IntVar(_)))
2642+
| (
2643+
&ty::Infer(ty::InferTy::IntVar(_)),
2644+
&ty::Int(_) | &ty::Uint(_) | &ty::Infer(ty::InferTy::IntVar(_)),
2645+
)
2646+
| (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_)))
2647+
| (
2648+
&ty::Infer(ty::InferTy::FloatVar(_)),
2649+
&ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)),
2650+
)
2651+
| (&ty::Infer(ty::InferTy::TyVar(_)), _)
2652+
| (_, &ty::Infer(ty::InferTy::TyVar(_))) => true,
2653+
(&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => {
2654+
mut_a == mut_b && self.same_type_modulo_infer(*ty_a, *ty_b)
2655+
}
2656+
// FIXME(compiler-errors): This needs to be generalized more
2657+
_ => a == b,
2658+
}
2659+
}
26502660
}
26512661

26522662
impl<'a, 'tcx> InferCtxt<'a, 'tcx> {

compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use rustc_hir::intravisit::Visitor;
2222
use rustc_hir::GenericParam;
2323
use rustc_hir::Item;
2424
use rustc_hir::Node;
25-
use rustc_infer::infer::error_reporting::same_type_modulo_infer;
2625
use rustc_infer::traits::TraitEngine;
2726
use rustc_middle::traits::select::OverflowError;
2827
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
@@ -640,7 +639,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
640639
if expected.len() == 1 { "" } else { "s" },
641640
)
642641
);
643-
} else if !same_type_modulo_infer(given_ty, expected_ty) {
642+
} else if !self.same_type_modulo_infer(given_ty, expected_ty) {
644643
// Print type mismatch
645644
let (expected_args, given_args) =
646645
self.cmp(given_ty, expected_ty);

0 commit comments

Comments
 (0)