Track place validity

This commit is contained in:
Nadrieril 2023-11-18 03:52:54 +01:00
parent 1978168c13
commit f5dbb54648
2 changed files with 216 additions and 30 deletions

View File

@ -42,7 +42,7 @@ pub(crate) fn check_match(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Result<(), Err
for param in thir.params.iter() {
if let Some(box ref pattern) = param.pat {
visitor.check_binding_is_irrefutable(pattern, "function argument", None);
visitor.check_binding_is_irrefutable(pattern, "function argument", None, None);
}
}
visitor.error
@ -254,10 +254,11 @@ fn visit_land_rhs(
self.with_lint_level(lint_level, |this| this.visit_land_rhs(&this.thir[value]))
}
ExprKind::Let { box ref pat, expr } => {
let expr = &self.thir()[expr];
self.with_let_source(LetSource::None, |this| {
this.visit_expr(&this.thir()[expr]);
this.visit_expr(expr);
});
Ok(Some((ex.span, self.is_let_irrefutable(pat)?)))
Ok(Some((ex.span, self.is_let_irrefutable(pat, Some(expr))?)))
}
_ => {
self.with_let_source(LetSource::None, |this| {
@ -287,35 +288,114 @@ fn lower_pattern(
}
}
/// Inspects the match scrutinee expression to determine whether the place it evaluates to may
/// hold invalid data.
fn is_known_valid_scrutinee(&self, scrutinee: &Expr<'tcx>) -> bool {
use ExprKind::*;
match &scrutinee.kind {
// Both pointers and references can validly point to a place with invalid data.
Deref { .. } => false,
// Inherit validity of the parent place, unless the parent is an union.
Field { lhs, .. } => {
let lhs = &self.thir()[*lhs];
match lhs.ty.kind() {
ty::Adt(def, _) if def.is_union() => false,
_ => self.is_known_valid_scrutinee(lhs),
}
}
// Essentially a field access.
Index { lhs, .. } => {
let lhs = &self.thir()[*lhs];
self.is_known_valid_scrutinee(lhs)
}
// No-op.
Scope { value, .. } => self.is_known_valid_scrutinee(&self.thir()[*value]),
// Casts don't cause a load.
NeverToAny { source }
| Cast { source }
| Use { source }
| PointerCoercion { source, .. }
| PlaceTypeAscription { source, .. }
| ValueTypeAscription { source, .. } => {
self.is_known_valid_scrutinee(&self.thir()[*source])
}
// These diverge.
Become { .. } | Break { .. } | Continue { .. } | Return { .. } => true,
// These are statements that evaluate to `()`.
Assign { .. } | AssignOp { .. } | InlineAsm { .. } | Let { .. } => true,
// These evaluate to a value.
AddressOf { .. }
| Adt { .. }
| Array { .. }
| Binary { .. }
| Block { .. }
| Borrow { .. }
| Box { .. }
| Call { .. }
| Closure { .. }
| ConstBlock { .. }
| ConstParam { .. }
| If { .. }
| Literal { .. }
| LogicalOp { .. }
| Loop { .. }
| Match { .. }
| NamedConst { .. }
| NonHirLiteral { .. }
| OffsetOf { .. }
| Repeat { .. }
| StaticRef { .. }
| ThreadLocalRef { .. }
| Tuple { .. }
| Unary { .. }
| UpvarRef { .. }
| VarRef { .. }
| ZstLiteral { .. }
| Yield { .. } => true,
}
}
fn new_cx(
&self,
refutability: RefutableFlag,
match_span: Option<Span>,
whole_match_span: Option<Span>,
scrutinee: Option<&Expr<'tcx>>,
scrut_span: Span,
) -> MatchCheckCtxt<'p, 'tcx> {
let refutable = match refutability {
Irrefutable => false,
Refutable => true,
};
// If we don't have a scrutinee we're either a function parameter or a `let x;`. Both cases
// require validity.
let known_valid_scrutinee =
scrutinee.map(|scrut| self.is_known_valid_scrutinee(scrut)).unwrap_or(true);
MatchCheckCtxt {
tcx: self.tcx,
param_env: self.param_env,
module: self.tcx.parent_module(self.lint_level).to_def_id(),
pattern_arena: self.pattern_arena,
match_lint_level: self.lint_level,
match_span,
whole_match_span,
scrut_span,
refutable,
known_valid_scrutinee,
}
}
#[instrument(level = "trace", skip(self))]
fn check_let(&mut self, pat: &Pat<'tcx>, scrutinee: Option<ExprId>, span: Span) {
assert!(self.let_source != LetSource::None);
let scrut = scrutinee.map(|id| &self.thir[id]);
if let LetSource::PlainLet = self.let_source {
self.check_binding_is_irrefutable(pat, "local binding", Some(span))
self.check_binding_is_irrefutable(pat, "local binding", scrut, Some(span))
} else {
let Ok(refutability) = self.is_let_irrefutable(pat) else { return };
let Ok(refutability) = self.is_let_irrefutable(pat, scrut) else { return };
if matches!(refutability, Irrefutable) {
report_irrefutable_let_patterns(
self.tcx,
@ -336,7 +416,7 @@ fn check_match(
expr_span: Span,
) {
let scrut = &self.thir[scrut];
let cx = self.new_cx(Refutable, Some(expr_span), scrut.span);
let cx = self.new_cx(Refutable, Some(expr_span), Some(scrut), scrut.span);
let mut tarms = Vec::with_capacity(arms.len());
for &arm in arms {
@ -377,7 +457,12 @@ fn check_match(
debug_assert_eq!(pat.span.desugaring_kind(), Some(DesugaringKind::ForLoop));
let PatKind::Variant { ref subpatterns, .. } = pat.kind else { bug!() };
let [pat_field] = &subpatterns[..] else { bug!() };
self.check_binding_is_irrefutable(&pat_field.pattern, "`for` loop binding", None);
self.check_binding_is_irrefutable(
&pat_field.pattern,
"`for` loop binding",
None,
None,
);
} else {
self.error = Err(report_non_exhaustive_match(
&cx, self.thir, scrut_ty, scrut.span, witnesses, arms, expr_span,
@ -457,16 +542,21 @@ fn analyze_binding(
&mut self,
pat: &Pat<'tcx>,
refutability: RefutableFlag,
scrut: Option<&Expr<'tcx>>,
) -> Result<(MatchCheckCtxt<'p, 'tcx>, UsefulnessReport<'p, 'tcx>), ErrorGuaranteed> {
let cx = self.new_cx(refutability, None, pat.span);
let cx = self.new_cx(refutability, None, scrut, pat.span);
let pat = self.lower_pattern(&cx, pat)?;
let arms = [MatchArm { pat, hir_id: self.lint_level, has_guard: false }];
let report = compute_match_usefulness(&cx, &arms, pat.ty());
Ok((cx, report))
}
fn is_let_irrefutable(&mut self, pat: &Pat<'tcx>) -> Result<RefutableFlag, ErrorGuaranteed> {
let (cx, report) = self.analyze_binding(pat, Refutable)?;
fn is_let_irrefutable(
&mut self,
pat: &Pat<'tcx>,
scrut: Option<&Expr<'tcx>>,
) -> Result<RefutableFlag, ErrorGuaranteed> {
let (cx, report) = self.analyze_binding(pat, Refutable, scrut)?;
// Report if the pattern is unreachable, which can only occur when the type is uninhabited.
// This also reports unreachable sub-patterns.
report_arm_reachability(&cx, &report);
@ -476,10 +566,16 @@ fn is_let_irrefutable(&mut self, pat: &Pat<'tcx>) -> Result<RefutableFlag, Error
}
#[instrument(level = "trace", skip(self))]
fn check_binding_is_irrefutable(&mut self, pat: &Pat<'tcx>, origin: &str, sp: Option<Span>) {
fn check_binding_is_irrefutable(
&mut self,
pat: &Pat<'tcx>,
origin: &str,
scrut: Option<&Expr<'tcx>>,
sp: Option<Span>,
) {
let pattern_ty = pat.ty;
let Ok((cx, report)) = self.analyze_binding(pat, Irrefutable) else { return };
let Ok((cx, report)) = self.analyze_binding(pat, Irrefutable, scrut) else { return };
let witnesses = report.non_exhaustiveness_witnesses;
if witnesses.is_empty() {
// The pattern is irrefutable.

View File

@ -551,6 +551,7 @@
//! I (Nadrieril) prefer to put new tests in `ui/pattern/usefulness` unless there's a specific
//! reason not to, for example if they crucially depend on a particular feature like `or_patterns`.
use self::ValidityConstraint::*;
use super::deconstruct_pat::{
Constructor, ConstructorSet, DeconstructedPat, IntRange, MaybeInfiniteInt, SplitConstructorSet,
WitnessPat,
@ -587,11 +588,14 @@ pub(crate) struct MatchCheckCtxt<'p, 'tcx> {
/// Lint level at the match.
pub(crate) match_lint_level: HirId,
/// The span of the whole match, if applicable.
pub(crate) match_span: Option<Span>,
pub(crate) whole_match_span: Option<Span>,
/// Span of the scrutinee.
pub(crate) scrut_span: Span,
/// Only produce `NON_EXHAUSTIVE_OMITTED_PATTERNS` lint on refutable patterns.
pub(crate) refutable: bool,
/// Whether the data at the scrutinee is known to be valid. This is false if the scrutinee comes
/// from a union field, a pointer deref, or a reference deref (pending opsem decisions).
pub(crate) known_valid_scrutinee: bool,
}
impl<'a, 'tcx> MatchCheckCtxt<'a, 'tcx> {
@ -620,12 +624,63 @@ pub(super) struct PatCtxt<'a, 'p, 'tcx> {
pub(super) is_top_level: bool,
}
impl<'a, 'p, 'tcx> PatCtxt<'a, 'p, 'tcx> {
/// A `PatCtxt` when code other than `is_useful` needs one.
fn new_dummy(cx: &'a MatchCheckCtxt<'p, 'tcx>, ty: Ty<'tcx>) -> Self {
PatCtxt { cx, ty, is_top_level: false }
}
}
impl<'a, 'p, 'tcx> fmt::Debug for PatCtxt<'a, 'p, 'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PatCtxt").field("ty", &self.ty).finish()
}
}
/// In the matrix, tracks whether a given place (aka column) is known to contain a valid value or
/// not.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) enum ValidityConstraint {
ValidOnly,
MaybeInvalid,
}
impl ValidityConstraint {
pub(super) fn from_bool(is_valid_only: bool) -> Self {
if is_valid_only { ValidOnly } else { MaybeInvalid }
}
/// If the place has validity given by `self` and we read that the value at the place has
/// constructor `ctor`, this computes what we can assume about the validity of the constructor
/// fields.
///
/// Pending further opsem decisions, the current behavior is: validity is preserved, except
/// under `&` where validity is reset to `MaybeInvalid`.
pub(super) fn specialize<'tcx>(
self,
pcx: &PatCtxt<'_, '_, 'tcx>,
ctor: &Constructor<'tcx>,
) -> Self {
// We preserve validity except when we go under a reference.
if matches!(ctor, Constructor::Single) && matches!(pcx.ty.kind(), ty::Ref(..)) {
// Validity of `x: &T` does not imply validity of `*x: T`.
MaybeInvalid
} else {
self
}
}
}
impl fmt::Display for ValidityConstraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
ValidOnly => "",
MaybeInvalid => "?",
};
write!(f, "{s}")
}
}
/// Represents a pattern-tuple under investigation.
#[derive(Clone)]
struct PatStack<'p, 'tcx> {
@ -770,10 +825,15 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
/// the matrix will correspond to `scrutinee.0.Some.0` and the second column to `scrutinee.1`.
#[derive(Clone)]
struct Matrix<'p, 'tcx> {
/// Vector of rows. The rows must form a rectangular 2D array. Moreover, all the patterns of
/// each column must have the same type. Each column corresponds to a place within the
/// scrutinee.
rows: Vec<MatrixRow<'p, 'tcx>>,
/// Stores an extra fictitious row full of wildcards. Mostly used to keep track of the type of
/// each column. This must obey the same invariants as the real rows.
wildcard_row: PatStack<'p, 'tcx>,
/// Track for each column/place whether it contains a known valid value.
place_validity: SmallVec<[ValidityConstraint; 2]>,
}
impl<'p, 'tcx> Matrix<'p, 'tcx> {
@ -791,10 +851,22 @@ fn expand_and_push(&mut self, row: MatrixRow<'p, 'tcx>) {
}
/// Build a new matrix from an iterator of `MatchArm`s.
fn new(cx: &MatchCheckCtxt<'p, 'tcx>, arms: &[MatchArm<'p, 'tcx>], scrut_ty: Ty<'tcx>) -> Self {
fn new<'a>(
cx: &MatchCheckCtxt<'p, 'tcx>,
arms: &[MatchArm<'p, 'tcx>],
scrut_ty: Ty<'tcx>,
scrut_validity: ValidityConstraint,
) -> Self
where
'p: 'a,
{
let wild_pattern = cx.pattern_arena.alloc(DeconstructedPat::wildcard(scrut_ty, DUMMY_SP));
let wildcard_row = PatStack::from_pattern(wild_pattern);
let mut matrix = Matrix { rows: Vec::with_capacity(arms.len()), wildcard_row };
let mut matrix = Matrix {
rows: Vec::with_capacity(arms.len()),
wildcard_row,
place_validity: smallvec![scrut_validity],
};
for (row_id, arm) in arms.iter().enumerate() {
let v = MatrixRow {
pats: PatStack::from_pattern(arm.pat),
@ -858,7 +930,13 @@ fn specialize_constructor(
ctor: &Constructor<'tcx>,
) -> Matrix<'p, 'tcx> {
let wildcard_row = self.wildcard_row.pop_head_constructor(pcx, ctor);
let mut matrix = Matrix { rows: Vec::new(), wildcard_row };
let new_validity = self.place_validity[0].specialize(pcx, ctor);
let new_place_validity = std::iter::repeat(new_validity)
.take(ctor.arity(pcx))
.chain(self.place_validity[1..].iter().copied())
.collect();
let mut matrix =
Matrix { rows: Vec::new(), wildcard_row, place_validity: new_place_validity };
for (i, row) in self.rows().enumerate() {
if ctor.is_covered_by(pcx, row.head().ctor()) {
let new_row = row.pop_head_constructor(pcx, ctor, i);
@ -877,27 +955,38 @@ fn specialize_constructor(
/// + true + [Second(true)] +
/// + false + [_] +
/// + _ + [_, _, tail @ ..] +
/// | ✓ | ? | // column validity
/// ```
impl<'p, 'tcx> fmt::Debug for Matrix<'p, 'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\n")?;
let Matrix { rows, .. } = self;
let pretty_printed_matrix: Vec<Vec<String>> =
rows.iter().map(|row| row.iter().map(|pat| format!("{pat:?}")).collect()).collect();
let mut pretty_printed_matrix: Vec<Vec<String>> = self
.rows
.iter()
.map(|row| row.iter().map(|pat| format!("{pat:?}")).collect())
.collect();
pretty_printed_matrix
.push(self.place_validity.iter().map(|validity| format!("{validity}")).collect());
let column_count = rows.iter().map(|row| row.len()).next().unwrap_or(0);
assert!(rows.iter().all(|row| row.len() == column_count));
let column_count = self.column_count();
assert!(self.rows.iter().all(|row| row.len() == column_count));
assert!(self.place_validity.len() == column_count);
let column_widths: Vec<usize> = (0..column_count)
.map(|col| pretty_printed_matrix.iter().map(|row| row[col].len()).max().unwrap_or(0))
.collect();
for row in pretty_printed_matrix {
write!(f, "+")?;
for (row_i, row) in pretty_printed_matrix.into_iter().enumerate() {
let is_validity_row = row_i == self.rows.len();
let sep = if is_validity_row { "|" } else { "+" };
write!(f, "{sep}")?;
for (column, pat_str) in row.into_iter().enumerate() {
write!(f, " ")?;
write!(f, "{:1$}", pat_str, column_widths[column])?;
write!(f, " +")?;
write!(f, " {sep}")?;
}
if is_validity_row {
write!(f, " // column validity")?;
}
write!(f, "\n")?;
}
@ -1287,7 +1376,7 @@ fn collect_nonexhaustive_missing_variants<'p, 'tcx>(
let Some(ty) = column.head_ty() else {
return Vec::new();
};
let pcx = &PatCtxt { cx, ty, is_top_level: false };
let pcx = &PatCtxt::new_dummy(cx, ty);
let set = column.analyze_ctors(pcx);
if set.present.is_empty() {
@ -1336,7 +1425,7 @@ fn lint_overlapping_range_endpoints<'p, 'tcx>(
let Some(ty) = column.head_ty() else {
return;
};
let pcx = &PatCtxt { cx, ty, is_top_level: false };
let pcx = &PatCtxt::new_dummy(cx, ty);
let set = column.analyze_ctors(pcx);
@ -1439,7 +1528,8 @@ pub(crate) fn compute_match_usefulness<'p, 'tcx>(
arms: &[MatchArm<'p, 'tcx>],
scrut_ty: Ty<'tcx>,
) -> UsefulnessReport<'p, 'tcx> {
let mut matrix = Matrix::new(cx, arms, scrut_ty);
let scrut_validity = ValidityConstraint::from_bool(cx.known_valid_scrutinee);
let mut matrix = Matrix::new(cx, arms, scrut_ty, scrut_validity);
let non_exhaustiveness_witnesses = compute_exhaustiveness_and_usefulness(cx, &mut matrix, true);
let non_exhaustiveness_witnesses: Vec<_> = non_exhaustiveness_witnesses.single_column();
@ -1496,7 +1586,7 @@ pub(crate) fn compute_match_usefulness<'p, 'tcx>(
if !matches!(lint_level, rustc_session::lint::Level::Allow) {
let decorator = NonExhaustiveOmittedPatternLintOnArm {
lint_span: lint_level_source.span(),
suggest_lint_on_match: cx.match_span.map(|span| span.shrink_to_lo()),
suggest_lint_on_match: cx.whole_match_span.map(|span| span.shrink_to_lo()),
lint_level: lint_level.as_str(),
lint_name: "non_exhaustive_omitted_patterns",
};