Instantiate binders when checking supertrait upcasting

This commit is contained in:
Michael Goulet 2024-09-26 23:20:59 -04:00
parent d4ee408afc
commit 4fb097a5de
4 changed files with 132 additions and 48 deletions

View File

@ -159,7 +159,24 @@ pub fn eq<T>(
ToTrace::to_trace(self.cause, true, expected, actual), ToTrace::to_trace(self.cause, true, expected, actual),
self.param_env, self.param_env,
define_opaque_types, define_opaque_types,
); ToTrace::to_trace(self.cause, expected, actual),
expected,
actual,
)
}
/// Makes `expected == actual`.
pub fn eq_trace<T>(
self,
define_opaque_types: DefineOpaqueTypes,
trace: TypeTrace<'tcx>,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: Relate<TyCtxt<'tcx>>,
{
let mut fields = CombineFields::new(self.infcx, trace, self.param_env, define_opaque_types);
fields.equate(StructurallyRelateAliases::No).relate(expected, actual)?; fields.equate(StructurallyRelateAliases::No).relate(expected, actual)?;
Ok(InferOk { Ok(InferOk {
value: (), value: (),

View File

@ -448,10 +448,10 @@ fn compute_goal(&mut self, goal: Goal<I, I::Predicate>) -> QueryResult<I> {
} }
} }
} else { } else {
self.delegate.enter_forall(kind, |kind| { self.enter_forall(kind, |ecx, kind| {
let goal = goal.with(self.cx(), ty::Binder::dummy(kind)); let goal = goal.with(ecx.cx(), ty::Binder::dummy(kind));
self.add_goal(GoalSource::InstantiateHigherRanked, goal); ecx.add_goal(GoalSource::InstantiateHigherRanked, goal);
self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
}) })
} }
} }
@ -840,12 +840,14 @@ pub(super) fn instantiate_binder_with_infer<T: TypeFoldable<I> + Copy>(
self.delegate.instantiate_binder_with_infer(value) self.delegate.instantiate_binder_with_infer(value)
} }
/// `enter_forall`, but takes `&mut self` and passes it back through the
/// callback since it can't be aliased during the call.
pub(super) fn enter_forall<T: TypeFoldable<I> + Copy, U>( pub(super) fn enter_forall<T: TypeFoldable<I> + Copy, U>(
&self, &mut self,
value: ty::Binder<I, T>, value: ty::Binder<I, T>,
f: impl FnOnce(T) -> U, f: impl FnOnce(&mut Self, T) -> U,
) -> U { ) -> U {
self.delegate.enter_forall(value, f) self.delegate.enter_forall(value, |value| f(self, value))
} }
pub(super) fn resolve_vars_if_possible<T>(&self, value: T) -> T pub(super) fn resolve_vars_if_possible<T>(&self, value: T) -> T

View File

