Lower let expressions

This commit is contained in:
Chayim Refael Friedman 2022-01-23 05:39:26 +02:00
parent de8633f15f
commit 6bf6f4ff1d
4 changed files with 87 additions and 110 deletions

View File

@ -28,7 +28,7 @@ use crate::{
db::DefDatabase,
expr::{
dummy_expr_id, Array, BindingAnnotation, Expr, ExprId, Label, LabelId, Literal, MatchArm,
MatchGuard, Pat, PatId, RecordFieldPat, RecordLitField, Statement,
Pat, PatId, RecordFieldPat, RecordLitField, Statement,
},
intern::Interned,
item_scope::BuiltinShadowMode,
@ -155,9 +155,6 @@ impl ExprCollector<'_> {
fn alloc_expr_desugared(&mut self, expr: Expr) -> ExprId {
self.make_expr(expr, Err(SyntheticSyntax))
}
fn unit(&mut self) -> ExprId {
self.alloc_expr_desugared(Expr::Tuple { exprs: Box::default() })
}
fn missing_expr(&mut self) -> ExprId {
self.alloc_expr_desugared(Expr::Missing)
}
@ -215,33 +212,15 @@ impl ExprCollector<'_> {
}
});
let condition = match e.condition() {
None => self.missing_expr(),
Some(condition) => match condition.pat() {
None => self.collect_expr_opt(condition.expr()),
// if let -- desugar to match
Some(pat) => {
let pat = self.collect_pat(pat);
let match_expr = self.collect_expr_opt(condition.expr());
let placeholder_pat = self.missing_pat();
let arms = vec![
MatchArm { pat, expr: then_branch, guard: None },
MatchArm {
pat: placeholder_pat,
expr: else_branch.unwrap_or_else(|| self.unit()),
guard: None,
},
]
.into();
return Some(
self.alloc_expr(Expr::Match { expr: match_expr, arms }, syntax_ptr),
);
}
},
};
let condition = self.collect_expr_opt(e.condition());
self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr)
}
ast::Expr::LetExpr(e) => {
let pat = self.collect_pat_opt(e.pat());
let expr = self.collect_expr_opt(e.expr());
self.alloc_expr(Expr::Let { pat, expr }, syntax_ptr)
}
ast::Expr::BlockExpr(e) => match e.modifier() {
Some(ast::BlockModifier::Try(_)) => {
let body = self.collect_block(e);
@ -282,31 +261,7 @@ impl ExprCollector<'_> {
let label = e.label().map(|label| self.collect_label(label));
let body = self.collect_block_opt(e.loop_body());
let condition = match e.condition() {
None => self.missing_expr(),
Some(condition) => match condition.pat() {
None => self.collect_expr_opt(condition.expr()),
// if let -- desugar to match
Some(pat) => {
cov_mark::hit!(infer_resolve_while_let);
let pat = self.collect_pat(pat);
let match_expr = self.collect_expr_opt(condition.expr());
let placeholder_pat = self.missing_pat();
let break_ =
self.alloc_expr_desugared(Expr::Break { expr: None, label: None });
let arms = vec![
MatchArm { pat, expr: body, guard: None },
MatchArm { pat: placeholder_pat, expr: break_, guard: None },
]
.into();
let match_expr =
self.alloc_expr_desugared(Expr::Match { expr: match_expr, arms });
return Some(
self.alloc_expr(Expr::Loop { body: match_expr, label }, syntax_ptr),
);
}
},
};
let condition = self.collect_expr_opt(e.condition());
self.alloc_expr(Expr::While { condition, body, label }, syntax_ptr)
}
@ -352,15 +307,9 @@ impl ExprCollector<'_> {
self.check_cfg(&arm).map(|()| MatchArm {
pat: self.collect_pat_opt(arm.pat()),
expr: self.collect_expr_opt(arm.expr()),
guard: arm.guard().map(|guard| match guard.pat() {
Some(pat) => MatchGuard::IfLet {
pat: self.collect_pat(pat),
expr: self.collect_expr_opt(guard.expr()),
},
None => {
MatchGuard::If { expr: self.collect_expr_opt(guard.expr()) }
}
}),
guard: arm
.guard()
.map(|guard| self.collect_expr_opt(guard.condition())),
})
})
.collect()

