diff --git a/crates/ra_hir/src/expr/validation.rs b/crates/ra_hir/src/expr/validation.rs index 5d9d59ff89a..c8ae198696e 100644 --- a/crates/ra_hir/src/expr/validation.rs +++ b/crates/ra_hir/src/expr/validation.rs @@ -41,7 +41,7 @@ impl<'a, 'b> ExprValidator<'a, 'b> { let body_expr = &body[body.body_expr()]; if let Expr::Block { statements: _, tail: Some(t) } = body_expr { - self.validate_results_in_tail_expr(*t, db); + self.validate_results_in_tail_expr(body.body_expr(), *t, db); } } @@ -97,8 +97,14 @@ impl<'a, 'b> ExprValidator<'a, 'b> { } } - fn validate_results_in_tail_expr(&mut self, id: ExprId, db: &impl HirDatabase) { - let mismatch = match self.infer.type_mismatch_for_expr(id) { + fn validate_results_in_tail_expr( + &mut self, + body_id: ExprId, + id: ExprId, + db: &impl HirDatabase, + ) { + // the mismatch will be on the whole block currently + let mismatch = match self.infer.type_mismatch_for_expr(body_id) { Some(m) => m, None => return, }; diff --git a/crates/ra_hir/src/ty/infer.rs b/crates/ra_hir/src/ty/infer.rs index d94e8154b0d..8129904261a 100644 --- a/crates/ra_hir/src/ty/infer.rs +++ b/crates/ra_hir/src/ty/infer.rs @@ -280,8 +280,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { let ty1 = self.resolve_ty_shallow(ty1); let ty2 = self.resolve_ty_shallow(ty2); match (&*ty1, &*ty2) { - (Ty::Unknown, ..) => true, - (.., Ty::Unknown) => true, + (Ty::Unknown, _) | (_, Ty::Unknown) => true, (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => { self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1) } @@ -976,24 +975,48 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { ret_ty } + /// This is similar to unify, but it makes the first type coerce to the + /// second one. + fn coerce(&mut self, from_ty: &Ty, to_ty: &Ty) -> bool { + if is_never(from_ty) { + // ! coerces to any type + true + } else { + self.unify(from_ty, to_ty) + } + } + fn infer_expr(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { + let ty = self.infer_expr_inner(tgt_expr, expected); + let could_unify = self.unify(&ty, &expected.ty); + if !could_unify { + self.result.type_mismatches.insert( + tgt_expr, + TypeMismatch { expected: expected.ty.clone(), actual: ty.clone() }, + ); + } + let ty = self.resolve_ty_as_possible(&mut vec![], ty); + ty + } + + fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { let body = Arc::clone(&self.body); // avoid borrow checker problem let ty = match &body[tgt_expr] { Expr::Missing => Ty::Unknown, Expr::If { condition, then_branch, else_branch } => { // if let is desugared to match, so this is always simple if self.infer_expr(*condition, &Expectation::has_type(Ty::simple(TypeCtor::Bool))); - let then_ty = self.infer_expr(*then_branch, expected); - match else_branch { - Some(else_branch) => { - self.infer_expr(*else_branch, expected); - } - None => { - // no else branch -> unit - self.unify(&then_ty, &Ty::unit()); // actually coerce - } + + let then_ty = self.infer_expr_inner(*then_branch, &expected); + self.coerce(&then_ty, &expected.ty); + + let else_ty = match else_branch { + Some(else_branch) => self.infer_expr_inner(*else_branch, &expected), + None => Ty::unit(), }; - then_ty + self.coerce(&else_ty, &expected.ty); + + expected.ty.clone() } Expr::Block { statements, tail } => self.infer_block(statements, *tail, expected), Expr::TryBlock { body } => { @@ -1073,12 +1096,14 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { Expr::MethodCall { receiver, args, method_name, generic_args } => self .infer_method_call(tgt_expr, *receiver, &args, &method_name, generic_args.as_ref()), Expr::Match { expr, arms } => { + let input_ty = self.infer_expr(*expr, &Expectation::none()); let expected = if expected.ty == Ty::Unknown { Expectation::has_type(self.new_type_var()) } else { expected.clone() }; - let input_ty = self.infer_expr(*expr, &Expectation::none()); + + let mut arm_tys = Vec::with_capacity(arms.len()); for arm in arms { for &pat in &arm.pats { @@ -1090,10 +1115,16 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { &Expectation::has_type(Ty::simple(TypeCtor::Bool)), ); } - self.infer_expr(arm.expr, &expected); + arm_tys.push(self.infer_expr_inner(arm.expr, &expected)); } - expected.ty + let lub_ty = calculate_least_upper_bound(expected.ty.clone(), &arm_tys); + + for arm_ty in &arm_tys { + self.coerce(arm_ty, &lub_ty); + } + + lub_ty } Expr::Path(p) => { // FIXME this could be more efficient... @@ -1356,15 +1387,8 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { }; // use a new type variable if we got Ty::Unknown here let ty = self.insert_type_vars_shallow(ty); - let could_unify = self.unify(&ty, &expected.ty); let ty = self.resolve_ty_as_possible(&mut vec![], ty); self.write_expr_ty(tgt_expr, ty.clone()); - if !could_unify { - self.result.type_mismatches.insert( - tgt_expr, - TypeMismatch { expected: expected.ty.clone(), actual: ty.clone() }, - ); - } ty } @@ -1394,7 +1418,8 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } } } - let ty = if let Some(expr) = tail { self.infer_expr(expr, expected) } else { Ty::unit() }; + let ty = + if let Some(expr) = tail { self.infer_expr_inner(expr, expected) } else { Ty::unit() }; ty } @@ -1616,3 +1641,37 @@ mod diagnostics { } } } + +fn is_never(ty: &Ty) -> bool { + if let Ty::Apply(ApplicationTy { ctor: TypeCtor::Never, .. }) = ty { + true + } else { + false + } +} + +fn calculate_least_upper_bound(expected_ty: Ty, actual_tys: &[Ty]) -> Ty { + let mut all_never = true; + let mut last_never_ty = None; + let mut least_upper_bound = expected_ty; + + for actual_ty in actual_tys { + if is_never(actual_ty) { + last_never_ty = Some(actual_ty.clone()); + } else { + all_never = false; + least_upper_bound = match (actual_ty, &least_upper_bound) { + (_, Ty::Unknown) + | (Ty::Infer(_), Ty::Infer(InferTy::TypeVar(_))) + | (Ty::Apply(_), _) => actual_ty.clone(), + _ => least_upper_bound, + } + } + } + + if all_never && last_never_ty.is_some() { + last_never_ty.unwrap() + } else { + least_upper_bound + } +} diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs index c5818b738f0..e3eb0c3fafc 100644 --- a/crates/ra_hir/src/ty/tests.rs +++ b/crates/ra_hir/src/ty/tests.rs @@ -718,16 +718,18 @@ fn main(foo: Foo) { } "#), @r###" -[35; 38) 'foo': Foo -[45; 109) '{ ... } }': () -[51; 107) 'if tru... }': () -[54; 58) 'true': bool -[59; 67) '{ }': () -[73; 107) 'if fal... }': i32 -[76; 81) 'false': bool -[82; 107) '{ ... }': i32 -[92; 95) 'foo': Foo -[92; 101) 'foo.field': i32"### + ⋮ + ⋮[35; 38) 'foo': Foo + ⋮[45; 109) '{ ... } }': () + ⋮[51; 107) 'if tru... }': () + ⋮[54; 58) 'true': bool + ⋮[59; 67) '{ }': () + ⋮[73; 107) 'if fal... }': () + ⋮[76; 81) 'false': bool + ⋮[82; 107) '{ ... }': i32 + ⋮[92; 95) 'foo': Foo + ⋮[92; 101) 'foo.field': i32 + "### ) } @@ -3594,3 +3596,121 @@ fn no_such_field_diagnostics() { "### ); } + +mod branching_with_never_tests { + use super::type_at; + + #[test] + fn if_never() { + let t = type_at( + r#" +//- /main.rs +fn test() { + let i = if true { + loop {} + } else { + 3.0 + }; + i<|> + () +} +"#, + ); + assert_eq!(t, "f64"); + } + + #[test] + fn if_else_never() { + let t = type_at( + r#" +//- /main.rs +fn test(input: bool) { + let i = if input { + 2.0 + } else { + return + }; + i<|> + () +} +"#, + ); + assert_eq!(t, "f64"); + } + + #[test] + fn match_first_arm_never() { + let t = type_at( + r#" +//- /main.rs +fn test(a: i32) { + let i = match a { + 1 => return, + 2 => 2.0, + 3 => loop {}, + _ => 3.0, + }; + i<|> + () +} +"#, + ); + assert_eq!(t, "f64"); + } + + #[test] + fn match_second_arm_never() { + let t = type_at( + r#" +//- /main.rs +fn test(a: i32) { + let i = match a { + 1 => 3.0, + 2 => loop {}, + 3 => 3.0, + _ => return, + }; + i<|> + () +} +"#, + ); + assert_eq!(t, "f64"); + } + + #[test] + fn match_all_arms_never() { + let t = type_at( + r#" +//- /main.rs +fn test(a: i32) { + let i = match a { + 2 => return, + _ => loop {}, + }; + i<|> + () +} +"#, + ); + assert_eq!(t, "!"); + } + + #[test] + fn match_no_never_arms() { + let t = type_at( + r#" +//- /main.rs +fn test(a: i32) { + let i = match a { + 2 => 2.0, + _ => 3.0, + }; + i<|> + () +} +"#, + ); + assert_eq!(t, "f64"); + } +}