Properly infer types with type casts

This commit is contained in:
Ryo Yoshida 2023-07-12 23:46:23 +09:00
parent 75ac37f317
commit 074488b290
No known key found for this signature in database
GPG Key ID: E25698A930586171
5 changed files with 112 additions and 28 deletions

View File

@ -13,6 +13,15 @@
//! to certain types. To record this, we use the union-find implementation from
//! the `ena` crate, which is extracted from rustc.
mod cast;
pub(crate) mod closure;
mod coerce;
mod expr;
mod mutability;
mod pat;
mod path;
pub(crate) mod unify;
use std::{convert::identity, ops::Index};
use chalk_ir::{
@ -60,15 +69,8 @@
#[allow(unreachable_pub)]
pub use unify::could_unify;
pub(crate) use self::closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
pub(crate) mod unify;
mod path;
mod expr;
mod pat;
mod coerce;
pub(crate) mod closure;
mod mutability;
use cast::CastCheck;
pub(crate) use closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
/// The entry point of type inference.
pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
@ -508,6 +510,8 @@ pub(crate) struct InferenceContext<'a> {
diverges: Diverges,
breakables: Vec<BreakableContext>,
deferred_cast_checks: Vec<CastCheck>,
// fields related to closure capture
current_captures: Vec<CapturedItemWithoutTy>,
current_closure: Option<ClosureId>,
@ -582,7 +586,8 @@ fn new(
resolver,
diverges: Diverges::Maybe,
breakables: Vec::new(),
current_captures: vec![],
deferred_cast_checks: Vec::new(),
current_captures: Vec::new(),
current_closure: None,
deferred_closures: FxHashMap::default(),
closure_dependencies: FxHashMap::default(),
@ -594,7 +599,7 @@ fn new(
// used this function for another workaround, mention it here. If you really need this function and believe that
// there is no problem in it being `pub(crate)`, remove this comment.
pub(crate) fn resolve_all(self) -> InferenceResult {
let InferenceContext { mut table, mut result, .. } = self;
let InferenceContext { mut table, mut result, deferred_cast_checks, .. } = self;
// Destructure every single field so whenever new fields are added to `InferenceResult` we
// don't forget to handle them here.
let InferenceResult {
@ -622,6 +627,13 @@ pub(crate) fn resolve_all(self) -> InferenceResult {
table.fallback_if_possible();
// Comment from rustc:
// Even though coercion casts provide type hints, we check casts after fallback for
// backwards compatibility. This makes fallback a stronger type hint than a cast coercion.
for cast in deferred_cast_checks {
cast.check(&mut table);
}
// FIXME resolve obligations as well (use Guidance if necessary)
table.resolve_obligations_as_possible();

View File

@ -0,0 +1,46 @@
//! Type cast logic. Basically coercion + additional casts.
use crate::{infer::unify::InferenceTable, Interner, Ty, TyExt, TyKind};
#[derive(Clone, Debug)]
pub(super) struct CastCheck {
expr_ty: Ty,
cast_ty: Ty,
}
impl CastCheck {
pub(super) fn new(expr_ty: Ty, cast_ty: Ty) -> Self {
Self { expr_ty, cast_ty }
}
pub(super) fn check(self, table: &mut InferenceTable<'_>) {
// FIXME: This function currently only implements the bits that influence the type
// inference. We should return the adjustments on success and report diagnostics on error.
let expr_ty = table.resolve_ty_shallow(&self.expr_ty);
let cast_ty = table.resolve_ty_shallow(&self.cast_ty);
if expr_ty.contains_unknown() || cast_ty.contains_unknown() {
return;
}
if table.coerce(&expr_ty, &cast_ty).is_ok() {
return;
}
if check_ref_to_ptr_cast(expr_ty, cast_ty, table) {
// Note that this type of cast is actually split into a coercion to a
// pointer type and a cast:
// &[T; N] -> *[T; N] -> *T
return;
}
// FIXME: Check other kinds of non-coercion casts and report error if any?
}
}
fn check_ref_to_ptr_cast(expr_ty: Ty, cast_ty: Ty, table: &mut InferenceTable<'_>) -> bool {
let Some((expr_inner_ty, _, _)) = expr_ty.as_reference() else { return false; };
let Some((cast_inner_ty, _)) = cast_ty.as_raw_ptr() else { return false; };
let TyKind::Array(expr_elt_ty, _) = expr_inner_ty.kind(Interner) else { return false; };
table.coerce(expr_elt_ty, cast_inner_ty).is_ok()
}

View File

@ -46,8 +46,8 @@
};
use super::{
coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges, Expectation,
InferenceContext, InferenceDiagnostic, TypeMismatch,
cast::CastCheck, coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges,
Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch,
};
impl InferenceContext<'_> {
@ -574,16 +574,8 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
}
Expr::Cast { expr, type_ref } => {
let cast_ty = self.make_ty(type_ref);
// FIXME: propagate the "castable to" expectation
let inner_ty = self.infer_expr_no_expect(*expr);
match (inner_ty.kind(Interner), cast_ty.kind(Interner)) {
(TyKind::Ref(_, _, inner), TyKind::Raw(_, cast)) => {
// FIXME: record invalid cast diagnostic in case of mismatch
self.unify(inner, cast);
}
// FIXME check the other kinds of cast...
_ => (),
}
let expr_ty = self.infer_expr(*expr, &Expectation::Castable(cast_ty.clone()));
self.deferred_cast_checks.push(CastCheck::new(expr_ty, cast_ty.clone()));
cast_ty
}
Expr::Ref { expr, rawness, mutability } => {
@ -1592,7 +1584,7 @@ fn expected_inputs_for_expected_output(
output: Ty,
inputs: Vec<Ty>,
) -> Vec<Ty> {
if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
if let Some(expected_ty) = expected_output.only_has_type(&mut self.table) {
self.table.fudge_inference(|table| {
if table.try_unify(&expected_ty, &output).is_ok() {
table.resolve_with_fallback(inputs, &|var, kind, _, _| match kind {

View File

@ -1978,3 +1978,23 @@ fn f(self) {
"#,
);
}
#[test]
fn dont_unify_on_casts() {
// #15246
check_types(
r#"
fn unify(_: [bool; 1]) {}
fn casted(_: *const bool) {}
fn default<T>() -> T { loop {} }
fn test() {
let foo = default();
//^^^ [bool; 1]
casted(&foo as *const _);
unify(foo);
}
"#,
);
}

View File

@ -3513,7 +3513,6 @@ fn func() {
);
}
// FIXME
#[test]
fn castable_to() {
check_infer(
@ -3538,10 +3537,10 @@ fn func() {
120..122 '{}': ()
138..184 '{ ...0]>; }': ()
148..149 'x': Box<[i32; 0]>
152..160 'Box::new': fn new<[{unknown}; 0]>([{unknown}; 0]) -> Box<[{unknown}; 0]>
152..164 'Box::new([])': Box<[{unknown}; 0]>
152..160 'Box::new': fn new<[i32; 0]>([i32; 0]) -> Box<[i32; 0]>
152..164 'Box::new([])': Box<[i32; 0]>
152..181 'Box::n...2; 0]>': Box<[i32; 0]>
161..163 '[]': [{unknown}; 0]
161..163 '[]': [i32; 0]
"#]],
);
}
@ -3577,6 +3576,21 @@ fn f<T>(t: Ark<T>) {
);
}
#[test]
fn ref_to_array_to_ptr_cast() {
check_types(
r#"
fn default<T>() -> T { loop {} }
fn foo() {
let arr = [default()];
//^^^ [i32; 1]
let ref_to_arr = &arr;
let casted = ref_to_arr as *const i32;
}
"#,
);
}
#[test]
fn const_dependent_on_local() {
check_types(