diff --git a/compiler/rustc_hir_typeck/src/lib.rs b/compiler/rustc_hir_typeck/src/lib.rs index 67c35d717a1..80467ca9381 100644 --- a/compiler/rustc_hir_typeck/src/lib.rs +++ b/compiler/rustc_hir_typeck/src/lib.rs @@ -60,6 +60,7 @@ use rustc_hir::{HirIdMap, Node}; use rustc_hir_analysis::astconv::AstConv; use rustc_hir_analysis::check::check_abi; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; +use rustc_infer::traits::ObligationInspector; use rustc_middle::query::Providers; use rustc_middle::traits; use rustc_middle::ty::{self, Ty, TyCtxt}; @@ -139,7 +140,7 @@ fn used_trait_imports(tcx: TyCtxt<'_>, def_id: LocalDefId) -> &UnordSet(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::TypeckResults<'tcx> { let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity(); - typeck_with_fallback(tcx, def_id, fallback) + typeck_with_fallback(tcx, def_id, fallback, None) } /// Used only to get `TypeckResults` for type inference during error recovery. @@ -149,14 +150,28 @@ fn diagnostic_only_typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::T let span = tcx.hir().span(tcx.local_def_id_to_hir_id(def_id)); Ty::new_error_with_message(tcx, span, "diagnostic only typeck table used") }; - typeck_with_fallback(tcx, def_id, fallback) + typeck_with_fallback(tcx, def_id, fallback, None) } -#[instrument(level = "debug", skip(tcx, fallback), ret)] +/// Same as `typeck` but `inspect` is invoked on evaluation of each root obligation. +/// Inspecting obligations only works with the new trait solver. +/// This function is *only to be used* by external tools, it should not be +/// called from within rustc. Note, this is not a query, and thus is not cached. +pub fn inspect_typeck<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: LocalDefId, + inspect: ObligationInspector<'tcx>, +) -> &'tcx ty::TypeckResults<'tcx> { + let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity(); + typeck_with_fallback(tcx, def_id, fallback, Some(inspect)) +} + +#[instrument(level = "debug", skip(tcx, fallback, inspector), ret)] fn typeck_with_fallback<'tcx>( tcx: TyCtxt<'tcx>, def_id: LocalDefId, fallback: impl Fn() -> Ty<'tcx> + 'tcx, + inspector: Option>, ) -> &'tcx ty::TypeckResults<'tcx> { // Closures' typeck results come from their outermost function, // as they are part of the same "inference environment". @@ -178,6 +193,9 @@ fn typeck_with_fallback<'tcx>( let param_env = tcx.param_env(def_id); let inh = Inherited::new(tcx, def_id); + if let Some(inspector) = inspector { + inh.infcx.attach_obligation_inspector(inspector); + } let mut fcx = FnCtxt::new(&inh, param_env, def_id); if let Some(hir::FnSig { header, decl, .. }) = fn_sig { diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index e60e3ffeaa7..0f1af81d9f0 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -90,6 +90,7 @@ impl<'tcx> InferCtxt<'tcx> { universe: self.universe.clone(), intercrate, next_trait_solver: self.next_trait_solver, + obligation_inspector: self.obligation_inspector.clone(), } } } diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index e164041c599..39c41a93c3b 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -13,7 +13,9 @@ use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey}; use self::opaque_types::OpaqueTypeStorage; pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog}; -use crate::traits::{self, ObligationCause, PredicateObligations, TraitEngine, TraitEngineExt}; +use crate::traits::{ + self, ObligationCause, ObligationInspector, PredicateObligations, TraitEngine, TraitEngineExt, +}; use rustc_data_structures::fx::FxIndexMap; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; @@ -334,6 +336,8 @@ pub struct InferCtxt<'tcx> { pub intercrate: bool, next_trait_solver: bool, + + pub obligation_inspector: Cell>>, } impl<'tcx> ty::InferCtxtLike for InferCtxt<'tcx> { @@ -708,6 +712,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> { universe: Cell::new(ty::UniverseIndex::ROOT), intercrate, next_trait_solver, + obligation_inspector: Cell::new(None), } } } @@ -1724,6 +1729,15 @@ impl<'tcx> InferCtxt<'tcx> { } } } + + /// Attach a callback to be invoked on each root obligation evaluated in the new trait solver. + pub fn attach_obligation_inspector(&self, inspector: ObligationInspector<'tcx>) { + debug_assert!( + self.obligation_inspector.get().is_none(), + "shouldn't override a set obligation inspector" + ); + self.obligation_inspector.set(Some(inspector)); + } } impl<'tcx> TypeErrCtxt<'_, 'tcx> { diff --git a/compiler/rustc_infer/src/traits/mod.rs b/compiler/rustc_infer/src/traits/mod.rs index fdae093aac8..72ec07375ac 100644 --- a/compiler/rustc_infer/src/traits/mod.rs +++ b/compiler/rustc_infer/src/traits/mod.rs @@ -13,12 +13,15 @@ use std::hash::{Hash, Hasher}; use hir::def_id::LocalDefId; use rustc_hir as hir; +use rustc_middle::traits::query::NoSolution; +use rustc_middle::traits::solve::Certainty; use rustc_middle::ty::error::{ExpectedFound, TypeError}; use rustc_middle::ty::{self, Const, ToPredicate, Ty, TyCtxt}; use rustc_span::Span; pub use self::ImplSource::*; pub use self::SelectionError::*; +use crate::infer::InferCtxt; pub use self::engine::{TraitEngine, TraitEngineExt}; pub use self::project::MismatchedProjectionTypes; @@ -116,6 +119,11 @@ pub type PredicateObligations<'tcx> = Vec>; pub type Selection<'tcx> = ImplSource<'tcx, PredicateObligation<'tcx>>; +/// A callback that can be provided to `inspect_typeck`. Invoked on evaluation +/// of root obligations. +pub type ObligationInspector<'tcx> = + fn(&InferCtxt<'tcx>, &PredicateObligation<'tcx>, Result); + pub struct FulfillmentError<'tcx> { pub obligation: PredicateObligation<'tcx>, pub code: FulfillmentErrorCode<'tcx>, diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs index c847425ebf4..f08622816ec 100644 --- a/compiler/rustc_trait_selection/src/solve/fulfill.rs +++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs @@ -11,7 +11,7 @@ use rustc_middle::ty; use rustc_middle::ty::error::{ExpectedFound, TypeError}; use super::eval_ctxt::GenerateProofTree; -use super::{Certainty, InferCtxtEvalExt}; +use super::{Certainty, Goal, InferCtxtEvalExt}; /// A trait engine using the new trait solver. /// @@ -43,6 +43,21 @@ impl<'tcx> FulfillmentCtxt<'tcx> { ); FulfillmentCtxt { obligations: Vec::new(), usable_in_snapshot: infcx.num_open_snapshots() } } + + fn inspect_evaluated_obligation( + &self, + infcx: &InferCtxt<'tcx>, + obligation: &PredicateObligation<'tcx>, + result: &Result<(bool, Certainty, Vec>>), NoSolution>, + ) { + if let Some(inspector) = infcx.obligation_inspector.get() { + let result = match result { + Ok((_, c, _)) => Ok(*c), + Err(NoSolution) => Err(NoSolution), + }; + (inspector)(infcx, &obligation, result); + } + } } impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> { @@ -100,65 +115,66 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> { let mut has_changed = false; for obligation in mem::take(&mut self.obligations) { let goal = obligation.clone().into(); - let (changed, certainty, nested_goals) = - match infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0 { - Ok(result) => result, - Err(NoSolution) => { - errors.push(FulfillmentError { - obligation: obligation.clone(), - code: match goal.predicate.kind().skip_binder() { - ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => { - FulfillmentErrorCode::ProjectionError( - // FIXME: This could be a `Sorts` if the term is a type - MismatchedProjectionTypes { err: TypeError::Mismatch }, - ) - } - ty::PredicateKind::NormalizesTo(..) => { - FulfillmentErrorCode::ProjectionError( - MismatchedProjectionTypes { err: TypeError::Mismatch }, - ) - } - ty::PredicateKind::AliasRelate(_, _, _) => { - FulfillmentErrorCode::ProjectionError( - MismatchedProjectionTypes { err: TypeError::Mismatch }, - ) - } - ty::PredicateKind::Subtype(pred) => { - let (a, b) = infcx.instantiate_binder_with_placeholders( - goal.predicate.kind().rebind((pred.a, pred.b)), - ); - let expected_found = ExpectedFound::new(true, a, b); - FulfillmentErrorCode::SubtypeError( - expected_found, - TypeError::Sorts(expected_found), - ) - } - ty::PredicateKind::Coerce(pred) => { - let (a, b) = infcx.instantiate_binder_with_placeholders( - goal.predicate.kind().rebind((pred.a, pred.b)), - ); - let expected_found = ExpectedFound::new(false, a, b); - FulfillmentErrorCode::SubtypeError( - expected_found, - TypeError::Sorts(expected_found), - ) - } - ty::PredicateKind::Clause(_) - | ty::PredicateKind::ObjectSafe(_) - | ty::PredicateKind::Ambiguous => { - FulfillmentErrorCode::SelectionError( - SelectionError::Unimplemented, - ) - } - ty::PredicateKind::ConstEquate(..) => { - bug!("unexpected goal: {goal:?}") - } - }, - root_obligation: obligation, - }); - continue; - } - }; + let result = infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0; + self.inspect_evaluated_obligation(infcx, &obligation, &result); + let (changed, certainty, nested_goals) = match result { + Ok(result) => result, + Err(NoSolution) => { + errors.push(FulfillmentError { + obligation: obligation.clone(), + code: match goal.predicate.kind().skip_binder() { + ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => { + FulfillmentErrorCode::ProjectionError( + // FIXME: This could be a `Sorts` if the term is a type + MismatchedProjectionTypes { err: TypeError::Mismatch }, + ) + } + ty::PredicateKind::NormalizesTo(..) => { + FulfillmentErrorCode::ProjectionError( + MismatchedProjectionTypes { err: TypeError::Mismatch }, + ) + } + ty::PredicateKind::AliasRelate(_, _, _) => { + FulfillmentErrorCode::ProjectionError( + MismatchedProjectionTypes { err: TypeError::Mismatch }, + ) + } + ty::PredicateKind::Subtype(pred) => { + let (a, b) = infcx.instantiate_binder_with_placeholders( + goal.predicate.kind().rebind((pred.a, pred.b)), + ); + let expected_found = ExpectedFound::new(true, a, b); + FulfillmentErrorCode::SubtypeError( + expected_found, + TypeError::Sorts(expected_found), + ) + } + ty::PredicateKind::Coerce(pred) => { + let (a, b) = infcx.instantiate_binder_with_placeholders( + goal.predicate.kind().rebind((pred.a, pred.b)), + ); + let expected_found = ExpectedFound::new(false, a, b); + FulfillmentErrorCode::SubtypeError( + expected_found, + TypeError::Sorts(expected_found), + ) + } + ty::PredicateKind::Clause(_) + | ty::PredicateKind::ObjectSafe(_) + | ty::PredicateKind::Ambiguous => { + FulfillmentErrorCode::SelectionError( + SelectionError::Unimplemented, + ) + } + ty::PredicateKind::ConstEquate(..) => { + bug!("unexpected goal: {goal:?}") + } + }, + root_obligation: obligation, + }); + continue; + } + }; // Push any nested goals that we get from unifying our canonical response // with our obligation onto the fulfillment context. self.obligations.extend(nested_goals.into_iter().map(|goal| {