From 9a757d6ee4229bd7e022a2f3f26946cf05fddbcf Mon Sep 17 00:00:00 2001 From: lcnr Date: Tue, 17 Jan 2023 12:26:28 +0100 Subject: [PATCH] add `eq` to `InferCtxtExt` --- .../src/solve/infcx_ext.rs | 40 ++++++++++++++++++- .../src/solve/project_goals.rs | 29 ++++---------- .../src/solve/trait_goals.rs | 18 ++------- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/compiler/rustc_trait_selection/src/solve/infcx_ext.rs b/compiler/rustc_trait_selection/src/solve/infcx_ext.rs index 8a8c3091d54..f92d6463134 100644 --- a/compiler/rustc_trait_selection/src/solve/infcx_ext.rs +++ b/compiler/rustc_trait_selection/src/solve/infcx_ext.rs @@ -1,11 +1,28 @@ +use rustc_infer::infer::at::ToTrace; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; -use rustc_infer::infer::InferCtxt; -use rustc_middle::ty::Ty; +use rustc_infer::infer::{InferCtxt, InferOk}; +use rustc_infer::traits::query::NoSolution; +use rustc_infer::traits::ObligationCause; +use rustc_middle::ty::{self, Ty}; use rustc_span::DUMMY_SP; +use super::Goal; + /// Methods used inside of the canonical queries of the solver. +/// +/// Most notably these do not care about diagnostics information. +/// If you find this while looking for methods to use outside of the +/// solver, you may look at the implementation of these method for +/// help. pub(super) trait InferCtxtExt<'tcx> { fn next_ty_infer(&self) -> Ty<'tcx>; + + fn eq>( + &self, + param_env: ty::ParamEnv<'tcx>, + lhs: T, + rhs: T, + ) -> Result>>, NoSolution>; } impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> { @@ -15,4 +32,23 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> { span: DUMMY_SP, }) } + + #[instrument(level = "debug", skip(self, param_env), ret)] + fn eq>( + &self, + param_env: ty::ParamEnv<'tcx>, + lhs: T, + rhs: T, + ) -> Result>>, NoSolution> { + self.at(&ObligationCause::dummy(), param_env) + .define_opaque_types(false) + .eq(lhs, rhs) + .map(|InferOk { value: (), obligations }| { + obligations.into_iter().map(|o| o.into()).collect() + }) + .map_err(|e| { + debug!(?e, "failed to equate"); + NoSolution + }) + } } diff --git a/compiler/rustc_trait_selection/src/solve/project_goals.rs b/compiler/rustc_trait_selection/src/solve/project_goals.rs index 92c5d4e53f5..cf07926f85a 100644 --- a/compiler/rustc_trait_selection/src/solve/project_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/project_goals.rs @@ -1,14 +1,15 @@ use crate::traits::{specialization_graph, translate_substs}; use super::assembly::{self, Candidate, CandidateSource}; +use super::infcx_ext::InferCtxtExt; use super::{Certainty, EvalCtxt, Goal, QueryResult}; use rustc_errors::ErrorGuaranteed; use rustc_hir::def::DefKind; use rustc_hir::def_id::DefId; -use rustc_infer::infer::{InferCtxt, InferOk}; +use rustc_infer::infer::InferCtxt; use rustc_infer::traits::query::NoSolution; use rustc_infer::traits::specialization_graph::LeafDef; -use rustc_infer::traits::{ObligationCause, Reveal}; +use rustc_infer::traits::Reveal; use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; use rustc_middle::ty::ProjectionPredicate; use rustc_middle::ty::TypeVisitable; @@ -112,14 +113,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> { let impl_substs = ecx.infcx.fresh_substs_for_item(DUMMY_SP, impl_def_id); let impl_trait_ref = impl_trait_ref.subst(tcx, impl_substs); - let Ok(InferOk { obligations, .. }) = ecx.infcx - .at(&ObligationCause::dummy(), goal.param_env) - .define_opaque_types(false) - .eq(goal_trait_ref, impl_trait_ref) - .map_err(|e| debug!("failed to equate trait refs: {e:?}")) - else { - return Err(NoSolution) - }; + let mut nested_goals = ecx.infcx.eq(goal.param_env, goal_trait_ref, impl_trait_ref)?; let where_clause_bounds = tcx .predicates_of(impl_def_id) .instantiate(tcx, impl_substs) @@ -127,8 +121,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> { .into_iter() .map(|pred| goal.with(tcx, pred)); - let nested_goals = - obligations.into_iter().map(|o| o.into()).chain(where_clause_bounds).collect(); + nested_goals.extend(where_clause_bounds); let trait_ref_certainty = ecx.evaluate_all(nested_goals)?; let Some(assoc_def) = fetch_eligible_assoc_item_def( @@ -185,16 +178,8 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> { ty.map_bound(|ty| ty.into()) }; - let Ok(InferOk { obligations, .. }) = ecx.infcx - .at(&ObligationCause::dummy(), goal.param_env) - .define_opaque_types(false) - .eq(goal.predicate.term, term.subst(tcx, substs)) - .map_err(|e| debug!("failed to equate trait refs: {e:?}")) - else { - return Err(NoSolution); - }; - - let nested_goals = obligations.into_iter().map(|o| o.into()).collect(); + let nested_goals = + ecx.infcx.eq(goal.param_env, goal.predicate.term, term.subst(tcx, substs))?; let rhs_certainty = ecx.evaluate_all(nested_goals)?; Ok(trait_ref_certainty.unify_and(rhs_certainty)) diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index 3c8314aa565..bbe175d5cc8 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -3,11 +3,10 @@ use std::iter; use super::assembly::{self, Candidate, CandidateSource}; +use super::infcx_ext::InferCtxtExt; use super::{Certainty, EvalCtxt, Goal, QueryResult}; use rustc_hir::def_id::DefId; -use rustc_infer::infer::InferOk; use rustc_infer::traits::query::NoSolution; -use rustc_infer::traits::ObligationCause; use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; use rustc_middle::ty::TraitPredicate; use rustc_middle::ty::{self, Ty, TyCtxt}; @@ -45,24 +44,15 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { let impl_substs = ecx.infcx.fresh_substs_for_item(DUMMY_SP, impl_def_id); let impl_trait_ref = impl_trait_ref.subst(tcx, impl_substs); - let Ok(InferOk { obligations, .. }) = ecx.infcx - .at(&ObligationCause::dummy(), goal.param_env) - .define_opaque_types(false) - .eq(goal.predicate.trait_ref, impl_trait_ref) - .map_err(|e| debug!("failed to equate trait refs: {e:?}")) - else { - return Err(NoSolution); - }; + let mut nested_goals = + ecx.infcx.eq(goal.param_env, goal.predicate.trait_ref, impl_trait_ref)?; let where_clause_bounds = tcx .predicates_of(impl_def_id) .instantiate(tcx, impl_substs) .predicates .into_iter() .map(|pred| goal.with(tcx, pred)); - - let nested_goals = - obligations.into_iter().map(|o| o.into()).chain(where_clause_bounds).collect(); - + nested_goals.extend(where_clause_bounds); ecx.evaluate_all(nested_goals) }) }