diff --git a/crates/base_db/src/lib.rs b/crates/base_db/src/lib.rs index 54baa3a6330..d26f8f18081 100644 --- a/crates/base_db/src/lib.rs +++ b/crates/base_db/src/lib.rs @@ -42,7 +42,7 @@ pub struct FilePosition { pub offset: TextSize, } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] pub struct FileRange { pub file_id: FileId, pub range: TextRange, diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs index 2c667da2569..f023c1fb7de 100644 --- a/crates/hir_ty/src/infer.rs +++ b/crates/hir_ty/src/infer.rs @@ -761,6 +761,38 @@ fn only_has_type(&self, table: &mut unify::InferenceTable) -> Option { Expectation::RValueLikeUnsized(_) | Expectation::None => None, } } + + /// Comment copied from rustc: + /// Disregard "castable to" expectations because they + /// can lead us astray. Consider for example `if cond + /// {22} else {c} as u8` -- if we propagate the + /// "castable to u8" constraint to 22, it will pick the + /// type 22u8, which is overly constrained (c might not + /// be a u8). In effect, the problem is that the + /// "castable to" expectation is not the tightest thing + /// we can say, so we want to drop it in this case. + /// The tightest thing we can say is "must unify with + /// else branch". Note that in the case of a "has type" + /// constraint, this limitation does not hold. + /// + /// If the expected type is just a type variable, then don't use + /// an expected type. Otherwise, we might write parts of the type + /// when checking the 'then' block which are incompatible with the + /// 'else' branch. + fn adjust_for_branches(&self, table: &mut unify::InferenceTable) -> Expectation { + match self { + Expectation::HasType(ety) => { + let ety = table.resolve_ty_shallow(&ety); + if !ety.is_ty_var() { + Expectation::HasType(ety) + } else { + Expectation::None + } + } + Expectation::RValueLikeUnsized(ety) => Expectation::RValueLikeUnsized(ety.clone()), + _ => Expectation::None, + } + } } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs index f73bf43b215..e34f194fff8 100644 --- a/crates/hir_ty/src/infer/expr.rs +++ b/crates/hir_ty/src/infer/expr.rs @@ -337,10 +337,15 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { Expr::Match { expr, arms } => { let input_ty = self.infer_expr(*expr, &Expectation::none()); + let expected = expected.adjust_for_branches(&mut self.table); + let mut result_ty = if arms.is_empty() { TyKind::Never.intern(&Interner) } else { - self.table.new_type_var() + match &expected { + Expectation::HasType(ty) => ty.clone(), + _ => self.table.new_type_var(), + } }; let matchee_diverges = self.diverges; diff --git a/crates/hir_ty/src/tests.rs b/crates/hir_ty/src/tests.rs index 9d726b0248b..b873585c47d 100644 --- a/crates/hir_ty/src/tests.rs +++ b/crates/hir_ty/src/tests.rs @@ -9,7 +9,7 @@ mod display_source_code; mod incremental; -use std::{env, sync::Arc}; +use std::{collections::HashMap, env, sync::Arc}; use base_db::{fixture::WithFixture, FileRange, SourceDatabase, SourceDatabaseExt}; use expect_test::Expect; @@ -83,9 +83,105 @@ fn check_types_impl(ra_fixture: &str, display_source: bool) { checked_one = true; } } + assert!(checked_one, "no `//^` annotations found"); } +fn check_no_mismatches(ra_fixture: &str) { + check_mismatches_impl(ra_fixture, true) +} + +#[allow(unused)] +fn check_mismatches(ra_fixture: &str) { + check_mismatches_impl(ra_fixture, false) +} + +fn check_mismatches_impl(ra_fixture: &str, allow_none: bool) { + let _tracing = setup_tracing(); + let (db, file_id) = TestDB::with_single_file(ra_fixture); + let module = db.module_for_file(file_id); + let def_map = module.def_map(&db); + + let mut defs: Vec = Vec::new(); + visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it)); + defs.sort_by_key(|def| match def { + DefWithBodyId::FunctionId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.syntax().text_range().start() + } + DefWithBodyId::ConstId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.syntax().text_range().start() + } + DefWithBodyId::StaticId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.syntax().text_range().start() + } + }); + let mut mismatches = HashMap::new(); + let mut push_mismatch = |src_ptr: InFile, mismatch: TypeMismatch| { + let range = src_ptr.value.text_range(); + if src_ptr.file_id.call_node(&db).is_some() { + panic!("type mismatch in macro expansion"); + } + let file_range = FileRange { file_id: src_ptr.file_id.original_file(&db), range }; + let actual = format!( + "expected {}, got {}", + mismatch.expected.display_test(&db), + mismatch.actual.display_test(&db) + ); + mismatches.insert(file_range, actual); + }; + for def in defs { + let (_body, body_source_map) = db.body_with_source_map(def); + let inference_result = db.infer(def); + for (pat, mismatch) in inference_result.pat_type_mismatches() { + let syntax_ptr = match body_source_map.pat_syntax(pat) { + Ok(sp) => { + let root = db.parse_or_expand(sp.file_id).unwrap(); + sp.map(|ptr| { + ptr.either( + |it| it.to_node(&root).syntax().clone(), + |it| it.to_node(&root).syntax().clone(), + ) + }) + } + Err(SyntheticSyntax) => continue, + }; + push_mismatch(syntax_ptr, mismatch.clone()); + } + for (expr, mismatch) in inference_result.expr_type_mismatches() { + let node = match body_source_map.expr_syntax(expr) { + Ok(sp) => { + let root = db.parse_or_expand(sp.file_id).unwrap(); + sp.map(|ptr| ptr.to_node(&root).syntax().clone()) + } + Err(SyntheticSyntax) => continue, + }; + push_mismatch(node, mismatch.clone()); + } + } + let mut checked_one = false; + for (file_id, annotations) in db.extract_annotations() { + for (range, expected) in annotations { + let file_range = FileRange { file_id, range }; + if let Some(mismatch) = mismatches.remove(&file_range) { + assert_eq!(mismatch, expected); + } else { + assert!(false, "Expected mismatch not encountered: {}\n", expected); + } + checked_one = true; + } + } + let mut buf = String::new(); + for (range, mismatch) in mismatches { + format_to!(buf, "{:?}: {}\n", range.range, mismatch,); + } + assert!(buf.is_empty(), "Unexpected type mismatches:\n{}", buf); + + assert!(checked_one || allow_none, "no `//^` annotations found"); +} + fn type_at_range(db: &TestDB, pos: FileRange) -> Ty { let file = db.parse(pos.file_id).ok().unwrap(); let expr = algo::find_node_at_range::(file.syntax(), pos.range).unwrap(); diff --git a/crates/hir_ty/src/tests/coercion.rs b/crates/hir_ty/src/tests/coercion.rs index 6dac7e10397..71047703d18 100644 --- a/crates/hir_ty/src/tests/coercion.rs +++ b/crates/hir_ty/src/tests/coercion.rs @@ -1,6 +1,6 @@ use expect_test::expect; -use super::{check_infer, check_infer_with_mismatches, check_types}; +use super::{check_infer, check_infer_with_mismatches, check_no_mismatches, check_types}; #[test] fn infer_block_expr_type_mismatch() { @@ -963,7 +963,7 @@ fn test() -> i32 { #[test] fn panic_macro() { - check_infer_with_mismatches( + check_no_mismatches( r#" mod panic { #[macro_export] @@ -991,15 +991,34 @@ fn main() { panic!() } "#, - expect![[r#" - 174..185 '{ loop {} }': ! - 176..183 'loop {}': ! - 181..183 '{}': () - !0..24 '$crate...:panic': fn panic() -> ! - !0..26 '$crate...anic()': ! - !0..26 '$crate...anic()': ! - !0..28 '$crate...015!()': ! - 454..470 '{ ...c!() }': () - "#]], + ); +} + +#[test] +fn coerce_unsize_expected_type() { + check_no_mismatches( + r#" +#[lang = "sized"] +pub trait Sized {} +#[lang = "unsize"] +pub trait Unsize {} +#[lang = "coerce_unsized"] +pub trait CoerceUnsized {} + +impl, U> CoerceUnsized<&U> for &T {} + +fn main() { + let foo: &[u32] = &[1, 2]; + let foo: &[u32] = match true { + true => &[1, 2], + false => &[1, 2, 3], + }; + let foo: &[u32] = if true { + &[1, 2] + } else { + &[1, 2, 3] + }; +} + "#, ); } diff --git a/crates/hir_ty/src/tests/patterns.rs b/crates/hir_ty/src/tests/patterns.rs index 7d00cee9b3c..aa513c56d5c 100644 --- a/crates/hir_ty/src/tests/patterns.rs +++ b/crates/hir_ty/src/tests/patterns.rs @@ -1,6 +1,6 @@ use expect_test::expect; -use super::{check_infer, check_infer_with_mismatches, check_types}; +use super::{check_infer, check_infer_with_mismatches, check_mismatches, check_types}; #[test] fn infer_pattern() { @@ -518,47 +518,24 @@ fn test(a1: A, o: Option) { #[test] fn infer_const_pattern() { - check_infer_with_mismatches( + check_mismatches( r#" - enum Option { None } - use Option::None; - struct Foo; - const Bar: usize = 1; +enum Option { None } +use Option::None; +struct Foo; +const Bar: usize = 1; - fn test() { - let a: Option = None; - let b: Option = match a { - None => None, - }; - let _: () = match () { Foo => Foo }; // Expected mismatch - let _: () = match () { Bar => Bar }; // Expected mismatch - } +fn test() { + let a: Option = None; + let b: Option = match a { + None => None, + }; + let _: () = match () { Foo => () }; + // ^^^ expected (), got Foo + let _: () = match () { Bar => () }; + // ^^^ expected (), got usize +} "#, - expect![[r#" - 73..74 '1': usize - 87..309 '{ ...atch }': () - 97..98 'a': Option - 114..118 'None': Option - 128..129 'b': Option - 145..182 'match ... }': Option - 151..152 'a': Option - 163..167 'None': Option - 171..175 'None': Option - 192..193 '_': () - 200..223 'match ... Foo }': Foo - 206..208 '()': () - 211..214 'Foo': Foo - 218..221 'Foo': Foo - 254..255 '_': () - 262..285 'match ... Bar }': usize - 268..270 '()': () - 273..276 'Bar': usize - 280..283 'Bar': usize - 200..223: expected (), got Foo - 211..214: expected (), got Foo - 262..285: expected (), got usize - 273..276: expected (), got usize - "#]], ); }