@ -895,10 +895,13 @@ fn consider_builtin_upcast_to_principal(
source_projection.item_def_id() == target_projection.item_def_id() source_projection.item_def_id() == target_projection.item_def_id()
&& ecx && ecx
.probe(|_| ProbeKind::UpcastProjectionCompatibility) .probe(|_| ProbeKind::UpcastProjectionCompatibility)
.enter(|ecx| -> Result<(), NoSolution> { .enter(|ecx| -> Result<_, NoSolution> {
ecx.sub(param_env, source_projection, target_projection)?; ecx.enter_forall(target_projection, |ecx, target_projection| {
let _ = ecx.try_evaluate_added_goals()?; let source_projection =
Ok(()) ecx.instantiate_binder_with_infer(source_projection);
ecx.eq(param_env, source_projection, target_projection)?;
ecx.try_evaluate_added_goals()
})
}) })
.is_ok() .is_ok()
}; };
@ -909,11 +912,14 @@ fn consider_builtin_upcast_to_principal(
// Check that a's supertrait (upcast_principal) is compatible // Check that a's supertrait (upcast_principal) is compatible
// with the target (b_ty). // with the target (b_ty).
ty::ExistentialPredicate::Trait(target_principal) => { ty::ExistentialPredicate::Trait(target_principal) => {
ecx.sub( let source_principal = upcast_principal.unwrap();
param_env, let target_principal = bound.rebind(target_principal);
upcast_principal.unwrap(), ecx.enter_forall(target_principal, |ecx, target_principal| {
bound.rebind(target_principal), let source_principal =
)?; ecx.instantiate_binder_with_infer(source_principal);
ecx.eq(param_env, source_principal, target_principal)?;
ecx.try_evaluate_added_goals()
})?;
} }
// Check that b_ty's projection is satisfied by exactly one of // Check that b_ty's projection is satisfied by exactly one of
// a_ty's projections. First, we look through the list to see if // a_ty's projections. First, we look through the list to see if
@ -934,7 +940,12 @@ fn consider_builtin_upcast_to_principal(
Certainty::AMBIGUOUS, Certainty::AMBIGUOUS,
); );
} }
ecx.sub(param_env, source_projection, target_projection)?; ecx.enter_forall(target_projection, |ecx, target_projection| {
let source_projection =
ecx.instantiate_binder_with_infer(source_projection);
ecx.eq(param_env, source_projection, target_projection)?;
ecx.try_evaluate_added_goals()
})?;
} }
// Check that b_ty's auto traits are present in a_ty's bounds. // Check that b_ty's auto traits are present in a_ty's bounds.
ty::ExistentialPredicate::AutoTrait(def_id) => { ty::ExistentialPredicate::AutoTrait(def_id) => {
@ -1187,17 +1198,15 @@ fn probe_and_evaluate_goal_for_constituent_tys(
) -> Result<Vec<ty::Binder<I, I::Ty>>, NoSolution>, ) -> Result<Vec<ty::Binder<I, I::Ty>>, NoSolution>,
) -> Result<Candidate<I>, NoSolution> { ) -> Result<Candidate<I>, NoSolution> {
self.probe_trait_candidate(source).enter(|ecx| { self.probe_trait_candidate(source).enter(|ecx| {
ecx.add_goals( let goals = constituent_tys(ecx, goal.predicate.self_ty())?
GoalSource::ImplWhereBound, .into_iter()
constituent_tys(ecx, goal.predicate.self_ty())? .map(|ty| {
.into_iter() ecx.enter_forall(ty, |ecx, ty| {
.map(|ty| { goal.with(ecx.cx(), goal.predicate.with_self_ty(ecx.cx(), ty))
ecx.enter_forall(ty, |ty| {
goal.with(ecx.cx(), goal.predicate.with_self_ty(ecx.cx(), ty))
})
}) })
.collect::<Vec<_>>(), })
); .collect::<Vec<_>>();
ecx.add_goals(GoalSource::ImplWhereBound, goals);
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
}) })
} }

View File

