From ce11ae5d0de03d23b770f4d7ee912c792678f09d Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 27 Oct 2022 15:40:30 +0000 Subject: [PATCH] Address some more nits --- .../src/check/compare_method.rs | 15 +++--------- .../src/traits/engine.rs | 23 ++++++++++++++++--- compiler/rustc_traits/src/type_op.rs | 17 +++----------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/check/compare_method.rs b/compiler/rustc_hir_analysis/src/check/compare_method.rs index e72f18012ab..616472a1a3f 100644 --- a/compiler/rustc_hir_analysis/src/check/compare_method.rs +++ b/compiler/rustc_hir_analysis/src/check/compare_method.rs @@ -290,10 +290,7 @@ fn compare_predicate_entailment<'tcx>( // type would be more appropriate. In other places we have a `Vec` // corresponding to their `Vec`, but we don't have that here. // Fixing this would improve the output of test `issue-83765.rs`. - let mut result = infcx - .at(&cause, param_env) - .sup(trait_fty, impl_fty) - .map(|infer_ok| ocx.register_infer_ok_obligations(infer_ok)); + let mut result = ocx.sup_types(&cause, param_env, trait_fty, impl_fty); // HACK(RPITIT): #101614. When we are trying to infer the hidden types for // RPITITs, we need to equate the output tys instead of just subtyping. If @@ -302,10 +299,7 @@ fn compare_predicate_entailment<'tcx>( // fixed up to `ReEmpty`, and which is certainly not what we want. if trait_fty.has_infer_types() { result = result.and_then(|()| { - infcx - .at(&cause, param_env) - .eq(trait_sig.output(), impl_sig.output()) - .map(|infer_ok| ocx.register_infer_ok_obligations(infer_ok)) + ocx.equate_types(&cause, param_env, trait_sig.output(), impl_sig.output()) }); } @@ -1389,10 +1383,7 @@ pub(crate) fn raw_compare_const_impl<'tcx>( debug!("compare_const_impl: trait_ty={:?}", trait_ty); - let err = infcx - .at(&cause, param_env) - .sup(trait_ty, impl_ty) - .map(|ok| ocx.register_infer_ok_obligations(ok)); + let err = ocx.sup_types(&cause, param_env, trait_ty, impl_ty); if let Err(terr) = err { debug!( diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs index 18ad99b4935..c760ce1fed9 100644 --- a/compiler/rustc_trait_selection/src/traits/engine.rs +++ b/compiler/rustc_trait_selection/src/traits/engine.rs @@ -6,6 +6,7 @@ use super::{ChalkFulfillmentContext, FulfillmentContext}; use crate::infer::InferCtxtExt; use rustc_data_structures::fx::FxHashSet; use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_infer::infer::at::ToTrace; use rustc_infer::infer::canonical::{ Canonical, CanonicalVarValues, CanonicalizedQueryResponse, QueryResponse, }; @@ -111,12 +112,12 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> { self.register_infer_ok_obligations(infer_ok) } - pub fn equate_types( + pub fn equate_types>( &self, cause: &ObligationCause<'tcx>, param_env: ty::ParamEnv<'tcx>, - expected: Ty<'tcx>, - actual: Ty<'tcx>, + expected: T, + actual: T, ) -> Result<(), TypeError<'tcx>> { match self.infcx.at(cause, param_env).eq(expected, actual) { Ok(InferOk { obligations, value: () }) => { @@ -127,6 +128,22 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> { } } + pub fn sup_types>( + &self, + cause: &ObligationCause<'tcx>, + param_env: ty::ParamEnv<'tcx>, + expected: T, + actual: T, + ) -> Result<(), TypeError<'tcx>> { + match self.infcx.at(cause, param_env).sup(expected, actual) { + Ok(InferOk { obligations, value: () }) => { + self.register_obligations(obligations); + Ok(()) + } + Err(e) => Err(e), + } + } + pub fn select_all_or_error(&self) -> Vec> { self.engine.borrow_mut().select_all_or_error(self.infcx) } diff --git a/compiler/rustc_traits/src/type_op.rs b/compiler/rustc_traits/src/type_op.rs index 827c747e8ed..e0465121ad9 100644 --- a/compiler/rustc_traits/src/type_op.rs +++ b/compiler/rustc_traits/src/type_op.rs @@ -87,12 +87,7 @@ impl<'me, 'tcx> AscribeUserTypeCx<'me, 'tcx> { where T: ToTrace<'tcx>, { - Ok(self.ocx.register_infer_ok_obligations( - self.ocx - .infcx - .at(&ObligationCause::dummy_with_span(self.span), self.param_env) - .eq(a, b)?, - )) + Ok(self.ocx.equate_types(&ObligationCause::dummy_with_span(self.span), self.param_env, a, b)?) } fn prove_predicate(&self, predicate: Predicate<'tcx>, cause: ObligationCause<'tcx>) { @@ -181,10 +176,7 @@ fn type_op_eq<'tcx>( ) -> Result<&'tcx Canonical<'tcx, QueryResponse<'tcx, ()>>, NoSolution> { tcx.infer_ctxt().enter_canonical_trait_query(&canonicalized, |ocx, key| { let (param_env, Eq { a, b }) = key.into_parts(); - ocx.register_infer_ok_obligations( - ocx.infcx.at(&ObligationCause::dummy(), param_env).eq(a, b)?, - ); - Ok(()) + Ok(ocx.equate_types(&ObligationCause::dummy(), param_env, a, b)?) }) } @@ -236,10 +228,7 @@ fn type_op_subtype<'tcx>( ) -> Result<&'tcx Canonical<'tcx, QueryResponse<'tcx, ()>>, NoSolution> { tcx.infer_ctxt().enter_canonical_trait_query(&canonicalized, |ocx, key| { let (param_env, Subtype { sub, sup }) = key.into_parts(); - ocx.register_infer_ok_obligations( - ocx.infcx.at(&ObligationCause::dummy(), param_env).sup(sup, sub)?, - ); - Ok(()) + Ok(ocx.sup_types(&ObligationCause::dummy(), param_env, sup, sub)?) }) }