diff --git a/compiler/rustc_borrowck/src/nll.rs b/compiler/rustc_borrowck/src/nll.rs index 2440ae9780d..3a919e954a4 100644 --- a/compiler/rustc_borrowck/src/nll.rs +++ b/compiler/rustc_borrowck/src/nll.rs @@ -299,7 +299,7 @@ pub(crate) fn compute_regions<'cx, 'tcx>( // Solve the region constraints. let (closure_region_requirements, nll_errors) = - regioncx.solve(infcx, &body, polonius_output.clone()); + regioncx.solve(infcx, param_env, &body, polonius_output.clone()); if !nll_errors.is_empty() { // Suppress unhelpful extra errors in `infer_opaque_types`. diff --git a/compiler/rustc_borrowck/src/region_infer/mod.rs b/compiler/rustc_borrowck/src/region_infer/mod.rs index 228e88b33cf..d553a60faef 100644 --- a/compiler/rustc_borrowck/src/region_infer/mod.rs +++ b/compiler/rustc_borrowck/src/region_infer/mod.rs @@ -10,7 +10,8 @@ use rustc_hir::def_id::{DefId, CRATE_DEF_ID}; use rustc_hir::CRATE_HIR_ID; use rustc_index::vec::IndexVec; use rustc_infer::infer::canonical::QueryOutlivesConstraint; -use rustc_infer::infer::region_constraints::{GenericKind, VarInfos, VerifyBound}; +use rustc_infer::infer::outlives::test_type_match; +use rustc_infer::infer::region_constraints::{GenericKind, VarInfos, VerifyBound, VerifyIfEq}; use rustc_infer::infer::{InferCtxt, NllRegionVariableOrigin, RegionVariableOrigin}; use rustc_middle::mir::{ Body, ClosureOutlivesRequirement, ClosureOutlivesSubject, ClosureRegionRequirements, @@ -18,6 +19,7 @@ use rustc_middle::mir::{ }; use rustc_middle::traits::ObligationCause; use rustc_middle::traits::ObligationCauseCode; +use rustc_middle::ty::Region; use rustc_middle::ty::{self, subst::SubstsRef, RegionVid, Ty, TyCtxt, TypeFoldable}; use rustc_span::Span; @@ -46,6 +48,7 @@ pub mod values; pub struct RegionInferenceContext<'tcx> { pub var_infos: VarInfos, + /// Contains the definition for every region variable. Region /// variables are identified by their index (`RegionVid`). The /// definition contains information about where the region came @@ -559,6 +562,7 @@ impl<'tcx> RegionInferenceContext<'tcx> { pub(super) fn solve( &mut self, infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, body: &Body<'tcx>, polonius_output: Option>, ) -> (Option>, RegionErrors<'tcx>) { @@ -574,7 +578,13 @@ impl<'tcx> RegionInferenceContext<'tcx> { // eagerly. let mut outlives_requirements = infcx.tcx.is_typeck_child(mir_def_id).then(Vec::new); - self.check_type_tests(infcx, body, outlives_requirements.as_mut(), &mut errors_buffer); + self.check_type_tests( + infcx, + param_env, + body, + outlives_requirements.as_mut(), + &mut errors_buffer, + ); // In Polonius mode, the errors about missing universal region relations are in the output // and need to be emitted or propagated. Otherwise, we need to check whether the @@ -823,6 +833,7 @@ impl<'tcx> RegionInferenceContext<'tcx> { fn check_type_tests( &self, infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, body: &Body<'tcx>, mut propagated_outlives_requirements: Option<&mut Vec>>, errors_buffer: &mut RegionErrors<'tcx>, @@ -839,7 +850,8 @@ impl<'tcx> RegionInferenceContext<'tcx> { let generic_ty = type_test.generic_kind.to_ty(tcx); if self.eval_verify_bound( - tcx, + infcx, + param_env, body, generic_ty, type_test.lower_bound, @@ -851,6 +863,7 @@ impl<'tcx> RegionInferenceContext<'tcx> { if let Some(propagated_outlives_requirements) = &mut propagated_outlives_requirements { if self.try_promote_type_test( infcx, + param_env, body, type_test, propagated_outlives_requirements, @@ -907,6 +920,7 @@ impl<'tcx> RegionInferenceContext<'tcx> { fn try_promote_type_test( &self, infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, body: &Body<'tcx>, type_test: &TypeTest<'tcx>, propagated_outlives_requirements: &mut Vec>, @@ -938,7 +952,14 @@ impl<'tcx> RegionInferenceContext<'tcx> { // where `ur` is a local bound -- we are sometimes in a // position to prove things that our caller cannot. See // #53570 for an example. - if self.eval_verify_bound(tcx, body, generic_ty, ur, &type_test.verify_bound) { + if self.eval_verify_bound( + infcx, + param_env, + body, + generic_ty, + ur, + &type_test.verify_bound, + ) { continue; } @@ -1161,7 +1182,8 @@ impl<'tcx> RegionInferenceContext<'tcx> { /// `point`. fn eval_verify_bound( &self, - tcx: TyCtxt<'tcx>, + infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, body: &Body<'tcx>, generic_ty: Ty<'tcx>, lower_bound: RegionVid, @@ -1170,14 +1192,13 @@ impl<'tcx> RegionInferenceContext<'tcx> { debug!("eval_verify_bound(lower_bound={:?}, verify_bound={:?})", lower_bound, verify_bound); match verify_bound { - VerifyBound::IfEq(test_ty, verify_bound1) => self.eval_if_eq( - tcx, - body, - generic_ty, - lower_bound, - *test_ty, - &VerifyBound::OutlivedBy(*verify_bound1), - ), + VerifyBound::IfEq(test_ty, verify_bound1) => { + self.eval_if_eq(infcx, generic_ty, lower_bound, *test_ty, *verify_bound1) + } + + VerifyBound::IfEqBound(verify_if_eq_b) => { + self.eval_if_eq_bound(infcx, param_env, generic_ty, lower_bound, *verify_if_eq_b) + } VerifyBound::IsEmpty => { let lower_bound_scc = self.constraint_sccs.scc(lower_bound); @@ -1190,33 +1211,71 @@ impl<'tcx> RegionInferenceContext<'tcx> { } VerifyBound::AnyBound(verify_bounds) => verify_bounds.iter().any(|verify_bound| { - self.eval_verify_bound(tcx, body, generic_ty, lower_bound, verify_bound) + self.eval_verify_bound( + infcx, + param_env, + body, + generic_ty, + lower_bound, + verify_bound, + ) }), VerifyBound::AllBounds(verify_bounds) => verify_bounds.iter().all(|verify_bound| { - self.eval_verify_bound(tcx, body, generic_ty, lower_bound, verify_bound) + self.eval_verify_bound( + infcx, + param_env, + body, + generic_ty, + lower_bound, + verify_bound, + ) }), } } fn eval_if_eq( &self, - tcx: TyCtxt<'tcx>, - body: &Body<'tcx>, + infcx: &InferCtxt<'_, 'tcx>, generic_ty: Ty<'tcx>, lower_bound: RegionVid, test_ty: Ty<'tcx>, - verify_bound: &VerifyBound<'tcx>, + verify_bound: Region<'tcx>, ) -> bool { - let generic_ty_normalized = self.normalize_to_scc_representatives(tcx, generic_ty); - let test_ty_normalized = self.normalize_to_scc_representatives(tcx, test_ty); + let generic_ty_normalized = self.normalize_to_scc_representatives(infcx.tcx, generic_ty); + let test_ty_normalized = self.normalize_to_scc_representatives(infcx.tcx, test_ty); if generic_ty_normalized == test_ty_normalized { - self.eval_verify_bound(tcx, body, generic_ty, lower_bound, verify_bound) + let verify_bound_vid = self.to_region_vid(verify_bound); + self.eval_outlives(verify_bound_vid, lower_bound) } else { false } } + fn eval_if_eq_bound( + &self, + infcx: &InferCtxt<'_, 'tcx>, + param_env: ty::ParamEnv<'tcx>, + generic_ty: Ty<'tcx>, + lower_bound: RegionVid, + verify_if_eq_b: ty::Binder<'tcx, VerifyIfEq<'tcx>>, + ) -> bool { + let generic_ty = self.normalize_to_scc_representatives(infcx.tcx, generic_ty); + let verify_if_eq_b = self.normalize_to_scc_representatives(infcx.tcx, verify_if_eq_b); + match test_type_match::extract_verify_if_eq_bound( + infcx.tcx, + param_env, + &verify_if_eq_b, + generic_ty, + ) { + Some(r) => { + let r_vid = self.to_region_vid(r); + self.eval_outlives(r_vid, lower_bound) + } + None => false, + } + } + /// This is a conservative normalization procedure. It takes every /// free region in `value` and replaces it with the /// "representative" of its SCC (see `scc_representatives` field). diff --git a/compiler/rustc_infer/src/infer/lexical_region_resolve/mod.rs b/compiler/rustc_infer/src/infer/lexical_region_resolve/mod.rs index c5afd376217..1cc5f3d53c9 100644 --- a/compiler/rustc_infer/src/infer/lexical_region_resolve/mod.rs +++ b/compiler/rustc_infer/src/infer/lexical_region_resolve/mod.rs @@ -22,6 +22,8 @@ use rustc_middle::ty::{Region, RegionVid}; use rustc_span::Span; use std::fmt; +use super::outlives::test_type_match; + /// This function performs lexical region resolution given a complete /// set of constraints and variable origins. It performs a fixed-point /// iteration to find region values which satisfy all constraints, @@ -29,12 +31,13 @@ use std::fmt; /// all the variables as well as a set of errors that must be reported. #[instrument(level = "debug", skip(region_rels, var_infos, data))] pub(crate) fn resolve<'tcx>( + param_env: ty::ParamEnv<'tcx>, region_rels: &RegionRelations<'_, 'tcx>, var_infos: VarInfos, data: RegionConstraintData<'tcx>, ) -> (LexicalRegionResolutions<'tcx>, Vec>) { let mut errors = vec![]; - let mut resolver = LexicalResolver { region_rels, var_infos, data }; + let mut resolver = LexicalResolver { param_env, region_rels, var_infos, data }; let values = resolver.infer_variable_values(&mut errors); (values, errors) } @@ -100,6 +103,7 @@ struct RegionAndOrigin<'tcx> { type RegionGraph<'tcx> = Graph<(), Constraint<'tcx>>; struct LexicalResolver<'cx, 'tcx> { + param_env: ty::ParamEnv<'tcx>, region_rels: &'cx RegionRelations<'cx, 'tcx>, var_infos: VarInfos, data: RegionConstraintData<'tcx>, @@ -823,6 +827,21 @@ impl<'cx, 'tcx> LexicalResolver<'cx, 'tcx> { && self.bound_is_met(&VerifyBound::OutlivedBy(*r), var_values, generic_ty, min) } + VerifyBound::IfEqBound(verify_if_eq_b) => { + match test_type_match::extract_verify_if_eq_bound( + self.tcx(), + self.param_env, + verify_if_eq_b, + generic_ty, + ) { + Some(r) => { + self.bound_is_met(&VerifyBound::OutlivedBy(r), var_values, generic_ty, min) + } + + None => false, + } + } + VerifyBound::OutlivedBy(r) => { self.sub_concrete_regions(min, var_values.normalize(self.tcx(), *r)) } diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index 93a067cb516..6f88b83a473 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -1290,7 +1290,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { &RegionRelations::new(self.tcx, region_context, outlives_env.free_region_map()); let (lexical_region_resolutions, errors) = - lexical_region_resolve::resolve(region_rels, var_infos, data); + lexical_region_resolve::resolve(outlives_env.param_env, region_rels, var_infos, data); let old_value = self.lexical_region_resolutions.replace(Some(lexical_region_resolutions)); assert!(old_value.is_none()); diff --git a/compiler/rustc_infer/src/infer/outlives/mod.rs b/compiler/rustc_infer/src/infer/outlives/mod.rs index b9652e83e65..2a085288fb7 100644 --- a/compiler/rustc_infer/src/infer/outlives/mod.rs +++ b/compiler/rustc_infer/src/infer/outlives/mod.rs @@ -3,6 +3,7 @@ pub mod components; pub mod env; pub mod obligations; +pub mod test_type_match; pub mod verify; use rustc_middle::traits::query::OutlivesBound; diff --git a/compiler/rustc_infer/src/infer/outlives/test_type_match.rs b/compiler/rustc_infer/src/infer/outlives/test_type_match.rs new file mode 100644 index 00000000000..99d6aabf0d6 --- /dev/null +++ b/compiler/rustc_infer/src/infer/outlives/test_type_match.rs @@ -0,0 +1,179 @@ +use std::collections::hash_map::Entry; + +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::ty::TypeFoldable; +use rustc_middle::ty::{ + self, + error::TypeError, + relate::{self, Relate, RelateResult, TypeRelation}, + Ty, TyCtxt, +}; + +use crate::infer::region_constraints::VerifyIfEq; + +/// Given a "verify-if-eq" type test like: +/// +/// exists<'a...> { +/// verify_if_eq(some_type, bound_region) +/// } +/// +/// and the type `test_ty` that the type test is being tested against, +/// returns: +/// +/// * `None` if `some_type` cannot be made equal to `test_ty`, +/// no matter the values of the variables in `exists`. +/// * `Some(r)` with a suitable bound (typically the value of `bound_region`, modulo +/// any bound existential variables, which will be substituted) for the +/// type under test. +/// +/// NB: This function uses a simplistic, syntactic version of type equality. +/// In other words, it may spuriously return `None` even if the type-under-test +/// is in fact equal to `some_type`. In practice, though, this is used on types +/// that are either projections like `T::Item` or `T` and it works fine, but it +/// could have trouble when complex types with higher-ranked binders and the +/// like are used. This is a particular challenge since this function is invoked +/// very late in inference and hence cannot make use of the normal inference +/// machinery. +pub fn extract_verify_if_eq_bound<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + verify_if_eq_b: &ty::Binder<'tcx, VerifyIfEq<'tcx>>, + test_ty: Ty<'tcx>, +) -> Option> { + assert!(!verify_if_eq_b.has_escaping_bound_vars()); + let mut m = Match::new(tcx, param_env); + let verify_if_eq = verify_if_eq_b.skip_binder(); + m.relate(verify_if_eq.ty, test_ty).ok()?; + + if let ty::RegionKind::ReLateBound(depth, br) = verify_if_eq.bound.kind() { + assert!(depth == ty::INNERMOST); + match m.map.get(&br) { + Some(&r) => Some(r), + None => { + // If there is no mapping, then this region is unconstrained. + // In that case, we escalate to `'static`. + Some(tcx.lifetimes.re_static) + } + } + } else { + // The region does not contain any inference variables. + Some(verify_if_eq.bound) + } +} + +struct Match<'tcx> { + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + pattern_depth: ty::DebruijnIndex, + map: FxHashMap>, +} + +impl<'tcx> Match<'tcx> { + fn new(tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>) -> Match<'tcx> { + Match { tcx, param_env, pattern_depth: ty::INNERMOST, map: FxHashMap::default() } + } +} + +impl<'tcx> Match<'tcx> { + /// Creates the "Error" variant that signals "no match". + fn no_match(&self) -> RelateResult<'tcx, T> { + Err(TypeError::Mismatch) + } + + /// Binds the pattern variable `br` to `value`; returns an `Err` if the pattern + /// is already bound to a different value. + fn bind( + &mut self, + br: ty::BoundRegion, + value: ty::Region<'tcx>, + ) -> RelateResult<'tcx, ty::Region<'tcx>> { + match self.map.entry(br) { + Entry::Occupied(entry) => { + if *entry.get() == value { + Ok(value) + } else { + self.no_match() + } + } + Entry::Vacant(entry) => { + entry.insert(value); + Ok(value) + } + } + } +} + +impl<'tcx> TypeRelation<'tcx> for Match<'tcx> { + fn tag(&self) -> &'static str { + "Match" + } + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.param_env + } + fn a_is_expected(&self) -> bool { + true + } // irrelevant + + fn relate_with_variance>( + &mut self, + _: ty::Variance, + _: ty::VarianceDiagInfo<'tcx>, + a: T, + b: T, + ) -> RelateResult<'tcx, T> { + self.relate(a, b) + } + + #[instrument(skip(self), level = "debug")] + fn regions( + &mut self, + pattern: ty::Region<'tcx>, + value: ty::Region<'tcx>, + ) -> RelateResult<'tcx, ty::Region<'tcx>> { + if let ty::RegionKind::ReLateBound(depth, br) = pattern.kind() && depth == self.pattern_depth { + self.bind(br, pattern) + } else if pattern == value { + Ok(pattern) + } else { + self.no_match() + } + } + + fn tys(&mut self, pattern: Ty<'tcx>, value: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { + if pattern == value { + return Ok(pattern); + } else { + relate::super_relate_tys(self, pattern, value) + } + } + + fn consts( + &mut self, + pattern: ty::Const<'tcx>, + value: ty::Const<'tcx>, + ) -> RelateResult<'tcx, ty::Const<'tcx>> { + debug!("{}.consts({:?}, {:?})", self.tag(), pattern, value); + if pattern == value { + return Ok(pattern); + } else { + relate::super_relate_consts(self, pattern, value) + } + } + + fn binders( + &mut self, + pattern: ty::Binder<'tcx, T>, + value: ty::Binder<'tcx, T>, + ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> + where + T: Relate<'tcx>, + { + self.pattern_depth.shift_in(1); + let result = Ok(pattern.rebind(self.relate(pattern.skip_binder(), value.skip_binder())?)); + self.pattern_depth.shift_out(1); + result + } +} diff --git a/compiler/rustc_infer/src/infer/region_constraints/mod.rs b/compiler/rustc_infer/src/infer/region_constraints/mod.rs index d7b4f450e0f..e0ccbb2c0f9 100644 --- a/compiler/rustc_infer/src/infer/region_constraints/mod.rs +++ b/compiler/rustc_infer/src/infer/region_constraints/mod.rs @@ -226,6 +226,8 @@ pub enum VerifyBound<'tcx> { /// (after inference), and `'a: min`, then `G: min`. IfEq(Ty<'tcx>, Region<'tcx>), + IfEqBound(ty::Binder<'tcx, VerifyIfEq<'tcx>>), + /// Given a region `R`, expands to the function: /// /// ```ignore (pseudo-rust) @@ -267,6 +269,49 @@ pub enum VerifyBound<'tcx> { AllBounds(Vec>), } +/// Given a kind K and a bound B, expands to a function like the +/// following, where `G` is the generic for which this verify +/// bound was created: +/// +/// ```ignore (pseudo-rust) +/// fn(min) -> bool { +/// if G == K { +/// B(min) +/// } else { +/// false +/// } +/// } +/// ``` +/// +/// In other words, if the generic `G` that we are checking is +/// equal to `K`, then check the associated verify bound +/// (otherwise, false). +/// +/// This is used when we have something in the environment that +/// may or may not be relevant, depending on the region inference +/// results. For example, we may have `where >::Item: 'b` in our where-clauses. If we are +/// generating the verify-bound for `>::Item`, then +/// this where-clause is only relevant if `'0` winds up inferred +/// to `'a`. +/// +/// So we would compile to a verify-bound like +/// +/// ```ignore (illustrative) +/// IfEq(>::Item, AnyRegion('a)) +/// ``` +/// +/// meaning, if the subject G is equal to `>::Item` +/// (after inference), and `'a: min`, then `G: min`. +#[derive(Debug, Copy, Clone, TypeFoldable)] +pub struct VerifyIfEq<'tcx> { + /// Type which must match the generic `G` + pub ty: Ty<'tcx>, + + /// Bound that applies if `ty` is equal. + pub bound: Region<'tcx>, +} + #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub(crate) struct TwoRegions<'tcx> { a: Region<'tcx>, @@ -761,6 +806,7 @@ impl<'tcx> VerifyBound<'tcx> { pub fn must_hold(&self) -> bool { match self { VerifyBound::IfEq(..) => false, + VerifyBound::IfEqBound(..) => false, VerifyBound::OutlivedBy(re) => re.is_static(), VerifyBound::IsEmpty => false, VerifyBound::AnyBound(bs) => bs.iter().any(|b| b.must_hold()), @@ -771,6 +817,7 @@ impl<'tcx> VerifyBound<'tcx> { pub fn cannot_hold(&self) -> bool { match self { VerifyBound::IfEq(_, _) => false, + VerifyBound::IfEqBound(..) => false, VerifyBound::IsEmpty => false, VerifyBound::OutlivedBy(_) => false, VerifyBound::AnyBound(bs) => bs.iter().all(|b| b.cannot_hold()),