@ -16,6 +16,7 @@
use rustc_hir::def_id::DefId; use rustc_hir::def_id::DefId;
use rustc_infer::infer::BoundRegionConversionTime::{self, HigherRankedType}; use rustc_infer::infer::BoundRegionConversionTime::{self, HigherRankedType};
use rustc_infer::infer::DefineOpaqueTypes; use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::infer::at::ToTrace;
use rustc_infer::infer::relate::TypeRelation; use rustc_infer::infer::relate::TypeRelation;
use rustc_infer::traits::TraitObligation; use rustc_infer::traits::TraitObligation;
use rustc_middle::bug; use rustc_middle::bug;
@ -44,7 +45,7 @@
TraitQueryMode, const_evaluatable, project, util, wf, TraitQueryMode, const_evaluatable, project, util, wf,
}; };
use crate::error_reporting::InferCtxtErrorExt; use crate::error_reporting::InferCtxtErrorExt;
use crate::infer::{InferCtxt, InferCtxtExt, InferOk, TypeFreshener}; use crate::infer::{InferCtxt, InferOk, TypeFreshener};
use crate::solve::InferCtxtSelectExt as _; use crate::solve::InferCtxtSelectExt as _;
use crate::traits::normalize::{normalize_with_depth, normalize_with_depth_to}; use crate::traits::normalize::{normalize_with_depth, normalize_with_depth_to};
use crate::traits::project::{ProjectAndUnifyResult, ProjectionCacheKeyExt}; use crate::traits::project::{ProjectAndUnifyResult, ProjectionCacheKeyExt};
@ -2579,16 +2580,32 @@ fn match_upcast_principal(
// Check that a_ty's supertrait (upcast_principal) is compatible // Check that a_ty's supertrait (upcast_principal) is compatible
// with the target (b_ty). // with the target (b_ty).
ty::ExistentialPredicate::Trait(target_principal) => { ty::ExistentialPredicate::Trait(target_principal) => {
let hr_source_principal = upcast_principal.map_bound(|trait_ref| {
ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
});
let hr_target_principal = bound.rebind(target_principal);
nested.extend( nested.extend(
self.infcx self.infcx
.at(&obligation.cause, obligation.param_env) .enter_forall(hr_target_principal, |target_principal| {
.sup( let source_principal =
DefineOpaqueTypes::Yes, self.infcx.instantiate_binder_with_fresh_vars(
bound.rebind(target_principal), obligation.cause.span,
upcast_principal.map_bound(|trait_ref| { HigherRankedType,
ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref) hr_source_principal,
}), );
) self.infcx.at(&obligation.cause, obligation.param_env).eq_trace(
DefineOpaqueTypes::Yes,
ToTrace::to_trace(
&obligation.cause,
true,
hr_target_principal,
hr_source_principal,
),
target_principal,
source_principal,
)
})
.map_err(|_| SelectionError::Unimplemented)? .map_err(|_| SelectionError::Unimplemented)?
.into_obligations(), .into_obligations(),
); );
@ -2599,19 +2616,41 @@ fn match_upcast_principal(
// return ambiguity. Otherwise, if exactly one matches, equate // return ambiguity. Otherwise, if exactly one matches, equate
// it with b_ty's projection. // it with b_ty's projection.
ty::ExistentialPredicate::Projection(target_projection) => { ty::ExistentialPredicate::Projection(target_projection) => {
let target_projection = bound.rebind(target_projection); let hr_target_projection = bound.rebind(target_projection);
let mut matching_projections = let mut matching_projections =
a_data.projection_bounds().filter(|source_projection| { a_data.projection_bounds().filter(|&hr_source_projection| {
// Eager normalization means that we can just use can_eq // Eager normalization means that we can just use can_eq
// here instead of equating and processing obligations. // here instead of equating and processing obligations.
source_projection.item_def_id() == target_projection.item_def_id() hr_source_projection.item_def_id() == hr_target_projection.item_def_id()
&& self.infcx.can_eq( && self.infcx.probe(|_| {
obligation.param_env, self.infcx
*source_projection, .enter_forall(hr_target_projection, |target_projection| {
target_projection, let source_projection =
) self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
hr_source_projection,
);
self.infcx
.at(&obligation.cause, obligation.param_env)
.eq_trace(
DefineOpaqueTypes::Yes,
ToTrace::to_trace(
&obligation.cause,
true,
hr_target_projection,
hr_source_projection,
),
target_projection,
source_projection,
)
})
.is_ok()
})
}); });
let Some(source_projection) = matching_projections.next() else {
let Some(hr_source_projection) = matching_projections.next() else {
return Err(SelectionError::Unimplemented); return Err(SelectionError::Unimplemented);
}; };
if matching_projections.next().is_some() { if matching_projections.next().is_some() {
@ -2619,8 +2658,25 @@ fn match_upcast_principal(
} }
nested.extend( nested.extend(
self.infcx self.infcx
.at(&obligation.cause, obligation.param_env) .enter_forall(hr_target_projection, |target_projection| {
.sup(DefineOpaqueTypes::Yes, target_projection, source_projection) let source_projection =
self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
hr_source_projection,
);
self.infcx.at(&obligation.cause, obligation.param_env).eq_trace(
DefineOpaqueTypes::Yes,
ToTrace::to_trace(
&obligation.cause,
true,
hr_target_projection,
hr_source_projection,
),
target_projection,
source_projection,
)
})
.map_err(|_| SelectionError::Unimplemented)? .map_err(|_| SelectionError::Unimplemented)?
.into_obligations(), .into_obligations(),
); );