Verify name references more rigidly

Previously we didn't verify that record expressions/patterns that were
found did actually point to the struct we're operating on. Moreover,
when that record expressions/patterns had missing child nodes, we would
continue traversing their ancestor nodes.
This commit is contained in:
Ryo Yoshida 2023-05-29 19:44:31 +09:00
parent ab9347542c
commit 033e6ac57a
No known key found for this signature in database
GPG Key ID: E25698A930586171

View File

@ -1,9 +1,9 @@
use either::Either;
use ide_db::defs::Definition;
use ide_db::{defs::Definition, search::FileReference};
use itertools::Itertools;
use syntax::{
ast::{self, AstNode, HasGenericParams, HasVisibility},
match_ast, SyntaxKind, SyntaxNode,
match_ast, SyntaxKind,
};
use crate::{assist_context::SourceChangeBuilder, AssistContext, AssistId, AssistKind, Assists};
@ -141,59 +141,72 @@ fn edit_struct_references(
};
let usages = strukt_def.usages(&ctx.sema).include_self_refs().all();
let edit_node = |edit: &mut SourceChangeBuilder, node: SyntaxNode| -> Option<()> {
match_ast! {
match node {
ast::RecordPat(record_struct_pat) => {
let Some(fr) = ctx.sema.original_range_opt(record_struct_pat.syntax()) else {
// We've found the node to replace, so we should return `Some` even if the
// replacement failed to stop the ancestor node traversal.
return Some(());
};
edit.replace(
fr.range,
ast::make::tuple_struct_pat(
record_struct_pat.path()?,
record_struct_pat
.record_pat_field_list()?
.fields()
.filter_map(|pat| pat.pat())
)
.to_string()
);
},
ast::RecordExpr(record_expr) => {
let Some(fr) = ctx.sema.original_range_opt(record_expr.syntax()) else {
// See the comment above.
return Some(());
};
let path = record_expr.path()?;
let args = record_expr
.record_expr_field_list()?
.fields()
.filter_map(|f| f.expr())
.join(", ");
edit.replace(fr.range, format!("{path}({args})"));
},
_ => return None,
}
}
Some(())
};
for (file_id, refs) in usages {
edit.edit_file(file_id);
for r in refs {
for node in r.name.syntax().ancestors() {
if edit_node(edit, node).is_some() {
break;
}
}
process_struct_name_reference(ctx, r, edit);
}
}
}
fn process_struct_name_reference(
ctx: &AssistContext<'_>,
r: FileReference,
edit: &mut SourceChangeBuilder,
) -> Option<()> {
// First check if it's the last semgnet of a path that directly belongs to a record
// expression/pattern.
let name_ref = r.name.as_name_ref()?;
let path_segment = name_ref.syntax().parent().and_then(ast::PathSegment::cast)?;
// A `PathSegment` always belongs to a `Path`, so there's at least one `Path` at this point.
let full_path =
path_segment.syntax().parent()?.ancestors().map_while(ast::Path::cast).last().unwrap();
if full_path.segment().unwrap().name_ref()? != *name_ref {
// `name_ref` isn't the last segment of the path, so `full_path` doesn't point to the
// struct we want to edit.
return None;
}
let parent = full_path.syntax().parent()?;
match_ast! {
match parent {
ast::RecordPat(record_struct_pat) => {
// When we failed to get the original range for the whole struct expression node,
// we can't provide any reasonable edit. Leave it untouched.
let file_range = ctx.sema.original_range_opt(record_struct_pat.syntax())?;
edit.replace(
file_range.range,
ast::make::tuple_struct_pat(
record_struct_pat.path()?,
record_struct_pat
.record_pat_field_list()?
.fields()
.filter_map(|pat| pat.pat())
)
.to_string()
);
},
ast::RecordExpr(record_expr) => {
// When we failed to get the original range for the whole struct pattern node,
// we can't provide any reasonable edit. Leave it untouched.
let file_range = ctx.sema.original_range_opt(record_expr.syntax())?;
let path = record_expr.path()?;
let args = record_expr
.record_expr_field_list()?
.fields()
.filter_map(|f| f.expr())
.join(", ");
edit.replace(file_range.range, format!("{path}({args})"));
},
_ => {}
}
}
Some(())
}
fn edit_field_references(
ctx: &AssistContext<'_>,
edit: &mut SourceChangeBuilder,
@ -898,6 +911,69 @@ fn test() {
let Struct(inner) = s;
}
}
"#,
);
}
#[test]
fn struct_name_ref_may_not_be_part_of_struct_expr_or_struct_pat() {
check_assist(
convert_named_struct_to_tuple_struct,
r#"
struct $0Struct {
inner: i32,
}
struct Outer<T> {
value: T,
}
fn foo<T>() -> T { loop {} }
fn test() {
Outer {
value: foo::<Struct>();
}
}
trait HasAssoc {
type Assoc;
fn test();
}
impl HasAssoc for Struct {
type Assoc = Outer<i32>;
fn test() {
let a = Self::Assoc {
value: 42,
};
let Self::Assoc { value } = a;
}
}
"#,
r#"
struct Struct(i32);
struct Outer<T> {
value: T,
}
fn foo<T>() -> T { loop {} }
fn test() {
Outer {
value: foo::<Struct>();
}
}
trait HasAssoc {
type Assoc;
fn test();
}
impl HasAssoc for Struct {
type Assoc = Outer<i32>;
fn test() {
let a = Self::Assoc {
value: 42,
};
let Self::Assoc { value } = a;
}
}
"#,
);
}