diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs index b13c9627bf7..4ed4164bd27 100644 --- a/compiler/rustc_infer/src/infer/combine.rs +++ b/compiler/rustc_infer/src/infer/combine.rs @@ -26,21 +26,17 @@ use super::glb::Glb; use super::lub::Lub; use super::sub::Sub; -use super::type_variable::TypeVariableValue; -use super::{DefineOpaqueTypes, InferCtxt, MiscVariable, TypeTrace}; -use crate::infer::generalize::{generalize, CombineDelegate, Generalization}; +use super::{DefineOpaqueTypes, InferCtxt, TypeTrace}; +use crate::infer::generalize::{self, CombineDelegate, Generalization}; use crate::traits::{Obligation, PredicateObligations}; use rustc_middle::infer::canonical::OriginalQueryValues; use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue}; use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind}; use rustc_middle::ty::error::{ExpectedFound, TypeError}; use rustc_middle::ty::relate::{RelateResult, TypeRelation}; -use rustc_middle::ty::{ - self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable, - TypeSuperFoldable, TypeVisitableExt, -}; +use rustc_middle::ty::{self, AliasKind, InferConst, ToPredicate, Ty, TyCtxt, TypeVisitableExt}; use rustc_middle::ty::{IntType, UintType}; -use rustc_span::{Span, DUMMY_SP}; +use rustc_span::DUMMY_SP; #[derive(Clone)] pub struct CombineFields<'infcx, 'tcx> { @@ -208,11 +204,11 @@ pub fn super_combine_consts( // matching in the solver. let a_error = self.tcx.const_error(a.ty(), guar); if let ty::ConstKind::Infer(InferConst::Var(vid)) = a.kind() { - return self.unify_const_variable(vid, a_error); + return self.unify_const_variable(vid, a_error, relation.param_env()); } let b_error = self.tcx.const_error(b.ty(), guar); if let ty::ConstKind::Infer(InferConst::Var(vid)) = b.kind() { - return self.unify_const_variable(vid, b_error); + return self.unify_const_variable(vid, b_error, relation.param_env()); } return Ok(if relation.a_is_expected() { a_error } else { b_error }); @@ -234,11 +230,11 @@ pub fn super_combine_consts( } (ty::ConstKind::Infer(InferConst::Var(vid)), _) => { - return self.unify_const_variable(vid, b); + return self.unify_const_variable(vid, b, relation.param_env()); } (_, ty::ConstKind::Infer(InferConst::Var(vid))) => { - return self.unify_const_variable(vid, a); + return self.unify_const_variable(vid, a, relation.param_env()); } (ty::ConstKind::Unevaluated(..), _) | (_, ty::ConstKind::Unevaluated(..)) if self.tcx.lazy_normalization() => @@ -291,24 +287,17 @@ fn unify_const_variable( &self, target_vid: ty::ConstVid<'tcx>, ct: ty::Const<'tcx>, + param_env: ty::ParamEnv<'tcx>, ) -> RelateResult<'tcx, ty::Const<'tcx>> { - let (for_universe, span) = { - let mut inner = self.inner.borrow_mut(); - let variable_table = &mut inner.const_unification_table(); - let var_value = variable_table.probe_value(target_vid); - match var_value.val { - ConstVariableValue::Known { value } => { - bug!("instantiating {:?} which has a known value {:?}", target_vid, value) - } - ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span), - } - }; - let value = ct.try_fold_with(&mut ConstInferUnifier { - infcx: self, - span, - for_universe, + let span = + self.inner.borrow_mut().const_unification_table().probe_value(target_vid).origin.span; + let Generalization { value, needs_wf: _ } = generalize::generalize( + self, + &mut CombineDelegate { infcx: self, span, param_env }, + ct, target_vid, - })?; + ty::Variance::Invariant, + )?; self.inner.borrow_mut().const_unification_table().union_value( target_vid, @@ -547,135 +536,3 @@ fn float_unification_error<'tcx>( let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v; TypeError::FloatMismatch(ExpectedFound::new(a_is_expected, a, b)) } - -struct ConstInferUnifier<'cx, 'tcx> { - infcx: &'cx InferCtxt<'tcx>, - - span: Span, - - for_universe: ty::UniverseIndex, - - /// The vid of the const variable that is in the process of being - /// instantiated; if we find this within the const we are folding, - /// that means we would have created a cyclic const. - target_vid: ty::ConstVid<'tcx>, -} - -impl<'tcx> FallibleTypeFolder> for ConstInferUnifier<'_, 'tcx> { - type Error = TypeError<'tcx>; - - fn interner(&self) -> TyCtxt<'tcx> { - self.infcx.tcx - } - - #[instrument(level = "debug", skip(self), ret)] - fn try_fold_ty(&mut self, t: Ty<'tcx>) -> Result, TypeError<'tcx>> { - match t.kind() { - &ty::Infer(ty::TyVar(vid)) => { - let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid); - let probe = self.infcx.inner.borrow_mut().type_variables().probe(vid); - match probe { - TypeVariableValue::Known { value: u } => { - debug!("ConstOccursChecker: known value {:?}", u); - u.try_fold_with(self) - } - TypeVariableValue::Unknown { universe } => { - if self.for_universe.can_name(universe) { - return Ok(t); - } - - let origin = - *self.infcx.inner.borrow_mut().type_variables().var_origin(vid); - let new_var_id = self - .infcx - .inner - .borrow_mut() - .type_variables() - .new_var(self.for_universe, origin); - Ok(self.interner().mk_ty_var(new_var_id)) - } - } - } - ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => Ok(t), - _ => t.try_super_fold_with(self), - } - } - - #[instrument(level = "debug", skip(self), ret)] - fn try_fold_region( - &mut self, - r: ty::Region<'tcx>, - ) -> Result, TypeError<'tcx>> { - debug!("ConstInferUnifier: r={:?}", r); - - match *r { - // Never make variables for regions bound within the type itself, - // nor for erased regions. - ty::ReLateBound(..) | ty::ReErased | ty::ReError(_) => { - return Ok(r); - } - - ty::RePlaceholder(..) - | ty::ReVar(..) - | ty::ReStatic - | ty::ReEarlyBound(..) - | ty::ReFree(..) => { - // see common code below - } - } - - let r_universe = self.infcx.universe_of_region(r); - if self.for_universe.can_name(r_universe) { - return Ok(r); - } else { - // FIXME: This is non-ideal because we don't give a - // very descriptive origin for this region variable. - Ok(self.infcx.next_region_var_in_universe(MiscVariable(self.span), self.for_universe)) - } - } - - #[instrument(level = "debug", skip(self), ret)] - fn try_fold_const(&mut self, c: ty::Const<'tcx>) -> Result, TypeError<'tcx>> { - match c.kind() { - ty::ConstKind::Infer(InferConst::Var(vid)) => { - // Check if the current unification would end up - // unifying `target_vid` with a const which contains - // an inference variable which is unioned with `target_vid`. - // - // Not doing so can easily result in stack overflows. - if self - .infcx - .inner - .borrow_mut() - .const_unification_table() - .unioned(self.target_vid, vid) - { - return Err(TypeError::CyclicConst(c)); - } - - let var_value = - self.infcx.inner.borrow_mut().const_unification_table().probe_value(vid); - match var_value.val { - ConstVariableValue::Known { value: u } => u.try_fold_with(self), - ConstVariableValue::Unknown { universe } => { - if self.for_universe.can_name(universe) { - Ok(c) - } else { - let new_var_id = - self.infcx.inner.borrow_mut().const_unification_table().new_key( - ConstVarValue { - origin: var_value.origin, - val: ConstVariableValue::Unknown { - universe: self.for_universe, - }, - }, - ); - Ok(self.interner().mk_const(new_var_id, c.ty())) - } - } - } - } - _ => c.try_super_fold_with(self), - } - } -} diff --git a/compiler/rustc_infer/src/infer/generalize.rs b/compiler/rustc_infer/src/infer/generalize.rs index c8562c84a17..8acfe638aa3 100644 --- a/compiler/rustc_infer/src/infer/generalize.rs +++ b/compiler/rustc_infer/src/infer/generalize.rs @@ -3,36 +3,44 @@ use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue}; use rustc_middle::ty::error::TypeError; use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation}; -use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitableExt}; +use rustc_middle::ty::{self, InferConst, Term, Ty, TyCtxt, TypeVisitableExt}; use rustc_span::Span; use crate::infer::nll_relate::TypeRelatingDelegate; use crate::infer::type_variable::TypeVariableValue; use crate::infer::{InferCtxt, RegionVariableOrigin}; -pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>>( +pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>, T: Into> + Relate<'tcx>>( infcx: &InferCtxt<'tcx>, delegate: &mut D, - ty: Ty<'tcx>, - for_vid: ty::TyVid, + term: T, + for_vid: impl Into>, ambient_variance: ty::Variance, -) -> RelateResult<'tcx, Generalization>> { - let for_universe = infcx.probe_ty_var(for_vid).unwrap_err(); - let for_vid_sub_root = infcx.inner.borrow_mut().type_variables().sub_root_var(for_vid); +) -> RelateResult<'tcx, Generalization> { + let (for_universe, root_vid) = match for_vid.into() { + ty::TermVid::Ty(ty_vid) => ( + infcx.probe_ty_var(ty_vid).unwrap_err(), + ty::TermVid::Ty(infcx.inner.borrow_mut().type_variables().sub_root_var(ty_vid)), + ), + ty::TermVid::Const(ct_vid) => ( + infcx.probe_const_var(ct_vid).unwrap_err(), + ty::TermVid::Const(infcx.inner.borrow_mut().const_unification_table().find(ct_vid)), + ), + }; let mut generalizer = Generalizer { infcx, delegate, ambient_variance, - for_vid_sub_root, + root_vid, for_universe, - root_ty: ty, + root_term: term.into(), needs_wf: false, cache: Default::default(), }; - assert!(!ty.has_escaping_bound_vars()); - let value = generalizer.relate(ty, ty)?; + assert!(!term.has_escaping_bound_vars()); + let value = generalizer.relate(term, term)?; let needs_wf = generalizer.needs_wf; Ok(Generalization { value, needs_wf }) } @@ -99,11 +107,8 @@ fn generalize_existential(&mut self, universe: ty::UniverseIndex) -> ty::Region< /// establishes `'0: 'x` as a constraint. /// /// [blog post]: https://is.gd/0hKvIr -struct Generalizer<'me, 'tcx, D> -where - D: GeneralizerDelegate<'tcx>, -{ - pub infcx: &'me InferCtxt<'tcx>, +struct Generalizer<'me, 'tcx, D> { + infcx: &'me InferCtxt<'tcx>, // An delegate used to abstract the behaviors of the three previous // generalizer-like implementations. @@ -116,14 +121,15 @@ struct Generalizer<'me, 'tcx, D> /// The vid of the type variable that is in the process of being /// instantiated. If we find this within the value we are folding, /// that means we would have created a cyclic value. - pub for_vid_sub_root: ty::TyVid, + root_vid: ty::TermVid<'tcx>, /// The universe of the type variable that is in the process of being /// instantiated. If we find anything that this universe cannot name, /// we reject the relation. for_universe: ty::UniverseIndex, - pub root_ty: Ty<'tcx>, + /// The root term (const or type) we're generalizing. Used for cycle errors. + root_term: Term<'tcx>, cache: SsoHashMap, Ty<'tcx>>, @@ -131,6 +137,15 @@ struct Generalizer<'me, 'tcx, D> needs_wf: bool, } +impl<'tcx, D> Generalizer<'_, 'tcx, D> { + fn cyclic_term_error(&self) -> TypeError<'tcx> { + match self.root_term.unpack() { + ty::TermKind::Ty(ty) => TypeError::CyclicTy(ty), + ty::TermKind::Const(ct) => TypeError::CyclicConst(ct), + } + } +} + impl<'tcx, D> TypeRelation<'tcx> for Generalizer<'_, 'tcx, D> where D: GeneralizerDelegate<'tcx>, @@ -226,10 +241,10 @@ fn tys(&mut self, t: Ty<'tcx>, t2: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { let mut inner = self.infcx.inner.borrow_mut(); let vid = inner.type_variables().root_var(vid); let sub_vid = inner.type_variables().sub_root_var(vid); - if sub_vid == self.for_vid_sub_root { + if TermVid::Ty(sub_vid) == self.root_vid { // If sub-roots are equal, then `for_vid` and // `vid` are related via subtyping. - Err(TypeError::CyclicTy(self.root_ty)) + Err(self.cyclic_term_error()) } else { let probe = inner.type_variables().probe(vid); match probe { @@ -363,6 +378,17 @@ fn consts( bug!("unexpected inference variable encountered in NLL generalization: {:?}", c); } ty::ConstKind::Infer(InferConst::Var(vid)) => { + // Check if the current unification would end up + // unifying `target_vid` with a const which contains + // an inference variable which is unioned with `target_vid`. + // + // Not doing so can easily result in stack overflows. + if TermVid::Const(self.infcx.inner.borrow_mut().const_unification_table().find(vid)) + == self.root_vid + { + return Err(self.cyclic_term_error()); + } + let mut inner = self.infcx.inner.borrow_mut(); let variable_table = &mut inner.const_unification_table(); let var_value = variable_table.probe_value(vid); diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index b414e1200cd..2fe0b2938ef 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -1070,6 +1070,24 @@ pub fn index(self) -> usize { } } +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub enum TermVid<'tcx> { + Ty(ty::TyVid), + Const(ty::ConstVid<'tcx>), +} + +impl From for TermVid<'_> { + fn from(value: ty::TyVid) -> Self { + TermVid::Ty(value) + } +} + +impl<'tcx> From> for TermVid<'tcx> { + fn from(value: ty::ConstVid<'tcx>) -> Self { + TermVid::Const(value) + } +} + /// This kind of predicate has no *direct* correspondent in the /// syntax, but it roughly corresponds to the syntactic forms: ///