View File

@ -8,7 +8,7 @@ use rustc_hash::FxHashMap;
use crate::{
body::Body,
db::DefDatabase,
expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement},
expr::{Expr, ExprId, LabelId, Pat, PatId, Statement},
BlockId, DefWithBodyId,
};
@ -53,9 +53,9 @@ impl ExprScopes {
fn new(body: &Body) -> ExprScopes {
let mut scopes =
ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() };
let root = scopes.root_scope();
let mut root = scopes.root_scope();
scopes.add_params_bindings(body, root, &body.params);
compute_expr_scopes(body.body_expr, body, &mut scopes, root);
compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root);
scopes
}
@ -151,32 +151,32 @@ fn compute_block_scopes(
match stmt {
Statement::Let { pat, initializer, else_branch, .. } => {
if let Some(expr) = initializer {
compute_expr_scopes(*expr, body, scopes, scope);
compute_expr_scopes(*expr, body, scopes, &mut scope);
}
if let Some(expr) = else_branch {
compute_expr_scopes(*expr, body, scopes, scope);
compute_expr_scopes(*expr, body, scopes, &mut scope);
}
scope = scopes.new_scope(scope);
scopes.add_bindings(body, scope, *pat);
}
Statement::Expr { expr, .. } => {
compute_expr_scopes(*expr, body, scopes, scope);
compute_expr_scopes(*expr, body, scopes, &mut scope);
}
}
}
if let Some(expr) = tail {
compute_expr_scopes(expr, body, scopes, scope);
compute_expr_scopes(expr, body, scopes, &mut scope);
}
}
fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId) {
let make_label =
|label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
scopes.set_scope(expr, scope);
scopes.set_scope(expr, *scope);
match &body[expr] {
Expr::Block { statements, tail, id, label } => {
let scope = scopes.new_block_scope(scope, *id, make_label(label));
let scope = scopes.new_block_scope(*scope, *id, make_label(label));
// Overwrite the old scope for the block expr, so that every block scope can be found
// via the block itself (important for blocks that only contain items, no expressions).
scopes.set_scope(expr, scope);
@ -184,46 +184,49 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
}
Expr::For { iterable, pat, body: body_expr, label } => {
compute_expr_scopes(*iterable, body, scopes, scope);
let scope = scopes.new_labeled_scope(scope, make_label(label));
let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
scopes.add_bindings(body, scope, *pat);
compute_expr_scopes(*body_expr, body, scopes, scope);
compute_expr_scopes(*body_expr, body, scopes, &mut scope);
}
Expr::While { condition, body: body_expr, label } => {
let scope = scopes.new_labeled_scope(scope, make_label(label));
compute_expr_scopes(*condition, body, scopes, scope);
compute_expr_scopes(*body_expr, body, scopes, scope);
let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
compute_expr_scopes(*condition, body, scopes, &mut scope);
compute_expr_scopes(*body_expr, body, scopes, &mut scope);
}
Expr::Loop { body: body_expr, label } => {
let scope = scopes.new_labeled_scope(scope, make_label(label));
compute_expr_scopes(*body_expr, body, scopes, scope);
let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
compute_expr_scopes(*body_expr, body, scopes, &mut scope);
}
Expr::Lambda { args, body: body_expr, .. } => {
let scope = scopes.new_scope(scope);
let mut scope = scopes.new_scope(*scope);
scopes.add_params_bindings(body, scope, args);
compute_expr_scopes(*body_expr, body, scopes, scope);
compute_expr_scopes(*body_expr, body, scopes, &mut scope);
}
Expr::Match { expr, arms } => {
compute_expr_scopes(*expr, body, scopes, scope);
for arm in arms.iter() {
let mut scope = scopes.new_scope(scope);
let mut scope = scopes.new_scope(*scope);
scopes.add_bindings(body, scope, arm.pat);
match arm.guard {
Some(MatchGuard::If { expr: guard }) => {
scopes.set_scope(guard, scope);
compute_expr_scopes(guard, body, scopes, scope);
}
Some(MatchGuard::IfLet { pat, expr: guard }) => {
scopes.set_scope(guard, scope);
compute_expr_scopes(guard, body, scopes, scope);
scope = scopes.new_scope(scope);
scopes.add_bindings(body, scope, pat);
}
_ => {}
};
scopes.set_scope(arm.expr, scope);
compute_expr_scopes(arm.expr, body, scopes, scope);
if let Some(guard) = arm.guard {
scope = scopes.new_scope(scope);
compute_expr_scopes(guard, body, scopes, &mut scope);
}
compute_expr_scopes(arm.expr, body, scopes, &mut scope);
}
}
&Expr::If { condition, then_branch, else_branch } => {
let mut then_branch_scope = scopes.new_scope(*scope);
compute_expr_scopes(condition, body, scopes, &mut then_branch_scope);
compute_expr_scopes(then_branch, body, scopes, &mut then_branch_scope);
if let Some(else_branch) = else_branch {
compute_expr_scopes(else_branch, body, scopes, scope);
}
}
&Expr::Let { pat, expr } => {
compute_expr_scopes(expr, body, scopes, scope);
*scope = scopes.new_scope(*scope);
scopes.add_bindings(body, *scope, pat);
}
e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
};
}
@ -500,8 +503,7 @@ fn foo() {
}
#[test]
fn while_let_desugaring() {
cov_mark::check!(infer_resolve_while_let);
fn while_let_adds_binding() {
do_check_local_name(
r#"
fn test() {
@ -513,5 +515,31 @@ fn test() {
"#,
75,
);
do_check_local_name(
r#"
fn test() {
let foo: Option<f32> = None;
while (((let Option::Some(_) = foo))) && let Option::Some(spam) = foo {
spam$0
}
}
"#,
107,
);
}
#[test]
fn match_guard_if_let() {
do_check_local_name(
r#"
fn test() {
let foo: Option<f32> = None;
match foo {
_ if let Option::Some(spam) = foo => spam$0,
}
}
"#,
93,
);
}
}

View File

@ -59,6 +59,10 @@ pub enum Expr {
then_branch: ExprId,
else_branch: Option<ExprId>,
},
Let {
pat: PatId,
expr: ExprId,
},
Block {
id: BlockId,
statements: Box<[Statement]>,
@ -189,17 +193,10 @@ pub enum Array {
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct MatchArm {
pub pat: PatId,
pub guard: Option<MatchGuard>,
pub guard: Option<ExprId>,
pub expr: ExprId,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum MatchGuard {
If { expr: ExprId },
IfLet { pat: PatId, expr: ExprId },
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RecordLitField {
pub name: Name,
@ -232,6 +229,9 @@ impl Expr {
f(else_branch);
}
}
Expr::Let { expr, .. } => {
f(*expr);
}
Expr::Block { statements, tail, .. } => {
for stmt in statements.iter() {
match stmt {

View File

@ -108,18 +108,18 @@ fn expansion_does_not_parse_as_expression() {
check(
r#"
macro_rules! stmts {
() => { let _ = 0; }
() => { fn foo() {} }
}
fn f() { let _ = stmts!/*+errors*/(); }
"#,
expect![[r#"
macro_rules! stmts {
() => { let _ = 0; }
() => { fn foo() {} }
}
fn f() { let _ = /* parse error: expected expression */
let _ = 0;; }
fn foo() {}; }
"#]],
)
}