From 6691c4cdad06db87eb9fd31183e8b46bbfd741e3 Mon Sep 17 00:00:00 2001 From: lcnr Date: Fri, 5 May 2023 13:51:01 +0200 Subject: [PATCH] forbid escaping bound vars in combine removes the `CollectAllMismatches` in favor of a slightly more manual approach. --- compiler/rustc_infer/src/infer/combine.rs | 22 ++-- .../traits/error_reporting/method_chain.rs | 102 ------------------ .../src/traits/error_reporting/mod.rs | 4 +- .../src/traits/error_reporting/suggestions.rs | 49 ++++++--- 4 files changed, 46 insertions(+), 131 deletions(-) delete mode 100644 compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs index c9e13be02ff..2a51439b0a9 100644 --- a/compiler/rustc_infer/src/infer/combine.rs +++ b/compiler/rustc_infer/src/infer/combine.rs @@ -73,6 +73,8 @@ pub fn super_combine_tys( R: ObligationEmittingRelation<'tcx>, { let a_is_expected = relation.a_is_expected(); + debug_assert!(!a.has_escaping_bound_vars()); + debug_assert!(!b.has_escaping_bound_vars()); match (a.kind(), b.kind()) { // Relate integral variables to other types @@ -163,6 +165,8 @@ pub fn super_combine_consts( R: ObligationEmittingRelation<'tcx>, { debug!("{}.consts({:?}, {:?})", relation.tag(), a, b); + debug_assert!(!a.has_escaping_bound_vars()); + debug_assert!(!b.has_escaping_bound_vars()); if a == b { return Ok(a); } @@ -238,22 +242,12 @@ pub fn super_combine_consts( (_, ty::ConstKind::Infer(InferConst::Var(vid))) => { return self.unify_const_variable(vid, a); } - (ty::ConstKind::Unevaluated(..), _) if self.tcx.lazy_normalization() => { - // FIXME(#59490): Need to remove the leak check to accommodate - // escaping bound variables here. - if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() { - relation.register_const_equate_obligation(a, b); - } + (ty::ConstKind::Unevaluated(..), _) | (_, ty::ConstKind::Unevaluated(..)) + if self.tcx.lazy_normalization() => + { + relation.register_const_equate_obligation(a, b); return Ok(b); } - (_, ty::ConstKind::Unevaluated(..)) if self.tcx.lazy_normalization() => { - // FIXME(#59490): Need to remove the leak check to accommodate - // escaping bound variables here. - if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() { - relation.register_const_equate_obligation(a, b); - } - return Ok(a); - } _ => {} } diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs deleted file mode 100644 index 7e1dba4ed26..00000000000 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs +++ /dev/null @@ -1,102 +0,0 @@ -use crate::infer::InferCtxt; - -use rustc_infer::infer::ObligationEmittingRelation; -use rustc_infer::traits::PredicateObligations; -use rustc_middle::ty::error::TypeError; -use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation}; -use rustc_middle::ty::{self, Ty, TyCtxt}; - -pub struct CollectAllMismatches<'a, 'tcx> { - pub infcx: &'a InferCtxt<'tcx>, - pub param_env: ty::ParamEnv<'tcx>, - pub errors: Vec>, -} - -impl<'a, 'tcx> TypeRelation<'tcx> for CollectAllMismatches<'a, 'tcx> { - fn tag(&self) -> &'static str { - "CollectAllMismatches" - } - - fn tcx(&self) -> TyCtxt<'tcx> { - self.infcx.tcx - } - - fn param_env(&self) -> ty::ParamEnv<'tcx> { - self.param_env - } - - fn a_is_expected(&self) -> bool { - true - } - - fn relate_with_variance>( - &mut self, - _: ty::Variance, - _: ty::VarianceDiagInfo<'tcx>, - a: T, - b: T, - ) -> RelateResult<'tcx, T> { - self.relate(a, b) - } - - fn regions( - &mut self, - a: ty::Region<'tcx>, - _b: ty::Region<'tcx>, - ) -> RelateResult<'tcx, ty::Region<'tcx>> { - Ok(a) - } - - fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { - self.infcx.probe(|_| { - if a.is_ty_var() || b.is_ty_var() { - Ok(a) - } else { - self.infcx.super_combine_tys(self, a, b).or_else(|e| { - self.errors.push(e); - Ok(a) - }) - } - }) - } - - fn consts( - &mut self, - a: ty::Const<'tcx>, - b: ty::Const<'tcx>, - ) -> RelateResult<'tcx, ty::Const<'tcx>> { - self.infcx.probe(|_| { - if a.is_ct_infer() || b.is_ct_infer() { - Ok(a) - } else { - relate::super_relate_consts(self, a, b) // could do something similar here for constants! - } - }) - } - - fn binders>( - &mut self, - a: ty::Binder<'tcx, T>, - b: ty::Binder<'tcx, T>, - ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> { - Ok(a.rebind(self.relate(a.skip_binder(), b.skip_binder())?)) - } -} - -impl<'tcx> ObligationEmittingRelation<'tcx> for CollectAllMismatches<'_, 'tcx> { - fn alias_relate_direction(&self) -> ty::AliasRelationDirection { - // FIXME(deferred_projection_equality): We really should get rid of this relation. - ty::AliasRelationDirection::Equate - } - - fn register_obligations(&mut self, _obligations: PredicateObligations<'tcx>) { - // FIXME(deferred_projection_equality) - } - - fn register_predicates( - &mut self, - _obligations: impl IntoIterator>, - ) { - // FIXME(deferred_projection_equality) - } -} diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index 8f2a5d649f0..afb64da8b61 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -1,5 +1,4 @@ mod ambiguity; -pub mod method_chain; pub mod on_unimplemented; pub mod suggestions; @@ -559,6 +558,7 @@ fn report_overflow_obligation( suggest_increasing_limit, |err| { self.note_obligation_cause_code( + obligation.cause.body_id, err, predicate, obligation.param_env, @@ -1431,6 +1431,7 @@ fn report_fulfillment_error(&self, error: &FulfillmentError<'tcx>) { | ObligationCauseCode::ExprItemObligation(..) = code { self.note_obligation_cause_code( + error.obligation.cause.body_id, &mut diag, error.obligation.predicate, error.obligation.param_env, @@ -2544,6 +2545,7 @@ fn note_obligation_cause(&self, err: &mut Diagnostic, obligation: &PredicateObli // message, and fall back to regular note otherwise. if !self.maybe_note_obligation_cause_for_async_await(err, obligation) { self.note_obligation_cause_code( + obligation.cause.body_id, err, obligation.predicate, obligation.param_env, diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index d34eb193453..164540cc16f 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -25,10 +25,9 @@ use rustc_hir::{Expr, HirId}; use rustc_infer::infer::error_reporting::TypeErrCtxt; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; -use rustc_infer::infer::{InferOk, LateBoundRegionConversionTime}; +use rustc_infer::infer::{DefineOpaqueTypes, InferOk, LateBoundRegionConversionTime}; use rustc_middle::hir::map; use rustc_middle::ty::error::TypeError::{self, Sorts}; -use rustc_middle::ty::relate::TypeRelation; use rustc_middle::ty::{ self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, GeneratorDiagnosticData, GeneratorInteriorTypeCause, Infer, InferTy, InternalSubsts, @@ -39,9 +38,9 @@ use rustc_span::symbol::{sym, Ident, Symbol}; use rustc_span::{BytePos, DesugaringKind, ExpnKind, MacroKind, Span, DUMMY_SP}; use rustc_target::spec::abi; +use std::iter; use std::ops::Deref; -use super::method_chain::CollectAllMismatches; use super::InferCtxtPrivExt; use crate::infer::InferCtxtExt as _; use crate::traits::query::evaluate_obligation::InferCtxtExt as _; @@ -319,6 +318,7 @@ fn note_obligation_cause_for_async_await( fn note_obligation_cause_code( &self, + body_id: LocalDefId, err: &mut Diagnostic, predicate: T, param_env: ty::ParamEnv<'tcx>, @@ -359,8 +359,9 @@ fn suggest_dereferencing_index( ); fn note_function_argument_obligation( &self, - arg_hir_id: HirId, + body_id: LocalDefId, err: &mut Diagnostic, + arg_hir_id: HirId, parent_code: &ObligationCauseCode<'tcx>, param_env: ty::ParamEnv<'tcx>, predicate: ty::Predicate<'tcx>, @@ -2742,6 +2743,7 @@ fn note_obligation_cause_for_async_await( // bound that introduced the obligation (e.g. `T: Send`). debug!(?next_code); self.note_obligation_cause_code( + obligation.cause.body_id, err, obligation.predicate, obligation.param_env, @@ -2753,6 +2755,7 @@ fn note_obligation_cause_for_async_await( fn note_obligation_cause_code( &self, + body_id: LocalDefId, err: &mut Diagnostic, predicate: T, param_env: ty::ParamEnv<'tcx>, @@ -3152,6 +3155,7 @@ fn note_obligation_cause_code( // #74711: avoid a stack overflow ensure_sufficient_stack(|| { self.note_obligation_cause_code( + body_id, err, parent_predicate, param_env, @@ -3163,6 +3167,7 @@ fn note_obligation_cause_code( } else { ensure_sufficient_stack(|| { self.note_obligation_cause_code( + body_id, err, parent_predicate, param_env, @@ -3292,6 +3297,7 @@ fn note_obligation_cause_code( // #74711: avoid a stack overflow ensure_sufficient_stack(|| { self.note_obligation_cause_code( + body_id, err, parent_predicate, param_env, @@ -3307,6 +3313,7 @@ fn note_obligation_cause_code( // #74711: avoid a stack overflow ensure_sufficient_stack(|| { self.note_obligation_cause_code( + body_id, err, parent_predicate, param_env, @@ -3323,8 +3330,9 @@ fn note_obligation_cause_code( .. } => { self.note_function_argument_obligation( - arg_hir_id, + body_id, err, + arg_hir_id, parent_code, param_env, predicate, @@ -3332,6 +3340,7 @@ fn note_obligation_cause_code( ); ensure_sufficient_stack(|| { self.note_obligation_cause_code( + body_id, err, predicate, param_env, @@ -3553,8 +3562,9 @@ fn suggest_dereferencing_index( } fn note_function_argument_obligation( &self, - arg_hir_id: HirId, + body_id: LocalDefId, err: &mut Diagnostic, + arg_hir_id: HirId, parent_code: &ObligationCauseCode<'tcx>, param_env: ty::ParamEnv<'tcx>, failed_pred: ty::Predicate<'tcx>, @@ -3587,7 +3597,6 @@ fn note_function_argument_obligation( // to an associated type (as seen from `trait_pred`) in the predicate. Like in // trait_pred `S: Sum<::Item>` and predicate `i32: Sum<&()>` let mut type_diffs = vec![]; - if let ObligationCauseCode::ExprBindingObligation(def_id, _, _, idx) = parent_code.deref() && let Some(node_substs) = typeck_results.node_substs_opt(call_hir_id) && let where_clauses = self.tcx.predicates_of(def_id).instantiate(self.tcx, node_substs) @@ -3596,14 +3605,26 @@ fn note_function_argument_obligation( if let Some(where_pred) = where_pred.to_opt_poly_trait_pred() && let Some(failed_pred) = failed_pred.to_opt_poly_trait_pred() { - let mut c = CollectAllMismatches { - infcx: self.infcx, - param_env, - errors: vec![], + let where_pred = self.instantiate_binder_with_placeholders(where_pred); + let failed_pred = self.instantiate_binder_with_fresh_vars( + expr.span, + LateBoundRegionConversionTime::FnCall, + failed_pred + ); + + let zipped = + iter::zip(where_pred.trait_ref.substs, failed_pred.trait_ref.substs); + for (expected, actual) in zipped { + self.probe(|_| { + match self + .at(&ObligationCause::misc(expr.span, body_id), param_env) + .eq(DefineOpaqueTypes::No, expected, actual) + { + Ok(_) => (), // We ignore nested obligations here for now. + Err(err) => type_diffs.push(err), + } + }) }; - if let Ok(_) = c.relate(where_pred, failed_pred) { - type_diffs = c.errors; - } } else if let Some(where_pred) = where_pred.to_opt_poly_projection_pred() && let Some(failed_pred) = failed_pred.to_opt_poly_projection_pred() && let Some(found) = failed_pred.skip_binder().term.ty()