diagnostic to remove trailing return

This commit is contained in:
davidsemakula 2024-01-30 19:57:36 +03:00
parent e07183461f
commit 2987fac76f
5 changed files with 318 additions and 2 deletions

View File

@ -44,6 +44,9 @@ pub enum BodyValidationDiagnostic {
match_expr: ExprId,
uncovered_patterns: String,
},
RemoveTrailingReturn {
return_expr: ExprId,
},
RemoveUnnecessaryElse {
if_expr: ExprId,
},
@ -75,6 +78,10 @@ fn validate_body(&mut self, db: &dyn HirDatabase) {
let body = db.body(self.owner);
let mut filter_map_next_checker = None;
if matches!(self.owner, DefWithBodyId::FunctionId(_)) {
self.check_for_trailing_return(body.body_expr, &body);
}
for (id, expr) in body.exprs.iter() {
if let Some((variant, missed_fields, true)) =
record_literal_missing_fields(db, &self.infer, id, expr)
@ -93,12 +100,16 @@ fn validate_body(&mut self, db: &dyn HirDatabase) {
Expr::Call { .. } | Expr::MethodCall { .. } => {
self.validate_call(db, id, expr, &mut filter_map_next_checker);
}
Expr::Closure { body: body_expr, .. } => {
self.check_for_trailing_return(*body_expr, &body);
}
Expr::If { .. } => {
self.check_for_unnecessary_else(id, expr, &body);
}
_ => {}
}
}
for (id, pat) in body.pats.iter() {
if let Some((variant, missed_fields, true)) =
record_pattern_missing_fields(db, &self.infer, id, pat)
@ -244,6 +255,26 @@ fn lower_pattern<'p>(
pattern
}
fn check_for_trailing_return(&mut self, body_expr: ExprId, body: &Body) {
match &body.exprs[body_expr] {
Expr::Block { statements, tail, .. } => {
let last_stmt = tail.or_else(|| match statements.last()? {
Statement::Expr { expr, .. } => Some(*expr),
_ => None,
});
if let Some(last_stmt) = last_stmt {
self.check_for_trailing_return(last_stmt, body);
}
}
Expr::Return { .. } => {
self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn {
return_expr: body_expr,
});
}
_ => (),
}
}
fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) {
if let Expr::If { condition: _, then_branch, else_branch } = expr {
if else_branch.is_none() {

View File

@ -68,6 +68,7 @@ fn from(d: $diag) -> AnyDiagnostic {
PrivateAssocItem,
PrivateField,
ReplaceFilterMapNextWithFindMap,
RemoveTrailingReturn,
RemoveUnnecessaryElse,
TraitImplIncorrectSafety,
TraitImplMissingAssocItems,
@ -343,6 +344,12 @@ pub struct TraitImplRedundantAssocItems {
pub assoc_item: (Name, AssocItem),
}
#[derive(Debug)]
pub struct RemoveTrailingReturn {
pub file_id: HirFileId,
pub return_expr: AstPtr<ast::Expr>,
}
#[derive(Debug)]
pub struct RemoveUnnecessaryElse {
pub if_expr: InFile<AstPtr<ast::IfExpr>>,
@ -450,6 +457,17 @@ pub(crate) fn body_validation_diagnostic(
Err(SyntheticSyntax) => (),
}
}
BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => {
if let Ok(source_ptr) = source_map.expr_syntax(return_expr) {
return Some(
RemoveTrailingReturn {
file_id: source_ptr.file_id,
return_expr: source_ptr.value,
}
.into(),
);
}
}
BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr } => {
if let Ok(source_ptr) = source_map.expr_syntax(if_expr) {
if let Some(ptr) = source_ptr.value.cast::<ast::IfExpr>() {

View File

@ -0,0 +1,262 @@
use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn, HirFileIdExt, InFile};
use ide_db::{assists::Assist, source_change::SourceChange};
use syntax::{ast, AstNode, SyntaxNodePtr};
use text_edit::TextEdit;
use crate::{fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
// Diagnostic: remove-trailing-return
//
// This diagnostic is triggered when there is a redundant `return` at the end of a function
// or closure.
pub(crate) fn remove_trailing_return(
ctx: &DiagnosticsContext<'_>,
d: &RemoveTrailingReturn,
) -> Diagnostic {
let display_range = ctx.sema.diagnostics_display_range(InFile {
file_id: d.file_id,
value: expr_stmt(ctx, d)
.as_ref()
.map(|stmt| SyntaxNodePtr::new(stmt.syntax()))
.unwrap_or_else(|| d.return_expr.into()),
});
Diagnostic::new(
DiagnosticCode::Clippy("needless_return"),
"replace return <expr>; with <expr>",
display_range,
)
.with_fixes(fixes(ctx, d))
}
fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<Assist>> {
let return_expr = return_expr(ctx, d)?;
let stmt = expr_stmt(ctx, d);
let range = stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax).text_range();
let replacement =
return_expr.expr().map_or_else(String::new, |expr| format!("{}", expr.syntax().text()));
let edit = TextEdit::replace(range, replacement);
let source_change = SourceChange::from_text_edit(d.file_id.original_file(ctx.sema.db), edit);
Some(vec![fix(
"remove_trailing_return",
"Replace return <expr>; with <expr>",
source_change,
range,
)])
}
fn return_expr(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ReturnExpr> {
let root = ctx.sema.db.parse_or_expand(d.file_id);
let expr = d.return_expr.to_node(&root);
ast::ReturnExpr::cast(expr.syntax().clone())
}
fn expr_stmt(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ExprStmt> {
let return_expr = return_expr(ctx, d)?;
return_expr.syntax().parent().and_then(ast::ExprStmt::cast)
}
#[cfg(test)]
mod tests {
use crate::tests::{check_diagnostics, check_fix};
#[test]
fn remove_trailing_return() {
check_diagnostics(
r#"
fn foo() -> u8 {
return 2;
} //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
"#,
);
}
#[test]
fn remove_trailing_return_inner_function() {
check_diagnostics(
r#"
fn foo() -> u8 {
fn bar() -> u8 {
return 2;
} //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
bar()
}
"#,
);
}
#[test]
fn remove_trailing_return_closure() {
check_diagnostics(
r#"
fn foo() -> u8 {
let bar = || return 2;
bar() //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
}
"#,
);
check_diagnostics(
r#"
fn foo() -> u8 {
let bar = || {
return 2;
};//^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
bar()
}
"#,
);
}
#[test]
fn remove_trailing_return_unit() {
check_diagnostics(
r#"
fn foo() {
return
} //^^^^^^ 💡 weak: replace return <expr>; with <expr>
"#,
);
}
#[test]
fn remove_trailing_return_no_semi() {
check_diagnostics(
r#"
fn foo() -> u8 {
return 2
} //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
"#,
);
}
#[test]
fn no_diagnostic_if_no_return_keyword() {
check_diagnostics(
r#"
fn foo() -> u8 {
3
}
"#,
);
}
#[test]
fn no_diagnostic_if_not_last_statement() {
check_diagnostics(
r#"
fn foo() -> u8 {
if true { return 2; }
3
}
"#,
);
}
#[test]
fn replace_with_expr() {
check_fix(
r#"
fn foo() -> u8 {
return$0 2;
}
"#,
r#"
fn foo() -> u8 {
2
}
"#,
);
}
#[test]
fn replace_with_unit() {
check_fix(
r#"
fn foo() {
return$0/*ensure tidy is happy*/
}
"#,
r#"
fn foo() {
/*ensure tidy is happy*/
}
"#,
);
}
#[test]
fn replace_with_expr_no_semi() {
check_fix(
r#"
fn foo() -> u8 {
return$0 2
}
"#,
r#"
fn foo() -> u8 {
2
}
"#,
);
}
#[test]
fn replace_in_inner_function() {
check_fix(
r#"
fn foo() -> u8 {
fn bar() -> u8 {
return$0 2;
}
bar()
}
"#,
r#"
fn foo() -> u8 {
fn bar() -> u8 {
2
}
bar()
}
"#,
);
}
#[test]
fn replace_in_closure() {
check_fix(
r#"
fn foo() -> u8 {
let bar = || return$0 2;
bar()
}
"#,
r#"
fn foo() -> u8 {
let bar = || 2;
bar()
}
"#,
);
check_fix(
r#"
fn foo() -> u8 {
let bar = || {
return$0 2;
};
bar()
}
"#,
r#"
fn foo() -> u8 {
let bar = || {
2
};
bar()
}
"#,
);
}
}

View File

@ -186,7 +186,9 @@ fn str_ref_to_owned(
#[cfg(test)]
mod tests {
use crate::tests::{check_diagnostics, check_fix, check_no_fix};
use crate::tests::{
check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix,
};
#[test]
fn missing_reference() {
@ -718,7 +720,7 @@ struct Bar {
#[test]
fn return_no_value() {
check_diagnostics(
check_diagnostics_with_disabled(
r#"
fn f() -> i32 {
return;
@ -727,6 +729,7 @@ fn f() -> i32 {
}
fn g() { return; }
"#,
std::iter::once("needless_return".to_string()),
);
}

View File

@ -43,6 +43,7 @@ mod handlers {
pub(crate) mod no_such_field;
pub(crate) mod private_assoc_item;
pub(crate) mod private_field;
pub(crate) mod remove_trailing_return;
pub(crate) mod remove_unnecessary_else;
pub(crate) mod replace_filter_map_next_with_find_map;
pub(crate) mod trait_impl_incorrect_safety;
@ -383,6 +384,7 @@ pub fn diagnostics(
AnyDiagnostic::UnusedVariable(d) => handlers::unused_variables::unused_variables(&ctx, &d),
AnyDiagnostic::BreakOutsideOfLoop(d) => handlers::break_outside_of_loop::break_outside_of_loop(&ctx, &d),
AnyDiagnostic::MismatchedTupleStructPatArgCount(d) => handlers::mismatched_arg_count::mismatched_tuple_struct_pat_arg_count(&ctx, &d),
AnyDiagnostic::RemoveTrailingReturn(d) => handlers::remove_trailing_return::remove_trailing_return(&ctx, &d),
AnyDiagnostic::RemoveUnnecessaryElse(d) => handlers::remove_unnecessary_else::remove_unnecessary_else(&ctx, &d),
};
res.push(d)