diff --git a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs index 745a870ab6b..7f2c01772ba 100644 --- a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs +++ b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs @@ -1,6 +1,6 @@ use ide_db::defs::{Definition, NameRefClass}; use syntax::{ - ast::{self, HasName}, + ast::{self, HasName, Name}, ted, AstNode, SyntaxNode, }; @@ -48,7 +48,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<' other => format!("{{ {other} }}"), }; let extracting_arm_pat = extracting_arm.pat()?; - let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?; + let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?; acc.add( AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite), @@ -56,7 +56,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<' let_stmt.syntax().text_range(), |builder| { let extracting_arm_pat = - rename_variable(&extracting_arm_pat, extracted_variable, binding); + rename_variable(&extracting_arm_pat, &extracted_variable_positions, binding); builder.replace( let_stmt.syntax().text_range(), format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"), @@ -95,14 +95,15 @@ fn find_arms( } // Given an extracting arm, find the extracted variable. -fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option { +fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option> { match arm.expr()? { ast::Expr::PathExpr(path) => { let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?; match NameRefClass::classify(&ctx.sema, &name_ref)? { NameRefClass::Definition(Definition::Local(local)) => { - let source = local.primary_source(ctx.db()).into_ident_pat()?; - Some(source.name()?) + let source = + local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name()); + source.collect() } _ => None, } @@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti } // Rename `extracted` with `binding` in `pat`. -fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::Pat) -> SyntaxNode { +fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode { let syntax = pat.syntax().clone_for_update(); - let extracted_syntax = syntax.covering_element(extracted.syntax().text_range()); + let extracted = extracted + .iter() + .map(|e| syntax.covering_element(e.syntax().text_range())) + .collect::>(); + for extracted_syntax in extracted { + // If `extracted` variable is a record field, we should rename it to `binding`, + // otherwise we just need to replace `extracted` with `binding`. - // If `extracted` variable is a record field, we should rename it to `binding`, - // otherwise we just need to replace `extracted` with `binding`. - - if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast) - { - if let Some(name_ref) = record_pat_field.field_name() { - ted::replace( - record_pat_field.syntax(), - ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding) + if let Some(record_pat_field) = + extracted_syntax.ancestors().find_map(ast::RecordPatField::cast) + { + if let Some(name_ref) = record_pat_field.field_name() { + ted::replace( + record_pat_field.syntax(), + ast::make::record_pat_field( + ast::make::name_ref(&name_ref.text()), + binding.clone(), + ) .syntax() .clone_for_update(), - ); + ); + } + } else { + ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update()); } - } else { - ted::replace(extracted_syntax, binding.syntax().clone_for_update()); } - syntax } @@ -162,6 +170,39 @@ fn foo(opt: Option<()>) { ); } + #[test] + fn or_pattern_multiple_binding() { + check_assist( + convert_match_to_let_else, + r#" +//- minicore: option +enum Foo { + A(u32), + B(u32), + C(String), +} + +fn foo(opt: Option) -> Result { + let va$0lue = match opt { + Some(Foo::A(it) | Foo::B(it)) => it, + _ => return Err(()), + }; +} + "#, + r#" +enum Foo { + A(u32), + B(u32), + C(String), +} + +fn foo(opt: Option) -> Result { + let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) }; +} + "#, + ); + } + #[test] fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() { cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);