From 075ab03851d7b4304ddf563d75c4ee3d713f2583 Mon Sep 17 00:00:00 2001
From: Dorian Scheidt <dorian.scheidt@gmail.com>
Date: Sat, 30 Apr 2022 12:29:55 -0500
Subject: [PATCH 1/2] fix: Support generics in extract_function assist

This change attempts to resolve issue #7636: Extract into Function does not
create a generic function with constraints when extracting generic code.

In `FunctionBody::analyze_container`, we now traverse the `ancestors` in search
of `AnyHasGenericParams`, and attach any `GenericParamList`s and `WhereClause`s
we find to the `ContainerInfo`.

Later, in `format_function`, we collect all the `GenericParam`s and
`WherePred`s from the container, and filter them to keep only types matching
`TypeParam`s used within the newly extracted function body or param list. We
can then include the new `GenericParamList` and `WhereClause` in the new
function definition.

This change only impacts `TypeParam`s. `LifetimeParam`s and `ConstParam`s are
out of scope for this change.
---
 crates/hir/src/lib.rs                         |   9 +
 .../src/handlers/extract_function.rs          | 434 +++++++++++++++++-
 2 files changed, 436 insertions(+), 7 deletions(-)

diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs
index 96424d087ef..86124b68b51 100644
--- a/crates/hir/src/lib.rs
+++ b/crates/hir/src/lib.rs
@@ -3307,6 +3307,15 @@ impl Type {
         let tys = hir_ty::replace_errors_with_variables(&(self.ty.clone(), to.ty.clone()));
         hir_ty::could_coerce(db, self.env.clone(), &tys)
     }
+
+    pub fn as_type_param(&self, db: &dyn HirDatabase) -> Option<TypeParam> {
+        match self.ty.kind(Interner) {
+            TyKind::Placeholder(p) => Some(TypeParam {
+                id: TypeParamId::from_unchecked(hir_ty::from_placeholder_idx(db, *p)),
+            }),
+            _ => None,
+        }
+    }
 }
 
 #[derive(Debug)]
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 9233c198df8..aa1c3a548c9 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -2,7 +2,9 @@ use std::iter;
 
 use ast::make;
 use either::Either;
-use hir::{HasSource, HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
+use hir::{
+    HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam,
+};
 use ide_db::{
     defs::{Definition, NameRefClass},
     famous_defs::FamousDefs,
@@ -18,7 +20,7 @@ use syntax::{
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
-        AstNode,
+        AstNode, HasGenericParams,
     },
     match_ast, ted, SyntaxElement,
     SyntaxKind::{self, COMMENT},
@@ -294,6 +296,8 @@ struct ContainerInfo {
     parent_loop: Option<SyntaxNode>,
     /// The function's return type, const's type etc.
     ret_type: Option<hir::Type>,
+    generic_param_lists: Vec<ast::GenericParamList>,
+    where_clauses: Vec<ast::WhereClause>,
 }
 
 /// Control flow that is exported from extracted function
@@ -517,6 +521,24 @@ impl FunctionBody {
         }
     }
 
+    fn descendants(&self) -> impl Iterator<Item = SyntaxNode> {
+        match self {
+            FunctionBody::Expr(expr) => expr.syntax().descendants(),
+            FunctionBody::Span { parent, .. } => parent.syntax().descendants(),
+        }
+    }
+
+    fn descendant_paths(&self) -> impl Iterator<Item = ast::Path> {
+        self.descendants().filter_map(|node| {
+            match_ast! {
+                match node {
+                    ast::Path(it) => Some(it),
+                    _ => None
+                }
+            }
+        })
+    }
+
     fn from_expr(expr: ast::Expr) -> Option<Self> {
         match expr {
             ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr),
@@ -731,6 +753,7 @@ impl FunctionBody {
                 parent_loop.get_or_insert(loop_.syntax().clone());
             }
         };
+
         let (is_const, expr, ty) = loop {
             let anc = ancestors.next()?;
             break match_ast! {
@@ -798,7 +821,19 @@ impl FunctionBody {
             container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
                 container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
             });
-        Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty })
+
+        let parent = self.parent()?;
+        let generic_param_lists = parent_generic_param_lists(&parent);
+        let where_clauses = parent_where_clauses(&parent);
+
+        Some(ContainerInfo {
+            is_in_tail,
+            is_const,
+            parent_loop,
+            ret_type: ty,
+            generic_param_lists,
+            where_clauses,
+        })
     }
 
     fn return_ty(&self, ctx: &AssistContext) -> Option<RetType> {
@@ -955,6 +990,26 @@ impl FunctionBody {
     }
 }
 
