diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs
index 12e0b041316..746c3e4ee88 100644
--- a/crates/hir_ty/src/infer/expr.rs
+++ b/crates/hir_ty/src/infer/expr.rs
@@ -735,10 +735,11 @@ impl<'a> InferenceContext<'a> {
                         _ => self.table.new_type_var(),
                     };
 
+                let expected = Expectation::has_type(elem_ty.clone());
                 let len = match array {
                     Array::ElementList(items) => {
                         for expr in items.iter() {
-                            let cur_elem_ty = self.infer_expr_inner(*expr, expected);
+                            let cur_elem_ty = self.infer_expr_inner(*expr, &expected);
                             elem_ty = self.coerce_merge_branch(Some(*expr), &elem_ty, &cur_elem_ty);
                         }
                         Some(items.len() as u64)
diff --git a/crates/hir_ty/src/tests/regression.rs b/crates/hir_ty/src/tests/regression.rs
index d80375f02fd..915dfbbc0d4 100644
--- a/crates/hir_ty/src/tests/regression.rs
+++ b/crates/hir_ty/src/tests/regression.rs
@@ -117,23 +117,34 @@ fn recursive_vars_2() {
         "#,
         expect![[r#"
             10..79 '{     ...x)]; }': ()
-            20..21 'x': {unknown}
-            24..31 'unknown': {unknown}
+            20..21 'x': &{unknown}
+            24..31 'unknown': &{unknown}
             41..42 'y': {unknown}
             45..52 'unknown': {unknown}
-            58..76 '[(x, y..., &x)]': [({unknown}, {unknown}); 2]
-            59..65 '(x, y)': ({unknown}, {unknown})
-            60..61 'x': {unknown}
+            58..76 '[(x, y..., &x)]': [(&{unknown}, {unknown}); 2]
+            59..65 '(x, y)': (&{unknown}, {unknown})
+            60..61 'x': &{unknown}
             63..64 'y': {unknown}
-            67..75 '(&y, &x)': (&{unknown}, &{unknown})
+            67..75 '(&y, &x)': (&{unknown}, {unknown})
             68..70 '&y': &{unknown}
             69..70 'y': {unknown}
-            72..74 '&x': &{unknown}
-            73..74 'x': {unknown}
+            72..74 '&x': &&{unknown}
+            73..74 'x': &{unknown}
         "#]],
     );
 }
 
+#[test]
+fn array_elements_expected_type() {
+    check_no_mismatches(
+        r#"
+        fn test() {
+            let x: [[u32; 2]; 2] = [[1, 2], [3, 4]];
+        }
+        "#,
+    );
+}
+
 #[test]
 fn infer_std_crash_1() {
     // caused stack overflow, taken from std