diff --git a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs index 007aba23d21..d3ff7b65cd0 100644 --- a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs @@ -11,14 +11,19 @@ search::FileReference, RootDatabase, }; +use itertools::Itertools; use rustc_hash::FxHashSet; use syntax::{ - algo::find_node_at_offset, - ast::{self, make, AstNode, NameOwner, VisibilityOwner}, - ted, SyntaxNode, T, + ast::{ + self, make, AstNode, AttrsOwner, GenericParamsOwner, NameOwner, TypeBoundsOwner, + VisibilityOwner, + }, + match_ast, + ted::{self, Position}, + SyntaxNode, T, }; -use crate::{AssistContext, AssistId, AssistKind, Assists}; +use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists}; // Assist: extract_struct_from_enum_variant // @@ -70,11 +75,10 @@ pub(crate) fn extract_struct_from_enum_variant( continue; } builder.edit_file(file_id); - let source_file = builder.make_mut(ctx.sema.parse(file_id)); let processed = process_references( ctx, + builder, &mut visited_modules_set, - source_file.syntax(), &enum_module_def, &variant_hir_name, references, @@ -84,13 +88,12 @@ pub(crate) fn extract_struct_from_enum_variant( }); } builder.edit_file(ctx.frange.file_id); - let source_file = builder.make_mut(ctx.sema.parse(ctx.frange.file_id)); let variant = builder.make_mut(variant.clone()); if let Some(references) = def_file_references { let processed = process_references( ctx, + builder, &mut visited_modules_set, - source_file.syntax(), &enum_module_def, &variant_hir_name, references, @@ -100,12 +103,12 @@ pub(crate) fn extract_struct_from_enum_variant( }); } - let def = create_struct_def(variant_name.clone(), &field_list, enum_ast.visibility()); + let def = create_struct_def(variant_name.clone(), &field_list, &enum_ast); let start_offset = &variant.parent_enum().syntax().clone(); ted::insert_raw(ted::Position::before(start_offset), def.syntax()); ted::insert_raw(ted::Position::before(start_offset), &make::tokens::blank_line()); - update_variant(&variant); + update_variant(&variant, enum_ast.generic_param_list()); }, ) } @@ -149,7 +152,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va fn create_struct_def( variant_name: ast::Name, field_list: &Either, - visibility: Option, + enum_: &ast::Enum, ) -> ast::Struct { let pub_vis = make::visibility_pub(); @@ -184,12 +187,38 @@ fn create_struct_def( } }; - make::struct_(visibility, variant_name, None, field_list).clone_for_update() + // FIXME: This uses all the generic params of the enum, but the variant might not use all of them. + let strukt = + make::struct_(enum_.visibility(), variant_name, enum_.generic_param_list(), field_list) + .clone_for_update(); + + // copy attributes + ted::insert_all( + Position::first_child_of(strukt.syntax()), + enum_.attrs().map(|it| it.syntax().clone_for_update().into()).collect(), + ); + strukt } -fn update_variant(variant: &ast::Variant) -> Option<()> { +fn update_variant(variant: &ast::Variant, generic: Option) -> Option<()> { let name = variant.name()?; - let tuple_field = make::tuple_field(None, make::ty(&name.text())); + let ty = match generic { + // FIXME: This uses all the generic params of the enum, but the variant might not use all of them. + Some(gpl) => { + let gpl = gpl.clone_for_update(); + gpl.generic_params().for_each(|gp| { + match gp { + ast::GenericParam::LifetimeParam(it) => it.type_bound_list(), + ast::GenericParam::TypeParam(it) => it.type_bound_list(), + ast::GenericParam::ConstParam(_) => return, + } + .map(|it| it.remove()); + }); + make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", "))) + } + None => make::ty(&name.text()), + }; + let tuple_field = make::tuple_field(None, ty); let replacement = make::variant( name, Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))), @@ -208,18 +237,17 @@ fn apply_references( if let Some((scope, path)) = import { insert_use(&scope, mod_path_to_ast(&path), insert_use_cfg); } - ted::insert_raw( - ted::Position::before(segment.syntax()), - make::path_from_text(&format!("{}", segment)).clone_for_update().syntax(), - ); + // deep clone to prevent cycle + let path = make::path_from_segments(iter::once(segment.clone_subtree()), false); + ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax()); ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['('])); ted::insert_raw(ted::Position::after(&node), make::token(T![')'])); } fn process_references( ctx: &AssistContext, + builder: &mut AssistBuilder, visited_modules: &mut FxHashSet, - source_file: &SyntaxNode, enum_module_def: &ModuleDef, variant_hir_name: &Name, refs: Vec, @@ -228,8 +256,9 @@ fn process_references( // and corresponding nodes up front refs.into_iter() .flat_map(|reference| { - let (segment, scope_node, module) = - reference_to_node(&ctx.sema, source_file, reference)?; + let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?; + let segment = builder.make_mut(segment); + let scope_node = builder.make_syntax_mut(scope_node); if !visited_modules.contains(&module) { let mod_path = module.find_use_path_prefixed( ctx.sema.db, @@ -251,23 +280,22 @@ fn process_references( fn reference_to_node( sema: &hir::Semantics, - source_file: &SyntaxNode, reference: FileReference, ) -> Option<(ast::PathSegment, SyntaxNode, hir::Module)> { - let offset = reference.range.start(); - if let Some(path_expr) = find_node_at_offset::(source_file, offset) { - // tuple variant - Some((path_expr.path()?.segment()?, path_expr.syntax().parent()?)) - } else if let Some(record_expr) = find_node_at_offset::(source_file, offset) { - // record variant - Some((record_expr.path()?.segment()?, record_expr.syntax().clone())) - } else { - None - } - .and_then(|(segment, expr)| { - let module = sema.scope(&expr).module()?; - Some((segment, expr, module)) - }) + let segment = + reference.name.as_name_ref()?.syntax().parent().and_then(ast::PathSegment::cast)?; + let parent = segment.parent_path().syntax().parent()?; + let expr_or_pat = match_ast! { + match parent { + ast::PathExpr(_it) => parent.parent()?, + ast::RecordExpr(_it) => parent, + ast::TupleStructPat(_it) => parent, + ast::RecordPat(_it) => parent, + _ => return None, + } + }; + let module = sema.scope(&expr_or_pat).module()?; + Some((segment, expr_or_pat, module)) } #[cfg(test)] @@ -278,6 +306,12 @@ mod tests { use super::*; + fn check_not_applicable(ra_fixture: &str) { + let fixture = + format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE); + check_assist_not_applicable(extract_struct_from_enum_variant, &fixture) + } + #[test] fn test_extract_struct_several_fields_tuple() { check_assist( @@ -311,6 +345,32 @@ enum A { One(One) }"#, ); } + #[test] + fn test_extract_struct_carries_over_generics() { + check_assist( + extract_struct_from_enum_variant, + r"enum En { Var { a: T$0 } }", + r#"struct Var{ pub a: T } + +enum En { Var(Var) }"#, + ); + } + + #[test] + fn test_extract_struct_carries_over_attributes() { + check_assist( + extract_struct_from_enum_variant, + r#"#[derive(Debug)] +#[derive(Clone)] +enum Enum { Variant{ field: u32$0 } }"#, + r#"#[derive(Debug)]#[derive(Clone)] struct Variant{ pub field: u32 } + +#[derive(Debug)] +#[derive(Clone)] +enum Enum { Variant(Variant) }"#, + ); + } + #[test] fn test_extract_struct_keep_comments_and_attrs_one_field_named() { check_assist( @@ -496,7 +556,7 @@ enum E { } fn f() { - let e = E::V { i: 9, j: 2 }; + let E::V { i, j } = E::V { i: 9, j: 2 }; } "#, r#" @@ -507,7 +567,34 @@ enum E { } fn f() { - let e = E::V(V { i: 9, j: 2 }); + let E::V(V { i, j }) = E::V(V { i: 9, j: 2 }); +} +"#, + ) + } + + #[test] + fn extract_record_fix_references2() { + check_assist( + extract_struct_from_enum_variant, + r#" +enum E { + $0V(i32, i32) +} + +fn f() { + let E::V(i, j) = E::V(9, 2); +} +"#, + r#" +struct V(pub i32, pub i32); + +enum E { + V(V) +} + +fn f() { + let E::V(V(i, j)) = E::V(V(9, 2)); } "#, ) @@ -610,12 +697,6 @@ fn foo() { ); } - fn check_not_applicable(ra_fixture: &str) { - let fixture = - format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE); - check_assist_not_applicable(extract_struct_from_enum_variant, &fixture) - } - #[test] fn test_extract_enum_not_applicable_for_element_with_no_fields() { check_not_applicable("enum A { $0One }"); diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 0cf17062610..4c3c9661d44 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -580,12 +580,11 @@ pub fn fn_( pub fn struct_( visibility: Option, strukt_name: ast::Name, - type_params: Option, + generic_param_list: Option, field_list: ast::FieldList, ) -> ast::Struct { let semicolon = if matches!(field_list, ast::FieldList::TupleFieldList(_)) { ";" } else { "" }; - let type_params = - if let Some(type_params) = type_params { format!("<{}>", type_params) } else { "".into() }; + let type_params = generic_param_list.map_or_else(String::new, |it| it.to_string()); let visibility = match visibility { None => String::new(), Some(it) => format!("{} ", it),