+fn parent_where_clauses(parent: &SyntaxNode) -> Vec<ast::WhereClause> {
+    let mut where_clause: Vec<ast::WhereClause> = parent
+        .ancestors()
+        .filter_map(ast::AnyHasGenericParams::cast)
+        .filter_map(|it| it.where_clause())
+        .collect();
+    where_clause.reverse();
+    where_clause
+}
+
+fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec<ast::GenericParamList> {
+    let mut generic_param_list: Vec<ast::GenericParamList> = parent
+        .ancestors()
+        .filter_map(ast::AnyHasGenericParams::cast)
+        .filter_map(|it| it.generic_param_list())
+        .collect();
+    generic_param_list.reverse();
+    generic_param_list
+}
+
 /// checks if relevant var is used with `&mut` access inside body
 fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool {
     usages
@@ -1362,37 +1417,154 @@ fn format_function(
     let const_kw = if fun.mods.is_const { "const " } else { "" };
     let async_kw = if fun.control_flow.is_async { "async " } else { "" };
     let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
+    let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
     match ctx.config.snippet_cap {
         Some(_) => format_to!(
             fn_def,
-            "\n\n{}{}{}{}fn $0{}{}",
+            "\n\n{}{}{}{}fn $0{}",
             new_indent,
             const_kw,
             async_kw,
             unsafe_kw,
             fun.name,
-            params
         ),
         None => format_to!(
             fn_def,
-            "\n\n{}{}{}{}fn {}{}",
+            "\n\n{}{}{}{}fn {}",
             new_indent,
             const_kw,
             async_kw,
             unsafe_kw,
             fun.name,
-            params
         ),
     }
+
+    if let Some(generic_params) = generic_params {
+        format_to!(fn_def, "{}", generic_params);
+    }
+
+    format_to!(fn_def, "{}", params);
+
     if let Some(ret_ty) = ret_ty {
         format_to!(fn_def, " {}", ret_ty);
     }
+
+    if let Some(where_clause) = where_clause {
+        format_to!(fn_def, " {}", where_clause);
+    }
+
     format_to!(fn_def, " {}", body);
 
     fn_def
 }
 
+fn make_generic_params_and_where_clause(
+    ctx: &AssistContext,
+    fun: &Function,
+) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) {
+    let used_type_params = fun.type_params(ctx);
+
+    let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params);
+    let where_clause = make_where_clause(ctx, fun, &used_type_params);
+
+    (generic_param_list, where_clause)
+}
+
+fn make_generic_param_list(
+    ctx: &AssistContext,
+    fun: &Function,
+    used_type_params: &[TypeParam],
+) -> Option<ast::GenericParamList> {
+    let mut generic_params = fun
+        .mods
+        .generic_param_lists
+        .iter()
+        .flat_map(|parent_params| {
+            parent_params
+                .generic_params()
+                .filter(|param| param_is_required(ctx, param, used_type_params))
+        })
+        .peekable();
+
+    if generic_params.peek().is_some() {
+        Some(make::generic_param_list(generic_params))
+    } else {
+        None
+    }
+}
+
+fn param_is_required(
+    ctx: &AssistContext,
+    param: &ast::GenericParam,
+    used_type_params: &[TypeParam],
+) -> bool {
+    match param {
+        ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false,
+        ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) {
+            Some(def) => used_type_params.contains(def),
+            _ => false,
+        },
+    }
+}
+
+fn make_where_clause(
+    ctx: &AssistContext,
+    fun: &Function,
+    used_type_params: &[TypeParam],
+) -> Option<ast::WhereClause> {
+    let mut predicates = fun
+        .mods
+        .where_clauses
+        .iter()
+        .flat_map(|parent_where_clause| {
+            parent_where_clause
+                .predicates()
+                .filter(|pred| pred_is_required(ctx, pred, used_type_params))
+        })
+        .peekable();
+
+    if predicates.peek().is_some() {
+        Some(make::where_clause(predicates))
+    } else {
+        None
+    }
+}
+
+fn pred_is_required(
+    ctx: &AssistContext,
+    pred: &ast::WherePred,
+    used_type_params: &[TypeParam],
+) -> bool {
+    match resolved_type_param(ctx, pred) {
+        Some(it) => used_type_params.contains(&it),
+        None => false,
+    }
+}
+
+fn resolved_type_param(ctx: &AssistContext, pred: &ast::WherePred) -> Option<TypeParam> {
+    let path = match pred.ty()? {
+        ast::Type::PathType(path_type) => path_type.path(),
+        _ => None,
+    }?;
+
+    match ctx.sema.resolve_path(&path)? {
+        PathResolution::TypeParam(type_param) => Some(type_param),
+        _ => None,
+    }
+}
+
 impl Function {
+    /// Collect all the `TypeParam`s used in the `body` and `params`.
+    fn type_params(&self, ctx: &AssistContext) -> Vec<TypeParam> {
+        let type_params_in_descendant_paths =
+            self.body.descendant_paths().filter_map(|it| match ctx.sema.resolve_path(&it) {
+                Some(PathResolution::TypeParam(type_param)) => Some(type_param),
+                _ => None,
+            });
+        let type_params_in_params = self.params.iter().filter_map(|p| p.ty.as_type_param(ctx.db()));
+        type_params_in_descendant_paths.chain(type_params_in_params).collect()
+    }
+
     fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList {
         let self_param = self.self_param.clone();
         let params = self.params.iter().map(|param| param.to_param(ctx, module));
@@ -4872,6 +5044,254 @@ fn parent(factor: i32) {
 fn $0fun_name(v: &[i32; 3], factor: i32) {
     v.iter().map(|it| it * factor);
 }
+"#,
+        );
+    }
+
+    #[test]
+    fn preserve_generics() {
+        check_assist(
+            extract_function,
+            r#"
+fn func<T: Debug>(i: T) {
+    $0foo(i);$0
+}
+"#,
+            r#"
+fn func<T: Debug>(i: T) {
+    fun_name(i);
+}
+
+fn $0fun_name<T: Debug>(i: T) {
+    foo(i);
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn preserve_generics_from_body() {
+        check_assist(
+            extract_function,
+            r#"
+fn func<T: Default>() -> T {
+    $0T::default()$0
+}
+"#,
+            r#"
+fn func<T: Default>() -> T {
+    fun_name()
+}
+
+fn $0fun_name<T: Default>() -> T {
+    T::default()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn filter_unused_generics() {
+        check_assist(
+            extract_function,
+            r#"
+fn func<T: Debug, U: Copy>(i: T, u: U) {
+    bar(u);
+    $0foo(i);$0
+}
+"#,
+            r#"
+fn func<T: Debug, U: Copy>(i: T, u: U) {
+    bar(u);
+    fun_name(i);
+}
+
+fn $0fun_name<T: Debug>(i: T) {
+    foo(i);
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn empty_generic_param_list() {
+        check_assist(
+            extract_function,
+            r#"
+fn func<T: Debug>(t: T, i: u32) {
+    bar(t);
+    $0foo(i);$0
+}
+"#,
+            r#"
+fn func<T: Debug>(t: T, i: u32) {
+    bar(t);
+    fun_name(i);
+}
+
+fn $0fun_name(i: u32) {
+    foo(i);
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn preserve_where_clause() {
+        check_assist(
+            extract_function,
+            r#"
+fn func<T>(i: T) where T: Debug {
+    $0foo(i);$0
+}
+"#,
+            r#"
+fn func<T>(i: T) where T: Debug {
+    fun_name(i);
+}
+
+fn $0fun_name<T>(i: T) where T: Debug {
+    foo(i);
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn filter_unused_where_clause() {
+        check_assist(
+            extract_function,
+            r#"
+fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
+    bar(u);
+    $0foo(i);$0
+}
+"#,
+            r#"
+fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
+    bar(u);
+    fun_name(i);
+}
+
+fn $0fun_name<T>(i: T) where T: Debug {
+    foo(i);
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn nested_generics() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct<T: Into<i32>>(T);
+impl <T: Into<i32> + Copy> Struct<T> {
+    fn func<V: Into<i32>>(&self, v: V) -> i32 {
+        let t = self.0;
+        $0t.into() + v.into()$0
+    }
+}
+"#,
+            r#"
+struct Struct<T: Into<i32>>(T);
+impl <T: Into<i32> + Copy> Struct<T> {
+    fn func<V: Into<i32>>(&self, v: V) -> i32 {
+        let t = self.0;
+        fun_name(t, v)
+    }
+}
+
+fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
+    t.into() + v.into()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn filters_unused_nested_generics() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct<T: Into<i32>, U: Debug>(T, U);
+impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
+    fn func<V: Into<i32>>(&self, v: V) -> i32 {
+        let t = self.0;
+        $0t.into() + v.into()$0
+    }
+}
+"#,
+            r#"
+struct Struct<T: Into<i32>, U: Debug>(T, U);
+impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
+    fn func<V: Into<i32>>(&self, v: V) -> i32 {
+        let t = self.0;
+        fun_name(t, v)
+    }
+}
+
+fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
+    t.into() + v.into()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn nested_where_clauses() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct<T>(T) where T: Into<i32>;
+impl <T> Struct<T> where T: Into<i32> + Copy {
+    fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
+        let t = self.0;
+        $0t.into() + v.into()$0
+    }
+}
+"#,
+            r#"
+struct Struct<T>(T) where T: Into<i32>;
+impl <T> Struct<T> where T: Into<i32> + Copy {
+    fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
+        let t = self.0;
+        fun_name(t, v)
+    }
+}
+
+fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
+    t.into() + v.into()
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn filters_unused_nested_where_clauses() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
+impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
+    fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
+        let t = self.0;
+        $0t.into() + v.into()$0
+    }
+}
+"#,
+            r#"
+struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
+impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
+    fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
+        let t = self.0;
+        fun_name(t, v)
+    }
+}
+
+fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
+    t.into() + v.into()
+}
 "#,
         );
     }

From 796641b5d85d06f2884e28971a358421976aefaa Mon Sep 17 00:00:00 2001
From: Dorian Scheidt <dorian.scheidt@gmail.com>
Date: Wed, 13 Jul 2022 10:20:55 -0500
Subject: [PATCH 2/2] Make search for applicable generics more precise

---
 .../src/handlers/extract_function.rs          | 67 ++++++++++++++-----
 1 file changed, 49 insertions(+), 18 deletions(-)

diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index aa1c3a548c9..94b638d4c60 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -823,8 +823,9 @@ impl FunctionBody {
             });
 
         let parent = self.parent()?;
-        let generic_param_lists = parent_generic_param_lists(&parent);
-        let where_clauses = parent_where_clauses(&parent);
+        let parents = generic_parents(&parent);
+        let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
+        let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
 
         Some(ContainerInfo {
             is_in_tail,
@@ -990,24 +991,54 @@ impl FunctionBody {
     }
 }
 
-fn parent_where_clauses(parent: &SyntaxNode) -> Vec<ast::WhereClause> {
-    let mut where_clause: Vec<ast::WhereClause> = parent
-        .ancestors()
-        .filter_map(ast::AnyHasGenericParams::cast)
-        .filter_map(|it| it.where_clause())
-        .collect();
-    where_clause.reverse();
-    where_clause
+enum GenericParent {
+    Fn(ast::Fn),
+    Impl(ast::Impl),
+    Trait(ast::Trait),
 }
 
-fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec<ast::GenericParamList> {
-    let mut generic_param_list: Vec<ast::GenericParamList> = parent
-        .ancestors()
-        .filter_map(ast::AnyHasGenericParams::cast)
-        .filter_map(|it| it.generic_param_list())
-        .collect();
-    generic_param_list.reverse();
-    generic_param_list
+impl GenericParent {
+    fn generic_param_list(&self) -> Option<ast::GenericParamList> {
+        match self {
+            GenericParent::Fn(fn_) => fn_.generic_param_list(),
+            GenericParent::Impl(impl_) => impl_.generic_param_list(),
+            GenericParent::Trait(trait_) => trait_.generic_param_list(),
+        }
+    }
+
+    fn where_clause(&self) -> Option<ast::WhereClause> {
+        match self {
+            GenericParent::Fn(fn_) => fn_.where_clause(),
+            GenericParent::Impl(impl_) => impl_.where_clause(),
+            GenericParent::Trait(trait_) => trait_.where_clause(),
+        }
+    }
+}
+
+/// Search `parent`'s ancestors for items with potentially applicable generic parameters
+fn generic_parents(parent: &SyntaxNode) -> Vec<GenericParent> {
+    let mut list = Vec::new();
+    if let Some(parent_item) = parent.ancestors().find_map(ast::Item::cast) {
+        match parent_item {
+            ast::Item::Fn(ref fn_) => {
+                if let Some(parent_parent) = parent_item
+                    .syntax()
+                    .parent()
+                    .and_then(|it| it.parent())
+                    .and_then(ast::Item::cast)
+                {
+                    match parent_parent {
+                        ast::Item::Impl(impl_) => list.push(GenericParent::Impl(impl_)),
+                        ast::Item::Trait(trait_) => list.push(GenericParent::Trait(trait_)),
+                        _ => (),
+                    }
+                }
+                list.push(GenericParent::Fn(fn_.clone()));
+            }
+            _ => (),
+        }
+    }
+    list
 }
 
 /// checks if relevant var is used with `&mut` access inside body