diff --git a/crates/assists/src/handlers/replace_impl_trait_with_generic.rs b/crates/assists/src/handlers/replace_impl_trait_with_generic.rs new file mode 100644 index 00000000000..6738bc13419 --- /dev/null +++ b/crates/assists/src/handlers/replace_impl_trait_with_generic.rs @@ -0,0 +1,168 @@ +use syntax::ast::{self, edit::AstNodeEdit, make, AstNode, GenericParamsOwner}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: replace_impl_trait_with_generic +// +// Replaces `impl Trait` function argument with the named generic. +// +// ``` +// fn foo(bar: <|>impl Bar) {} +// ``` +// -> +// ``` +// fn foo(bar: B) {} +// ``` +pub(crate) fn replace_impl_trait_with_generic( + acc: &mut Assists, + ctx: &AssistContext, +) -> Option<()> { + let type_impl_trait = ctx.find_node_at_offset::()?; + let type_param = type_impl_trait.syntax().parent().and_then(ast::Param::cast)?; + let type_fn = type_param.syntax().ancestors().find_map(ast::Fn::cast)?; + + let impl_trait_ty = type_impl_trait.type_bound_list()?; + + let target = type_fn.syntax().text_range(); + acc.add( + AssistId("replace_impl_trait_with_generic", AssistKind::RefactorRewrite), + "Replace impl trait with generic", + target, + |edit| { + let generic_letter = impl_trait_ty.to_string().chars().next().unwrap().to_string(); + + let generic_param_list = type_fn + .generic_param_list() + .unwrap_or_else(|| make::generic_param_list(None)) + .append_param(make::generic_param(generic_letter.clone(), Some(impl_trait_ty))); + + let new_type_fn = type_fn + .replace_descendant::(type_impl_trait.into(), make::ty(&generic_letter)) + .with_generic_param_list(generic_param_list); + + edit.replace_ast(type_fn.clone(), new_type_fn); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::check_assist; + + #[test] + fn replace_impl_trait_with_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo(bar: <|>impl Bar) {} + "#, + r#" + fn foo(bar: B) {} + "#, + ); + } + + #[test] + fn replace_impl_trait_without_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo(bar: <|>impl Bar) {} + "#, + r#" + fn foo(bar: B) {} + "#, + ); + } + + #[test] + fn replace_two_impl_trait_with_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo(foo: impl Foo, bar: <|>impl Bar) {} + "#, + r#" + fn foo(foo: impl Foo, bar: B) {} + "#, + ); + } + + #[test] + fn replace_impl_trait_with_empty_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo<>(bar: <|>impl Bar) {} + "#, + r#" + fn foo(bar: B) {} + "#, + ); + } + + #[test] + fn replace_impl_trait_with_empty_multiline_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo< + >(bar: <|>impl Bar) {} + "#, + r#" + fn foo(bar: B) {} + "#, + ); + } + + #[test] + #[ignore = "This case is very rare but there is no simple solutions to fix it."] + fn replace_impl_trait_with_exist_generic_letter() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo(bar: <|>impl Bar) {} + "#, + r#" + fn foo(bar: C) {} + "#, + ); + } + + #[test] + fn replace_impl_trait_with_multiline_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo< + G: Foo, + F, + H, + >(bar: <|>impl Bar) {} + "#, + r#" + fn foo< + G: Foo, + F, + H, B: Bar + >(bar: B) {} + "#, + ); + } + + #[test] + fn replace_impl_trait_multiple() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo(bar: <|>impl Foo + Bar) {} + "#, + r#" + fn foo(bar: F) {} + "#, + ); + } +} diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs index 2e0d191a609..cbac53e7111 100644 --- a/crates/assists/src/lib.rs +++ b/crates/assists/src/lib.rs @@ -155,6 +155,7 @@ mod handlers { mod remove_unused_param; mod reorder_fields; mod replace_if_let_with_match; + mod replace_impl_trait_with_generic; mod replace_let_with_if_let; mod replace_qualified_name_with_use; mod replace_unwrap_with_match; @@ -202,6 +203,7 @@ pub(crate) fn all() -> &'static [Handler] { remove_unused_param::remove_unused_param, reorder_fields::reorder_fields, replace_if_let_with_match::replace_if_let_with_match, + replace_impl_trait_with_generic::replace_impl_trait_with_generic, replace_let_with_if_let::replace_let_with_if_let, replace_qualified_name_with_use::replace_qualified_name_with_use, replace_unwrap_with_match::replace_unwrap_with_match, diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs index 04c8fd1f94e..27d15adb08a 100644 --- a/crates/assists/src/tests/generated.rs +++ b/crates/assists/src/tests/generated.rs @@ -814,6 +814,19 @@ fn handle(action: Action) { ) } +#[test] +fn doctest_replace_impl_trait_with_generic() { + check_doc_test( + "replace_impl_trait_with_generic", + r#####" +fn foo(bar: <|>impl Bar) {} +"#####, + r#####" +fn foo(bar: B) {} +"#####, + ) +} + #[test] fn doctest_replace_let_with_if_let() { check_doc_test( diff --git a/crates/syntax/src/ast/edit.rs b/crates/syntax/src/ast/edit.rs index 82347533326..8b1c65dd6f1 100644 --- a/crates/syntax/src/ast/edit.rs +++ b/crates/syntax/src/ast/edit.rs @@ -13,7 +13,7 @@ ast::{ self, make::{self, tokens}, - AstNode, TypeBoundsOwner, + AstNode, GenericParamsOwner, NameOwner, TypeBoundsOwner, }, AstToken, Direction, InsertPosition, SmolStr, SyntaxElement, SyntaxKind, SyntaxKind::{ATTR, COMMENT, WHITESPACE}, @@ -46,6 +46,19 @@ pub fn with_body(&self, body: ast::BlockExpr) -> ast::Fn { to_insert.push(body.syntax().clone().into()); self.replace_children(single_node(old_body_or_semi), to_insert) } + + #[must_use] + pub fn with_generic_param_list(&self, generic_args: ast::GenericParamList) -> ast::Fn { + if let Some(old) = self.generic_param_list() { + return self.replace_descendant(old, generic_args); + } + + let anchor = self.name().expect("The function must have a name").syntax().clone(); + + let mut to_insert: ArrayVec<[SyntaxElement; 1]> = ArrayVec::new(); + to_insert.push(generic_args.syntax().clone().into()); + self.insert_children(InsertPosition::After(anchor.into()), to_insert) + } } fn make_multiline(node: N) -> N @@ -459,6 +472,61 @@ pub fn append_arm(&self, item: ast::MatchArm) -> ast::MatchArmList { } } +impl ast::GenericParamList { + #[must_use] + pub fn append_params( + &self, + params: impl IntoIterator, + ) -> ast::GenericParamList { + let mut res = self.clone(); + params.into_iter().for_each(|it| res = res.append_param(it)); + res + } + + #[must_use] + pub fn append_param(&self, item: ast::GenericParam) -> ast::GenericParamList { + let space = tokens::single_space(); + + let mut to_insert: ArrayVec<[SyntaxElement; 4]> = ArrayVec::new(); + if self.generic_params().next().is_some() { + to_insert.push(space.into()); + } + to_insert.push(item.syntax().clone().into()); + + macro_rules! after_l_angle { + () => {{ + let anchor = match self.l_angle_token() { + Some(it) => it.into(), + None => return self.clone(), + }; + InsertPosition::After(anchor) + }}; + } + + macro_rules! after_field { + ($anchor:expr) => { + if let Some(comma) = $anchor + .syntax() + .siblings_with_tokens(Direction::Next) + .find(|it| it.kind() == T![,]) + { + InsertPosition::After(comma) + } else { + to_insert.insert(0, make::token(T![,]).into()); + InsertPosition::After($anchor.syntax().clone().into()) + } + }; + }; + + let position = match self.generic_params().last() { + Some(it) => after_field!(it), + None => after_l_angle!(), + }; + + self.insert_children(position, to_insert) + } +} + #[must_use] pub fn remove_attrs_and_docs(node: &N) -> N { N::cast(remove_attrs_and_docs_inner(node.syntax().clone())).unwrap() diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 33f1ad7b34e..25e8a359d9e 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -294,6 +294,21 @@ pub fn param_list(pats: impl IntoIterator) -> ast::ParamList ast_from_text(&format!("fn f({}) {{ }}", args)) } +pub fn generic_param(name: String, ty: Option) -> ast::GenericParam { + let bound = match ty { + Some(it) => format!(": {}", it), + None => String::new(), + }; + ast_from_text(&format!("fn f<{}{}>() {{ }}", name, bound)) +} + +pub fn generic_param_list( + pats: impl IntoIterator, +) -> ast::GenericParamList { + let args = pats.into_iter().join(", "); + ast_from_text(&format!("fn f<{}>() {{ }}", args)) +} + pub fn visibility_pub_crate() -> ast::Visibility { ast_from_text("pub(crate) struct S") }