remove trailing return in trailing match expression

This commit is contained in:
davidsemakula 2024-01-31 07:04:49 +03:00
parent cad222ff1b
commit 98e6f43a2f
3 changed files with 72 additions and 8 deletions

View File

@ -272,6 +272,12 @@ impl ExprValidator {
self.check_for_trailing_return(*else_branch, body); self.check_for_trailing_return(*else_branch, body);
} }
} }
Expr::Match { arms, .. } => {
for arm in arms.iter() {
let MatchArm { expr, .. } = arm;
self.check_for_trailing_return(*expr, body);
}
}
Expr::Return { .. } => { Expr::Return { .. } => {
self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn { self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn {
return_expr: body_expr, return_expr: body_expr,

View File

@ -11,7 +11,7 @@ use cfg::{CfgExpr, CfgOptions};
use either::Either; use either::Either;
use hir_def::{body::SyntheticSyntax, hir::ExprOrPatId, path::ModPath, AssocItemId, DefWithBodyId}; use hir_def::{body::SyntheticSyntax, hir::ExprOrPatId, path::ModPath, AssocItemId, DefWithBodyId};
use hir_expand::{name::Name, HirFileId, InFile}; use hir_expand::{name::Name, HirFileId, InFile};
use syntax::{ast, AstPtr, SyntaxError, SyntaxNodePtr, TextRange}; use syntax::{ast, AstNode, AstPtr, SyntaxError, SyntaxNodePtr, TextRange};
use crate::{AssocItem, Field, Local, MacroKind, Trait, Type}; use crate::{AssocItem, Field, Local, MacroKind, Trait, Type};
@ -459,6 +459,8 @@ impl AnyDiagnostic {
} }
BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => { BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => {
if let Ok(source_ptr) = source_map.expr_syntax(return_expr) { if let Ok(source_ptr) = source_map.expr_syntax(return_expr) {
// Filters out desugared return expressions (e.g. desugared try operators).
if ast::ReturnExpr::can_cast(source_ptr.value.kind()) {
return Some( return Some(
RemoveTrailingReturn { RemoveTrailingReturn {
file_id: source_ptr.file_id, file_id: source_ptr.file_id,
@ -468,6 +470,7 @@ impl AnyDiagnostic {
); );
} }
} }
}
BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr } => { BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr } => {
if let Ok(source_ptr) = source_map.expr_syntax(if_expr) { if let Ok(source_ptr) = source_map.expr_syntax(if_expr) {
if let Some(ptr) = source_ptr.value.cast::<ast::IfExpr>() { if let Some(ptr) = source_ptr.value.cast::<ast::IfExpr>() {

View File

@ -147,6 +147,21 @@ fn foo(x: usize) -> u8 {
); );
} }
#[test]
fn remove_trailing_return_in_match() {
check_diagnostics(
r#"
fn foo<T, E>(x: Result<T, E>) -> u8 {
match x {
Ok(_) => return 1,
//^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
Err(_) => return 0,
} //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
}
"#,
);
}
#[test] #[test]
fn no_diagnostic_if_no_return_keyword() { fn no_diagnostic_if_no_return_keyword() {
check_diagnostics( check_diagnostics(
@ -316,6 +331,46 @@ fn foo(x: usize) -> u8 {
0 0
} }
} }
"#,
);
}
#[test]
fn replace_in_match() {
check_fix(
r#"
fn foo<T, E>(x: Result<T, E>) -> u8 {
match x {
Ok(_) => return$0 1,
Err(_) => 0,
}
}
"#,
r#"
fn foo<T, E>(x: Result<T, E>) -> u8 {
match x {
Ok(_) => 1,
Err(_) => 0,
}
}
"#,
);
check_fix(
r#"
fn foo<T, E>(x: Result<T, E>) -> u8 {
match x {
Ok(_) => 1,
Err(_) => return$0 0,
}
}
"#,
r#"
fn foo<T, E>(x: Result<T, E>) -> u8 {
match x {
Ok(_) => 1,
Err(_) => 0,
}
}
"#, "#,
); );
} }