From c0a0664d120609fa37aa950a11d5e6c0df176770 Mon Sep 17 00:00:00 2001 From: hkalbasi Date: Sun, 19 Feb 2023 01:47:44 +0330 Subject: [PATCH] Support "or patterns" MIR lowering --- crates/hir-ty/src/consteval/tests.rs | 42 +++++++++++- crates/hir-ty/src/mir/lower.rs | 95 ++++++++++++++++++++++++---- 2 files changed, 122 insertions(+), 15 deletions(-) diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index b7a466c389c..f05688aa55b 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -559,11 +559,11 @@ fn function_param_patterns() { check_number( r#" const fn f(c @ (a, b): &(u8, u8)) -> u8 { - *a + *b + (*c).1 + *a + *b + c.0 + (*c).1 } const GOAL: u8 = f(&(2, 3)); "#, - 8, + 10, ); check_number( r#" @@ -641,6 +641,44 @@ fn options() { ); } +#[test] +fn or_pattern() { + check_number( + r#" + const GOAL: u8 = { + let (a | a) = 2; + a + }; + "#, + 2, + ); + check_number( + r#" + //- minicore: option + const fn f(x: Option) -> i32 { + let (Some(a) | Some(a)) = x else { return 2; }; + a + } + const GOAL: i32 = f(Some(10)) + f(None); + "#, + 12, + ); + check_number( + r#" + //- minicore: option + const fn f(x: Option, y: Option) -> i32 { + match (x, y) { + (Some(x), Some(y)) => x * y, + (Some(a), _) | (_, Some(a)) => a, + _ => 10, + } + } + const GOAL: i32 = f(Some(10), Some(20)) + f(Some(30), None) + f(None, Some(40)) + f(None, None); + "#, + 280, + ); +} + #[test] fn array_and_index() { check_number( diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 8f28d62db03..1bcdd3a5057 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -155,12 +155,7 @@ impl MirLowerCtx<'_> { if let Some(p) = self.lower_expr_as_place(expr_id) { return Ok((p, prev_block)); } - let mut ty = self.expr_ty(expr_id); - if let Some(x) = self.infer.expr_adjustments.get(&expr_id) { - if let Some(x) = x.last() { - ty = x.target.clone(); - } - } + let ty = self.expr_ty_after_adjustments(expr_id); let place = self.temp(ty)?; Ok((place.into(), self.lower_expr_to_place(expr_id, place.into(), prev_block)?)) } @@ -323,7 +318,7 @@ impl MirLowerCtx<'_> { current, None, cond_place, - self.expr_ty(*expr), + self.expr_ty_after_adjustments(*expr), *pat, BindingAnnotation::Unannotated, )?; @@ -339,7 +334,53 @@ impl MirLowerCtx<'_> { self.lower_block_to_place(None, statements, current, *tail, place) } Expr::Block { id: _, statements, tail, label } => { - self.lower_block_to_place(*label, statements, current, *tail, place) + if label.is_some() { + not_supported!("block with label"); + } + for statement in statements.iter() { + match statement { + hir_def::expr::Statement::Let { + pat, + initializer, + else_branch, + type_ref: _, + } => match initializer { + Some(expr_id) => { + let else_block; + let init_place; + (init_place, current) = + self.lower_expr_to_some_place(*expr_id, current)?; + (current, else_block) = self.pattern_match( + current, + None, + init_place, + self.expr_ty_after_adjustments(*expr_id), + *pat, + BindingAnnotation::Unannotated, + )?; + match (else_block, else_branch) { + (None, _) => (), + (Some(else_block), None) => { + self.set_terminator(else_block, Terminator::Unreachable); + } + (Some(else_block), Some(else_branch)) => { + let (_, b) = self + .lower_expr_to_some_place(*else_branch, else_block)?; + self.set_terminator(b, Terminator::Unreachable); + } + } + } + None => continue, + }, + hir_def::expr::Statement::Expr { expr, has_semi: _ } => { + (_, current) = self.lower_expr_to_some_place(*expr, current)?; + } + } + } + match tail { + Some(tail) => self.lower_expr_to_place(*tail, place, current), + None => Ok(current), + } } Expr::Loop { body, label } => self.lower_loop(current, *label, |this, begin, _| { let (_, block) = this.lower_expr_to_some_place(*body, begin)?; @@ -364,7 +405,7 @@ impl MirLowerCtx<'_> { } Expr::For { .. } => not_supported!("for loop"), Expr::Call { callee, args, .. } => { - let callee_ty = self.expr_ty(*callee); + let callee_ty = self.expr_ty_after_adjustments(*callee); match &callee_ty.data(Interner).kind { chalk_ir::TyKind::FnDef(..) => { let func = Operand::from_bytes(vec![], callee_ty.clone()); @@ -414,7 +455,7 @@ impl MirLowerCtx<'_> { } Expr::Match { expr, arms } => { let (cond_place, mut current) = self.lower_expr_to_some_place(*expr, current)?; - let cond_ty = self.expr_ty(*expr); + let cond_ty = self.expr_ty_after_adjustments(*expr); let end = self.new_basic_block(); for MatchArm { pat, guard, expr } in arms.iter() { if guard.is_some() { @@ -524,7 +565,7 @@ impl MirLowerCtx<'_> { } Expr::Field { expr, name } => { let (mut current_place, current) = self.lower_expr_to_some_place(*expr, current)?; - if let TyKind::Tuple(..) = self.expr_ty(*expr).kind(Interner) { + if let TyKind::Tuple(..) = self.expr_ty_after_adjustments(*expr).kind(Interner) { let index = name .as_tuple_index() .ok_or(MirLowerError::TypeError("named field on tuple"))?; @@ -623,7 +664,7 @@ impl MirLowerCtx<'_> { Expr::Index { base, index } => { let mut p_base; (p_base, current) = self.lower_expr_to_some_place(*base, current)?; - let l_index = self.temp(self.expr_ty(*index))?; + let l_index = self.temp(self.expr_ty_after_adjustments(*index))?; current = self.lower_expr_to_place(*index, l_index.into(), current)?; p_base.projection.push(ProjectionElem::Index(l_index)); self.push_assignment(current, place, Operand::Copy(p_base).into()); @@ -878,6 +919,16 @@ impl MirLowerCtx<'_> { self.infer[e].clone() } + fn expr_ty_after_adjustments(&self, e: ExprId) -> Ty { + let mut ty = None; + if let Some(x) = self.infer.expr_adjustments.get(&e) { + if let Some(x) = x.last() { + ty = Some(x.target.clone()); + } + } + ty.unwrap_or_else(|| self.expr_ty(e)) + } + fn push_assignment(&mut self, block: BasicBlockId, place: Place, rvalue: Rvalue) { self.result.basic_blocks[block].statements.push(Statement::Assign(place, rvalue)); } @@ -928,7 +979,25 @@ impl MirLowerCtx<'_> { binding_mode, )? } - Pat::Or(_) => not_supported!("or pattern"), + Pat::Or(pats) => { + let then_target = self.new_basic_block(); + let mut finished = false; + for pat in &**pats { + let (next, next_else) = + self.pattern_match(current, None, cond_place.clone(), cond_ty.clone(), *pat, binding_mode)?; + self.set_goto(next, then_target); + match next_else { + Some(t) => { + current = t; + } + None => { + finished = true; + break; + } + } + } + (then_target, (!finished).then_some(current)) + } Pat::Record { .. } => not_supported!("record pattern"), Pat::Range { .. } => not_supported!("range pattern"), Pat::Slice { .. } => not_supported!("slice pattern"),