implement (unused) matching solver

This commit is contained in:
Niko Matsakis 2022-06-15 05:55:05 -04:00
parent d203c13db2
commit c5ed318b22
7 changed files with 329 additions and 24 deletions

View File

@ -299,7 +299,7 @@ pub(crate) fn compute_regions<'cx, 'tcx>(
// Solve the region constraints.
let (closure_region_requirements, nll_errors) =
regioncx.solve(infcx, &body, polonius_output.clone());
regioncx.solve(infcx, param_env, &body, polonius_output.clone());
if !nll_errors.is_empty() {
// Suppress unhelpful extra errors in `infer_opaque_types`.

View File

@ -10,7 +10,8 @@ use rustc_hir::def_id::{DefId, CRATE_DEF_ID};
use rustc_hir::CRATE_HIR_ID;
use rustc_index::vec::IndexVec;
use rustc_infer::infer::canonical::QueryOutlivesConstraint;
use rustc_infer::infer::region_constraints::{GenericKind, VarInfos, VerifyBound};
use rustc_infer::infer::outlives::test_type_match;
use rustc_infer::infer::region_constraints::{GenericKind, VarInfos, VerifyBound, VerifyIfEq};
use rustc_infer::infer::{InferCtxt, NllRegionVariableOrigin, RegionVariableOrigin};
use rustc_middle::mir::{
Body, ClosureOutlivesRequirement, ClosureOutlivesSubject, ClosureRegionRequirements,
@ -18,6 +19,7 @@ use rustc_middle::mir::{
};
use rustc_middle::traits::ObligationCause;
use rustc_middle::traits::ObligationCauseCode;
use rustc_middle::ty::Region;
use rustc_middle::ty::{self, subst::SubstsRef, RegionVid, Ty, TyCtxt, TypeFoldable};
use rustc_span::Span;
@ -46,6 +48,7 @@ pub mod values;
pub struct RegionInferenceContext<'tcx> {
pub var_infos: VarInfos,
/// Contains the definition for every region variable. Region
/// variables are identified by their index (`RegionVid`). The
/// definition contains information about where the region came
@ -559,6 +562,7 @@ impl<'tcx> RegionInferenceContext<'tcx> {
pub(super) fn solve(
&mut self,
infcx: &InferCtxt<'_, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
body: &Body<'tcx>,
polonius_output: Option<Rc<PoloniusOutput>>,
) -> (Option<ClosureRegionRequirements<'tcx>>, RegionErrors<'tcx>) {
@ -574,7 +578,13 @@ impl<'tcx> RegionInferenceContext<'tcx> {
// eagerly.
let mut outlives_requirements = infcx.tcx.is_typeck_child(mir_def_id).then(Vec::new);
self.check_type_tests(infcx, body, outlives_requirements.as_mut(), &mut errors_buffer);
self.check_type_tests(
infcx,
param_env,
body,
outlives_requirements.as_mut(),
&mut errors_buffer,
);
// In Polonius mode, the errors about missing universal region relations are in the output
// and need to be emitted or propagated. Otherwise, we need to check whether the
@ -823,6 +833,7 @@ impl<'tcx> RegionInferenceContext<'tcx> {
fn check_type_tests(
&self,
infcx: &InferCtxt<'_, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
body: &Body<'tcx>,
mut propagated_outlives_requirements: Option<&mut Vec<ClosureOutlivesRequirement<'tcx>>>,
errors_buffer: &mut RegionErrors<'tcx>,
@ -839,7 +850,8 @@ impl<'tcx> RegionInferenceContext<'tcx> {
let generic_ty = type_test.generic_kind.to_ty(tcx);
if self.eval_verify_bound(
tcx,
infcx,
param_env,
body,
generic_ty,
type_test.lower_bound,
@ -851,6 +863,7 @@ impl<'tcx> RegionInferenceContext<'tcx> {
if let Some(propagated_outlives_requirements) = &mut propagated_outlives_requirements {
if self.try_promote_type_test(
infcx,
param_env,
body,
type_test,
propagated_outlives_requirements,
@ -907,6 +920,7 @@ impl<'tcx> RegionInferenceContext<'tcx> {
fn try_promote_type_test(
&self,
infcx: &InferCtxt<'_, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
body: &Body<'tcx>,
type_test: &TypeTest<'tcx>,
propagated_outlives_requirements: &mut Vec<ClosureOutlivesRequirement<'tcx>>,
@ -938,7 +952,14 @@ impl<'tcx> RegionInferenceContext<'tcx> {
// where `ur` is a local bound -- we are sometimes in a
// position to prove things that our caller cannot. See
// #53570 for an example.
if self.eval_verify_bound(tcx, body, generic_ty, ur, &type_test.verify_bound) {
if self.eval_verify_bound(
infcx,
param_env,
body,
generic_ty,
ur,
&type_test.verify_bound,
) {
continue;
}
@ -1161,7 +1182,8 @@ impl<'tcx> RegionInferenceContext<'tcx> {
/// `point`.
fn eval_verify_bound(
&self,
tcx: TyCtxt<'tcx>,
infcx: &InferCtxt<'_, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
body: &Body<'tcx>,
generic_ty: Ty<'tcx>,
lower_bound: RegionVid,
@ -1170,14 +1192,13 @@ impl<'tcx> RegionInferenceContext<'tcx> {
debug!("eval_verify_bound(lower_bound={:?}, verify_bound={:?})", lower_bound, verify_bound);
match verify_bound {
VerifyBound::IfEq(test_ty, verify_bound1) => self.eval_if_eq(
tcx,
body,
generic_ty,
lower_bound,
*test_ty,
&VerifyBound::OutlivedBy(*verify_bound1),
),
VerifyBound::IfEq(test_ty, verify_bound1) => {
self.eval_if_eq(infcx, generic_ty, lower_bound, *test_ty, *verify_bound1)
}
VerifyBound::IfEqBound(verify_if_eq_b) => {
self.eval_if_eq_bound(infcx, param_env, generic_ty, lower_bound, *verify_if_eq_b)
}
VerifyBound::IsEmpty => {
let lower_bound_scc = self.constraint_sccs.scc(lower_bound);
@ -1190,33 +1211,71 @@ impl<'tcx> RegionInferenceContext<'tcx> {
}
VerifyBound::AnyBound(verify_bounds) => verify_bounds.iter().any(|verify_bound| {
self.eval_verify_bound(tcx, body, generic_ty, lower_bound, verify_bound)
self.eval_verify_bound(
infcx,
param_env,
body,
generic_ty,
lower_bound,
verify_bound,
)
}),
VerifyBound::AllBounds(verify_bounds) => verify_bounds.iter().all(|verify_bound| {
self.eval_verify_bound(tcx, body, generic_ty, lower_bound, verify_bound)
self.eval_verify_bound(
infcx,
param_env,
body,
generic_ty,
lower_bound,
verify_bound,
)
}),
}
}
fn eval_if_eq(
&self,
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
infcx: &InferCtxt<'_, 'tcx>,
generic_ty: Ty<'tcx>,
lower_bound: RegionVid,
test_ty: Ty<'tcx>,
verify_bound: &VerifyBound<'tcx>,
verify_bound: Region<'tcx>,
) -> bool {
let generic_ty_normalized = self.normalize_to_scc_representatives(tcx, generic_ty);
let test_ty_normalized = self.normalize_to_scc_representatives(tcx, test_ty);
let generic_ty_normalized = self.normalize_to_scc_representatives(infcx.tcx, generic_ty);
let test_ty_normalized = self.normalize_to_scc_representatives(infcx.tcx, test_ty);
if generic_ty_normalized == test_ty_normalized {
self.eval_verify_bound(tcx, body, generic_ty, lower_bound, verify_bound)
let verify_bound_vid = self.to_region_vid(verify_bound);
self.eval_outlives(verify_bound_vid, lower_bound)
} else {
false
}
}
fn eval_if_eq_bound(
&self,
infcx: &InferCtxt<'_, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
generic_ty: Ty<'tcx>,
lower_bound: RegionVid,
verify_if_eq_b: ty::Binder<'tcx, VerifyIfEq<'tcx>>,
) -> bool {
let generic_ty = self.normalize_to_scc_representatives(infcx.tcx, generic_ty);
let verify_if_eq_b = self.normalize_to_scc_representatives(infcx.tcx, verify_if_eq_b);
match test_type_match::extract_verify_if_eq_bound(
infcx.tcx,
param_env,
&verify_if_eq_b,
generic_ty,
) {
Some(r) => {
let r_vid = self.to_region_vid(r);
self.eval_outlives(r_vid, lower_bound)
}
None => false,
}
}
/// This is a conservative normalization procedure. It takes every
/// free region in `value` and replaces it with the
/// "representative" of its SCC (see `scc_representatives` field).

View File

@ -22,6 +22,8 @@ use rustc_middle::ty::{Region, RegionVid};
use rustc_span::Span;
use std::fmt;
use super::outlives::test_type_match;
/// This function performs lexical region resolution given a complete
/// set of constraints and variable origins. It performs a fixed-point
/// iteration to find region values which satisfy all constraints,
@ -29,12 +31,13 @@ use std::fmt;
/// all the variables as well as a set of errors that must be reported.
#[instrument(level = "debug", skip(region_rels, var_infos, data))]
pub(crate) fn resolve<'tcx>(
param_env: ty::ParamEnv<'tcx>,
region_rels: &RegionRelations<'_, 'tcx>,
var_infos: VarInfos,
data: RegionConstraintData<'tcx>,
) -> (LexicalRegionResolutions<'tcx>, Vec<RegionResolutionError<'tcx>>) {
let mut errors = vec![];
let mut resolver = LexicalResolver { region_rels, var_infos, data };
let mut resolver = LexicalResolver { param_env, region_rels, var_infos, data };
let values = resolver.infer_variable_values(&mut errors);
(values, errors)
}
@ -100,6 +103,7 @@ struct RegionAndOrigin<'tcx> {
type RegionGraph<'tcx> = Graph<(), Constraint<'tcx>>;
struct LexicalResolver<'cx, 'tcx> {
param_env: ty::ParamEnv<'tcx>,
region_rels: &'cx RegionRelations<'cx, 'tcx>,
var_infos: VarInfos,
data: RegionConstraintData<'tcx>,
@ -823,6 +827,21 @@ impl<'cx, 'tcx> LexicalResolver<'cx, 'tcx> {
&& self.bound_is_met(&VerifyBound::OutlivedBy(*r), var_values, generic_ty, min)
}
VerifyBound::IfEqBound(verify_if_eq_b) => {
match test_type_match::extract_verify_if_eq_bound(
self.tcx(),
self.param_env,
verify_if_eq_b,
generic_ty,
) {
Some(r) => {
self.bound_is_met(&VerifyBound::OutlivedBy(r), var_values, generic_ty, min)
}
None => false,
}
}
VerifyBound::OutlivedBy(r) => {
self.sub_concrete_regions(min, var_values.normalize(self.tcx(), *r))
}

View File

@ -1290,7 +1290,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
&RegionRelations::new(self.tcx, region_context, outlives_env.free_region_map());
let (lexical_region_resolutions, errors) =
lexical_region_resolve::resolve(region_rels, var_infos, data);
lexical_region_resolve::resolve(outlives_env.param_env, region_rels, var_infos, data);
let old_value = self.lexical_region_resolutions.replace(Some(lexical_region_resolutions));
assert!(old_value.is_none());

View File

@ -3,6 +3,7 @@
pub mod components;
pub mod env;
pub mod obligations;
pub mod test_type_match;
pub mod verify;
use rustc_middle::traits::query::OutlivesBound;

View File

@ -0,0 +1,179 @@
use std::collections::hash_map::Entry;
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::ty::TypeFoldable;
use rustc_middle::ty::{
self,
error::TypeError,
relate::{self, Relate, RelateResult, TypeRelation},
Ty, TyCtxt,
};
use crate::infer::region_constraints::VerifyIfEq;
/// Given a "verify-if-eq" type test like:
///
/// exists<'a...> {
/// verify_if_eq(some_type, bound_region)
/// }
///
/// and the type `test_ty` that the type test is being tested against,
/// returns:
///
/// * `None` if `some_type` cannot be made equal to `test_ty`,
/// no matter the values of the variables in `exists`.
/// * `Some(r)` with a suitable bound (typically the value of `bound_region`, modulo
/// any bound existential variables, which will be substituted) for the
/// type under test.
///
/// NB: This function uses a simplistic, syntactic version of type equality.
/// In other words, it may spuriously return `None` even if the type-under-test
/// is in fact equal to `some_type`. In practice, though, this is used on types
/// that are either projections like `T::Item` or `T` and it works fine, but it
/// could have trouble when complex types with higher-ranked binders and the
/// like are used. This is a particular challenge since this function is invoked
/// very late in inference and hence cannot make use of the normal inference
/// machinery.
pub fn extract_verify_if_eq_bound<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
verify_if_eq_b: &ty::Binder<'tcx, VerifyIfEq<'tcx>>,
test_ty: Ty<'tcx>,
) -> Option<ty::Region<'tcx>> {
assert!(!verify_if_eq_b.has_escaping_bound_vars());
let mut m = Match::new(tcx, param_env);
let verify_if_eq = verify_if_eq_b.skip_binder();
m.relate(verify_if_eq.ty, test_ty).ok()?;
if let ty::RegionKind::ReLateBound(depth, br) = verify_if_eq.bound.kind() {
assert!(depth == ty::INNERMOST);
match m.map.get(&br) {
Some(&r) => Some(r),
None => {
// If there is no mapping, then this region is unconstrained.
// In that case, we escalate to `'static`.
Some(tcx.lifetimes.re_static)
}
}
} else {
// The region does not contain any inference variables.
Some(verify_if_eq.bound)
}
}
struct Match<'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
pattern_depth: ty::DebruijnIndex,
map: FxHashMap<ty::BoundRegion, ty::Region<'tcx>>,
}
impl<'tcx> Match<'tcx> {
fn new(tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>) -> Match<'tcx> {
Match { tcx, param_env, pattern_depth: ty::INNERMOST, map: FxHashMap::default() }
}
}
impl<'tcx> Match<'tcx> {
/// Creates the "Error" variant that signals "no match".
fn no_match<T>(&self) -> RelateResult<'tcx, T> {
Err(TypeError::Mismatch)
}
/// Binds the pattern variable `br` to `value`; returns an `Err` if the pattern
/// is already bound to a different value.
fn bind(
&mut self,
br: ty::BoundRegion,
value: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
match self.map.entry(br) {
Entry::Occupied(entry) => {
if *entry.get() == value {
Ok(value)
} else {
self.no_match()
}
}
Entry::Vacant(entry) => {
entry.insert(value);
Ok(value)
}
}
}
}
impl<'tcx> TypeRelation<'tcx> for Match<'tcx> {
fn tag(&self) -> &'static str {
"Match"
}
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
fn param_env(&self) -> ty::ParamEnv<'tcx> {
self.param_env
}
fn a_is_expected(&self) -> bool {
true
} // irrelevant
fn relate_with_variance<T: Relate<'tcx>>(
&mut self,
_: ty::Variance,
_: ty::VarianceDiagInfo<'tcx>,
a: T,
b: T,
) -> RelateResult<'tcx, T> {
self.relate(a, b)
}
#[instrument(skip(self), level = "debug")]
fn regions(
&mut self,
pattern: ty::Region<'tcx>,
value: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
if let ty::RegionKind::ReLateBound(depth, br) = pattern.kind() && depth == self.pattern_depth {
self.bind(br, pattern)
} else if pattern == value {
Ok(pattern)
} else {
self.no_match()
}
}
fn tys(&mut self, pattern: Ty<'tcx>, value: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
if pattern == value {
return Ok(pattern);
} else {
relate::super_relate_tys(self, pattern, value)
}
}
fn consts(
&mut self,
pattern: ty::Const<'tcx>,
value: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>> {
debug!("{}.consts({:?}, {:?})", self.tag(), pattern, value);
if pattern == value {
return Ok(pattern);
} else {
relate::super_relate_consts(self, pattern, value)
}
}
fn binders<T>(
&mut self,
pattern: ty::Binder<'tcx, T>,
value: ty::Binder<'tcx, T>,
) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
where
T: Relate<'tcx>,
{
self.pattern_depth.shift_in(1);
let result = Ok(pattern.rebind(self.relate(pattern.skip_binder(), value.skip_binder())?));
self.pattern_depth.shift_out(1);
result
}
}

View File

@ -226,6 +226,8 @@ pub enum VerifyBound<'tcx> {
/// (after inference), and `'a: min`, then `G: min`.
IfEq(Ty<'tcx>, Region<'tcx>),
IfEqBound(ty::Binder<'tcx, VerifyIfEq<'tcx>>),
/// Given a region `R`, expands to the function:
///
/// ```ignore (pseudo-rust)
@ -267,6 +269,49 @@ pub enum VerifyBound<'tcx> {
AllBounds(Vec<VerifyBound<'tcx>>),
}
/// Given a kind K and a bound B, expands to a function like the
/// following, where `G` is the generic for which this verify
/// bound was created:
///
/// ```ignore (pseudo-rust)
/// fn(min) -> bool {
/// if G == K {
/// B(min)
/// } else {
/// false
/// }
/// }
/// ```
///
/// In other words, if the generic `G` that we are checking is
/// equal to `K`, then check the associated verify bound
/// (otherwise, false).
///
/// This is used when we have something in the environment that
/// may or may not be relevant, depending on the region inference
/// results. For example, we may have `where <T as
/// Trait<'a>>::Item: 'b` in our where-clauses. If we are
/// generating the verify-bound for `<T as Trait<'0>>::Item`, then
/// this where-clause is only relevant if `'0` winds up inferred
/// to `'a`.
///
/// So we would compile to a verify-bound like
///
/// ```ignore (illustrative)
/// IfEq(<T as Trait<'a>>::Item, AnyRegion('a))
/// ```
///
/// meaning, if the subject G is equal to `<T as Trait<'a>>::Item`
/// (after inference), and `'a: min`, then `G: min`.
#[derive(Debug, Copy, Clone, TypeFoldable)]
pub struct VerifyIfEq<'tcx> {
/// Type which must match the generic `G`
pub ty: Ty<'tcx>,
/// Bound that applies if `ty` is equal.
pub bound: Region<'tcx>,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) struct TwoRegions<'tcx> {
a: Region<'tcx>,
@ -761,6 +806,7 @@ impl<'tcx> VerifyBound<'tcx> {
pub fn must_hold(&self) -> bool {
match self {
VerifyBound::IfEq(..) => false,
VerifyBound::IfEqBound(..) => false,
VerifyBound::OutlivedBy(re) => re.is_static(),
VerifyBound::IsEmpty => false,
VerifyBound::AnyBound(bs) => bs.iter().any(|b| b.must_hold()),
@ -771,6 +817,7 @@ impl<'tcx> VerifyBound<'tcx> {
pub fn cannot_hold(&self) -> bool {
match self {
VerifyBound::IfEq(_, _) => false,
VerifyBound::IfEqBound(..) => false,
VerifyBound::IsEmpty => false,
VerifyBound::OutlivedBy(_) => false,
VerifyBound::AnyBound(bs) => bs.iter().all(|b| b.cannot_hold()),