From 3bd47c0285433b5eb258196a81b95141d2a70505 Mon Sep 17 00:00:00 2001
From: Marcus Klaas de Vries <mail@marcusklaas.nl>
Date: Fri, 25 Jan 2019 21:16:02 +0100
Subject: [PATCH] First attempt at generic type inference for fns

---
 crates/ra_hir/src/code_model_api.rs           |  6 +-
 crates/ra_hir/src/code_model_impl/function.rs |  8 +--
 crates/ra_hir/src/generics.rs                 |  3 +-
 crates/ra_hir/src/ty.rs                       | 61 +++++++++++++++----
 crates/ra_hir/src/ty/tests.rs                 | 22 +++++++
 .../src/completion/completion_item.rs         |  2 +-
 6 files changed, 81 insertions(+), 21 deletions(-)

diff --git a/crates/ra_hir/src/code_model_api.rs b/crates/ra_hir/src/code_model_api.rs
index 191104890af..82ebb275a75 100644
--- a/crates/ra_hir/src/code_model_api.rs
+++ b/crates/ra_hir/src/code_model_api.rs
@@ -388,7 +388,7 @@ pub use crate::code_model_impl::function::ScopeEntryWithSyntax;
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub struct FnSignature {
     pub(crate) name: Name,
-    pub(crate) params: Vec<TypeRef>,
+    pub(crate) args: Vec<TypeRef>,
     pub(crate) ret_type: TypeRef,
     /// True if the first param is `self`. This is relevant to decide whether this
     /// can be called as a method.
@@ -400,8 +400,8 @@ impl FnSignature {
         &self.name
     }
 
-    pub fn params(&self) -> &[TypeRef] {
-        &self.params
+    pub fn args(&self) -> &[TypeRef] {
+        &self.args
     }
 
     pub fn ret_type(&self) -> &TypeRef {
diff --git a/crates/ra_hir/src/code_model_impl/function.rs b/crates/ra_hir/src/code_model_impl/function.rs
index e0dd4d6290b..b4aa18540e1 100644
--- a/crates/ra_hir/src/code_model_impl/function.rs
+++ b/crates/ra_hir/src/code_model_impl/function.rs
@@ -32,7 +32,7 @@ impl FnSignature {
             .name()
             .map(|n| n.as_name())
             .unwrap_or_else(Name::missing);
-        let mut params = Vec::new();
+        let mut args = Vec::new();
         let mut has_self_param = false;
         if let Some(param_list) = node.param_list() {
             if let Some(self_param) = param_list.self_param() {
@@ -50,12 +50,12 @@ impl FnSignature {
                         }
                     }
                 };
-                params.push(self_type);
+                args.push(self_type);
                 has_self_param = true;
             }
             for param in param_list.params() {
                 let type_ref = TypeRef::from_ast_opt(param.type_ref());
-                params.push(type_ref);
+                args.push(type_ref);
             }
         }
         let ret_type = if let Some(type_ref) = node.ret_type().and_then(|rt| rt.type_ref()) {
@@ -66,7 +66,7 @@ impl FnSignature {
 
         let sig = FnSignature {
             name,
-            params,
+            args,
             ret_type,
             has_self_param,
         };
diff --git a/crates/ra_hir/src/generics.rs b/crates/ra_hir/src/generics.rs
index 64c20a46222..a5501d54388 100644
--- a/crates/ra_hir/src/generics.rs
+++ b/crates/ra_hir/src/generics.rs
@@ -49,7 +49,8 @@ impl GenericParams {
         Arc::new(generics)
     }
 
-    fn fill(&mut self, node: &impl TypeParamsOwner) {
+    // FIXME: probably shouldnt be pub(crate)
+    pub(crate) fn fill(&mut self, node: &impl TypeParamsOwner) {
         if let Some(params) = node.type_param_list() {
             self.fill_params(params)
         }
diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs
index 31ea4570627..95de916ee18 100644
--- a/crates/ra_hir/src/ty.rs
+++ b/crates/ra_hir/src/ty.rs
@@ -209,6 +209,18 @@ pub enum Ty {
     /// `&'a mut T` or `&'a T`.
     Ref(Arc<Ty>, Mutability),
 
+    /// The anonymous type of a function declaration/definition. Each
+    /// function has a unique type, which is output (for a function
+    /// named `foo` returning an `i32`) as `fn() -> i32 {foo}`.
+    ///
+    /// For example the type of `bar` here:
+    ///
+    /// ```rust
+    /// fn foo() -> i32 { 1 }
+    /// let bar = foo; // bar: fn() -> i32 {foo}
+    /// ```
+    FnDef(Function, Substs),
+
     /// A pointer to a function.  Written as `fn() -> i32`.
     ///
     /// For example the type of `bar` here:
@@ -485,7 +497,7 @@ impl Ty {
                 }
                 sig_mut.output.walk_mut(f);
             }
-            Ty::Adt { substs, .. } => {
+            Ty::FnDef(_, substs) | Ty::Adt { substs, .. } => {
                 // Without an Arc::make_mut_slice, we can't avoid the clone here:
                 let mut v: Vec<_> = substs.0.iter().cloned().collect();
                 for t in &mut v {
@@ -524,6 +536,7 @@ impl Ty {
                 name,
                 substs,
             },
+            Ty::FnDef(func, _) => Ty::FnDef(func, substs),
             _ => self,
         }
     }
@@ -579,6 +592,7 @@ impl fmt::Display for Ty {
                         .to_fmt(f)
                 }
             }
+            Ty::FnDef(_func, _substs) => write!(f, "FNDEF-IMPLEMENT-ME"),
             Ty::FnPtr(sig) => {
                 join(sig.input.iter())
                     .surround_with("fn(", ")")
@@ -608,12 +622,18 @@ impl fmt::Display for Ty {
 /// Compute the declared type of a function. This should not need to look at the
 /// function body.
 fn type_for_fn(db: &impl HirDatabase, f: Function) -> Ty {
+    let generics = f.generic_params(db);
+    let substs = make_substs(&generics);
+    Ty::FnDef(f.into(), substs)
+}
+
+fn get_func_sig(db: &impl HirDatabase, f: Function) -> FnSig {
     let signature = f.signature(db);
     let module = f.module(db);
     let impl_block = f.impl_block(db);
     let generics = f.generic_params(db);
     let input = signature
-        .params()
+        .args()
         .iter()
         .map(|tr| Ty::from_hir(db, &module, impl_block.as_ref(), &generics, tr))
         .collect::<Vec<_>>();
@@ -624,8 +644,7 @@ fn type_for_fn(db: &impl HirDatabase, f: Function) -> Ty {
         &generics,
         signature.ret_type(),
     );
-    let sig = FnSig { input, output };
-    Ty::FnPtr(Arc::new(sig))
+    FnSig { input, output }
 }
 
 fn make_substs(generics: &GenericParams) -> Substs {
@@ -1142,7 +1161,13 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
                 let ty = self.insert_type_vars(ty.apply_substs(substs));
                 (ty, Some(var.into()))
             }
-            TypableDef::Enum(_) | TypableDef::Function(_) => (Ty::Unknown, None),
+            TypableDef::Function(func) => {
+                let ty = type_for_fn(self.db, func);
+                let ty = self.insert_type_vars(ty.apply_substs(substs));
+                // FIXME: is this right?
+                (ty, None)
+            }
+            TypableDef::Enum(_) => (Ty::Unknown, None),
         }
     }
 
@@ -1331,12 +1356,27 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
             }
             Expr::Call { callee, args } => {
                 let callee_ty = self.infer_expr(*callee, &Expectation::none());
+                // FIXME: so manu unnecessary clones
                 let (param_tys, ret_ty) = match &callee_ty {
-                    Ty::FnPtr(sig) => (&sig.input[..], sig.output.clone()),
+                    Ty::FnPtr(sig) => (sig.input.clone(), sig.output.clone()),
+                    Ty::FnDef(func, substs) => {
+                        let fn_sig = func.signature(self.db);
+                        // TODO: get input and return types from the fn_sig.
+                        // it contains typerefs which we can make into proper tys
+
+                        let sig = get_func_sig(self.db, *func);
+                        (
+                            sig.input
+                                .iter()
+                                .map(|ty| ty.clone().subst(&substs))
+                                .collect(),
+                            sig.output.clone().subst(&substs),
+                        )
+                    }
                     _ => {
                         // not callable
                         // TODO report an error?
-                        (&[][..], Ty::Unknown)
+                        (Vec::new(), Ty::Unknown)
                     }
                 };
                 for (i, arg) in args.iter().enumerate() {
@@ -1604,15 +1644,12 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
 
     fn collect_fn_signature(&mut self, signature: &FnSignature) {
         let body = Arc::clone(&self.body); // avoid borrow checker problem
-        for (type_ref, pat) in signature.params().iter().zip(body.params()) {
+        for (type_ref, pat) in signature.args().iter().zip(body.params()) {
             let ty = self.make_ty(type_ref);
 
             self.infer_pat(*pat, &ty);
         }
-        self.return_ty = {
-            let ty = self.make_ty(signature.ret_type());
-            ty
-        };
+        self.return_ty = self.make_ty(signature.ret_type());
     }
 
     fn infer_body(&mut self) {
diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs
index f74d6f5ea37..40913b164ab 100644
--- a/crates/ra_hir/src/ty/tests.rs
+++ b/crates/ra_hir/src/ty/tests.rs
@@ -594,6 +594,28 @@ fn test() {
     );
 }
 
+#[test]
+fn infer_type_param() {
+    check_inference(
+        "generic_fn",
+        r#"
+fn id<T>(x: T) -> T {
+    x
+}
+
+fn clone<T>(x: &T) -> T {
+    x
+}
+
+fn test() {
+    let y = 10u32;
+    id(y);
+    let x: bool = clone(z);
+}
+"#,
+    );
+}
+
 fn infer(content: &str) -> String {
     let (db, _, file_id) = MockDatabase::with_single_file(content);
     let source_file = db.parse(file_id);
diff --git a/crates/ra_ide_api/src/completion/completion_item.rs b/crates/ra_ide_api/src/completion/completion_item.rs
index b16ac2b289c..6e9a68e4068 100644
--- a/crates/ra_ide_api/src/completion/completion_item.rs
+++ b/crates/ra_ide_api/src/completion/completion_item.rs
@@ -240,7 +240,7 @@ impl Builder {
         if ctx.use_item_syntax.is_none() && !ctx.is_call {
             tested_by!(inserts_parens_for_function_calls);
             let sig = function.signature(ctx.db);
-            if sig.params().is_empty() || sig.has_self_param() && sig.params().len() == 1 {
+            if sig.args().is_empty() || sig.has_self_param() && sig.args().len() == 1 {
                 self.insert_text = Some(format!("{}()$0", self.label));
             } else {
                 self.insert_text = Some(format!("{}($0)", self.label));