From 24ba1bed040f7d8f483b250dbd4e49383823f644 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 4 Mar 2023 12:48:57 +0100 Subject: [PATCH] Set expectation for no-semi expression statements to unit --- crates/hir-def/src/resolver.rs | 4 +- crates/hir-ty/src/infer/expr.rs | 75 ++++++++++++++++---------- crates/hir-ty/src/infer/path.rs | 25 ++++----- crates/hir-ty/src/tests/diagnostics.rs | 21 ++++++++ crates/hir-ty/src/tests/regression.rs | 4 +- 5 files changed, 82 insertions(+), 47 deletions(-) diff --git a/crates/hir-def/src/resolver.rs b/crates/hir-def/src/resolver.rs index 0b9c136c7eb..664db292a7f 100644 --- a/crates/hir-def/src/resolver.rs +++ b/crates/hir-def/src/resolver.rs @@ -294,8 +294,8 @@ impl Resolver { } } - if let res @ Some(_) = self.module_scope.resolve_path_in_value_ns(db, path) { - return res; + if let Some(res) = self.module_scope.resolve_path_in_value_ns(db, path) { + return Some(res); } // If a path of the shape `u16::from_le_bytes` failed to resolve at all, then we fall back diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 02024e1ea78..81e97a9b0bf 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -130,7 +130,7 @@ impl<'a> InferenceContext<'a> { ); let ty = match label { Some(_) => { - let break_ty = self.table.new_type_var(); + let break_ty = expected.coercion_target_type(&mut self.table); let (breaks, ty) = self.with_breakable_ctx( BreakableKind::Block, Some(break_ty.clone()), @@ -403,37 +403,47 @@ impl<'a> InferenceContext<'a> { Expr::Match { expr, arms } => { let input_ty = self.infer_expr(*expr, &Expectation::none()); - let expected = expected.adjust_for_branches(&mut self.table); - - let result_ty = if arms.is_empty() { + if arms.is_empty() { + self.diverges = Diverges::Always; self.result.standard_types.never.clone() } else { - expected.coercion_target_type(&mut self.table) - }; - let mut coerce = CoerceMany::new(result_ty); - - let matchee_diverges = self.diverges; - let mut all_arms_diverge = Diverges::Always; - - for arm in arms.iter() { - self.diverges = Diverges::Maybe; - let input_ty = self.resolve_ty_shallow(&input_ty); - self.infer_top_pat(arm.pat, &input_ty); - if let Some(guard_expr) = arm.guard { - self.infer_expr( - guard_expr, - &Expectation::HasType(self.result.standard_types.bool_.clone()), - ); + let matchee_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); + let mut all_arms_diverge = Diverges::Always; + for arm in arms.iter() { + let input_ty = self.resolve_ty_shallow(&input_ty); + self.infer_top_pat(arm.pat, &input_ty); } - let arm_ty = self.infer_expr_inner(arm.expr, &expected); - all_arms_diverge &= self.diverges; - coerce.coerce(self, Some(arm.expr), &arm_ty); + let expected = expected.adjust_for_branches(&mut self.table); + let result_ty = match &expected { + // We don't coerce to `()` so that if the match expression is a + // statement it's branches can have any consistent type. + Expectation::HasType(ty) if *ty != self.result.standard_types.unit => { + ty.clone() + } + _ => self.table.new_type_var(), + }; + let mut coerce = CoerceMany::new(result_ty); + + for arm in arms.iter() { + if let Some(guard_expr) = arm.guard { + self.diverges = Diverges::Maybe; + self.infer_expr( + guard_expr, + &Expectation::HasType(self.result.standard_types.bool_.clone()), + ); + } + self.diverges = Diverges::Maybe; + + let arm_ty = self.infer_expr_inner(arm.expr, &expected); + all_arms_diverge &= self.diverges; + coerce.coerce(self, Some(arm.expr), &arm_ty); + } + + self.diverges = matchee_diverges | all_arms_diverge; + + coerce.complete(self) } - - self.diverges = matchee_diverges | all_arms_diverge; - - coerce.complete(self) } Expr::Path(p) => { // FIXME this could be more efficient... @@ -1179,8 +1189,15 @@ impl<'a> InferenceContext<'a> { self.diverges = previous_diverges; } } - Statement::Expr { expr, .. } => { - self.infer_expr(*expr, &Expectation::none()); + &Statement::Expr { expr, has_semi } => { + self.infer_expr( + expr, + &if has_semi { + Expectation::none() + } else { + Expectation::HasType(self.result.standard_types.unit.clone()) + }, + ); } } } diff --git a/crates/hir-ty/src/infer/path.rs b/crates/hir-ty/src/infer/path.rs index 0a8527afbd0..b3867623f37 100644 --- a/crates/hir-ty/src/infer/path.rs +++ b/crates/hir-ty/src/infer/path.rs @@ -40,20 +40,14 @@ impl<'a> InferenceContext<'a> { id: ExprOrPatId, ) -> Option { let (value, self_subst) = if let Some(type_ref) = path.type_anchor() { - if path.segments().is_empty() { - // This can't actually happen syntax-wise - return None; - } + let Some(last) = path.segments().last() else { return None }; let ty = self.make_ty(type_ref); let remaining_segments_for_ty = path.segments().take(path.segments().len() - 1); let ctx = crate::lower::TyLoweringContext::new(self.db, resolver); let (ty, _) = ctx.lower_ty_relative_path(ty, None, remaining_segments_for_ty); - self.resolve_ty_assoc_item( - ty, - path.segments().last().expect("path had at least one segment").name, - id, - )? + self.resolve_ty_assoc_item(ty, last.name, id)? } else { + // FIXME: report error, unresolved first path segment let value_or_partial = resolver.resolve_path_in_value_ns(self.db.upcast(), path.mod_path())?; @@ -66,10 +60,13 @@ impl<'a> InferenceContext<'a> { }; let typable: ValueTyDefId = match value { - ValueNs::LocalBinding(pat) => { - let ty = self.result.type_of_pat.get(pat)?.clone(); - return Some(ty); - } + ValueNs::LocalBinding(pat) => match self.result.type_of_pat.get(pat) { + Some(ty) => return Some(ty.clone()), + None => { + never!("uninferred pattern?"); + return None; + } + }, ValueNs::FunctionId(it) => it.into(), ValueNs::ConstId(it) => it.into(), ValueNs::StaticId(it) => it.into(), @@ -91,7 +88,7 @@ impl<'a> InferenceContext<'a> { let ty = self.db.value_ty(struct_id.into()).substitute(Interner, &substs); return Some(ty); } else { - // FIXME: diagnostic, invalid Self reference + // FIXME: report error, invalid Self reference return None; } } diff --git a/crates/hir-ty/src/tests/diagnostics.rs b/crates/hir-ty/src/tests/diagnostics.rs index f00fa972948..1876be303ad 100644 --- a/crates/hir-ty/src/tests/diagnostics.rs +++ b/crates/hir-ty/src/tests/diagnostics.rs @@ -73,3 +73,24 @@ fn test(x: bool) -> &'static str { "#, ); } + +#[test] +fn non_unit_block_expr_stmt_no_semi() { + check( + r#" +fn test(x: bool) { + if x { + "notok" + //^^^^^^^ expected (), got &str + } else { + "ok" + //^^^^ expected (), got &str + } + match x { true => true, false => 0 } + //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected (), got bool + //^ expected bool, got i32 + () +} +"#, + ); +} diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs index de6ae7fff8f..5fc2f46d560 100644 --- a/crates/hir-ty/src/tests/regression.rs +++ b/crates/hir-ty/src/tests/regression.rs @@ -1015,9 +1015,9 @@ fn cfg_tail() { 20..31 '{ "first" }': () 22..29 '"first"': &str 72..190 '{ ...] 13 }': () - 78..88 '{ "fake" }': &str + 78..88 '{ "fake" }': () 80..86 '"fake"': &str - 93..103 '{ "fake" }': &str + 93..103 '{ "fake" }': () 95..101 '"fake"': &str 108..120 '{ "second" }': () 110..118 '"second"': &str