Improve pattern matching MIR lowering
This commit is contained in:
parent
051dae2221
commit
9564773d5e
@ -1030,9 +1030,16 @@ fn collect_pat_(&mut self, pat: ast::Pat, binding_list: &mut BindingList) -> Pat
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
ast::Pat::LiteralPat(lit) => {
|
||||
ast::Pat::LiteralPat(lit) => 'b: {
|
||||
if let Some(ast_lit) = lit.literal() {
|
||||
let expr = Expr::Literal(ast_lit.kind().into());
|
||||
let mut hir_lit: Literal = ast_lit.kind().into();
|
||||
if lit.minus_token().is_some() {
|
||||
let Some(h) = hir_lit.negate() else {
|
||||
break 'b Pat::Missing;
|
||||
};
|
||||
hir_lit = h;
|
||||
}
|
||||
let expr = Expr::Literal(hir_lit);
|
||||
let expr_ptr = AstPtr::new(&ast::Expr::Literal(ast_lit));
|
||||
let expr_id = self.alloc_expr(expr, expr_ptr);
|
||||
Pat::Lit(expr_id)
|
||||
@ -1144,11 +1151,11 @@ fn from(ast_lit_kind: ast::LiteralKind) -> Self {
|
||||
FloatTypeWrapper::new(lit.float_value().unwrap_or(Default::default())),
|
||||
builtin,
|
||||
)
|
||||
} else if let builtin @ Some(_) = lit.suffix().and_then(BuiltinInt::from_suffix) {
|
||||
Literal::Int(lit.value().unwrap_or(0) as i128, builtin)
|
||||
} else {
|
||||
let builtin = lit.suffix().and_then(BuiltinUint::from_suffix);
|
||||
} else if let builtin @ Some(_) = lit.suffix().and_then(BuiltinUint::from_suffix) {
|
||||
Literal::Uint(lit.value().unwrap_or(0), builtin)
|
||||
} else {
|
||||
let builtin = lit.suffix().and_then(BuiltinInt::from_suffix);
|
||||
Literal::Int(lit.value().unwrap_or(0) as i128, builtin)
|
||||
}
|
||||
}
|
||||
LiteralKind::FloatNumber(lit) => {
|
||||
|
@ -92,6 +92,16 @@ pub enum Literal {
|
||||
Float(FloatTypeWrapper, Option<BuiltinFloat>),
|
||||
}
|
||||
|
||||
impl Literal {
|
||||
pub fn negate(self) -> Option<Self> {
|
||||
if let Literal::Int(i, k) = self {
|
||||
Some(Literal::Int(-i, k))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Expr {
|
||||
/// This is produced if the syntax tree does not have a required expression piece.
|
||||
|
@ -685,6 +685,36 @@ const fn f(x: Season) -> i32 {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_matching_literal() {
|
||||
check_number(
|
||||
r#"
|
||||
const fn f(x: i32) -> i32 {
|
||||
match x {
|
||||
-1 => 1,
|
||||
1 => 10,
|
||||
_ => 100,
|
||||
}
|
||||
}
|
||||
const GOAL: i32 = f(-1) + f(1) + f(0) + f(-5);
|
||||
"#,
|
||||
211
|
||||
);
|
||||
check_number(
|
||||
r#"
|
||||
const fn f(x: &str) -> u8 {
|
||||
match x {
|
||||
"foo" => 1,
|
||||
"bar" => 10,
|
||||
_ => 100,
|
||||
}
|
||||
}
|
||||
const GOAL: u8 = f("foo") + f("bar");
|
||||
"#,
|
||||
11
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_matching_ergonomics() {
|
||||
check_number(
|
||||
@ -698,6 +728,16 @@ const fn f(x: &(u8, u8)) -> u8 {
|
||||
"#,
|
||||
5,
|
||||
);
|
||||
check_number(
|
||||
r#"
|
||||
const GOAL: u8 = {
|
||||
let a = &(2, 3);
|
||||
let &(x, y) = a;
|
||||
x + y
|
||||
};
|
||||
"#,
|
||||
5,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -781,6 +821,33 @@ const fn f(&self, (a, b): &(u8, u8)) -> u8 {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn match_guards() {
|
||||
check_number(
|
||||
r#"
|
||||
//- minicore: option, eq
|
||||
impl<T: PartialEq> PartialEq for Option<T> {
|
||||
fn eq(&self, other: &Rhs) -> bool {
|
||||
match (self, other) {
|
||||
(Some(x), Some(y)) => x == y,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
fn f(x: Option<i32>) -> i32 {
|
||||
match x {
|
||||
y if y == Some(42) => 42000,
|
||||
Some(y) => y,
|
||||
None => 10
|
||||
}
|
||||
}
|
||||
const GOAL: i32 = f(Some(42)) + f(Some(2)) + f(None);
|
||||
"#,
|
||||
42012,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn options() {
|
||||
check_number(
|
||||
@ -983,6 +1050,51 @@ fn mult3(x: u8) -> u8 {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enum_variant_as_function() {
|
||||
check_number(
|
||||
r#"
|
||||
//- minicore: option
|
||||
const GOAL: u8 = {
|
||||
let f = Some;
|
||||
f(3).unwrap_or(2)
|
||||
};
|
||||
"#,
|
||||
3,
|
||||
);
|
||||
check_number(
|
||||
r#"
|
||||
//- minicore: option
|
||||
const GOAL: u8 = {
|
||||
let f: fn(u8) -> Option<u8> = Some;
|
||||
f(3).unwrap_or(2)
|
||||
};
|
||||
"#,
|
||||
3,
|
||||
);
|
||||
check_number(
|
||||
r#"
|
||||
//- minicore: coerce_unsized, index, slice
|
||||
enum Foo {
|
||||
Add2(u8),
|
||||
Mult3(u8),
|
||||
}
|
||||
use Foo::*;
|
||||
const fn f(x: Foo) -> u8 {
|
||||
match x {
|
||||
Add2(x) => x + 2,
|
||||
Mult3(x) => x * 3,
|
||||
}
|
||||
}
|
||||
const GOAL: u8 = {
|
||||
let x = [Add2, Mult3];
|
||||
f(x[0](1)) + f(x[1](5))
|
||||
};
|
||||
"#,
|
||||
18,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_traits() {
|
||||
check_number(
|
||||
|
@ -423,6 +423,7 @@ fn interpret_mir(
|
||||
args: impl Iterator<Item = Vec<u8>>,
|
||||
subst: Substitution,
|
||||
) -> Result<Vec<u8>> {
|
||||
dbg!(body.dbg(self.db));
|
||||
if let Some(x) = self.stack_depth_limit.checked_sub(1) {
|
||||
self.stack_depth_limit = x;
|
||||
} else {
|
||||
@ -581,7 +582,14 @@ fn eval_rvalue<'a>(
|
||||
let mut ty = self.operand_ty(lhs, locals)?;
|
||||
while let TyKind::Ref(_, _, z) = ty.kind(Interner) {
|
||||
ty = z.clone();
|
||||
let size = self.size_of_sized(&ty, locals, "operand of binary op")?;
|
||||
let size = if ty.kind(Interner) == &TyKind::Str {
|
||||
let ns = from_bytes!(usize, &lc[self.ptr_size()..self.ptr_size() * 2]);
|
||||
lc = &lc[..self.ptr_size()];
|
||||
rc = &rc[..self.ptr_size()];
|
||||
ns
|
||||
} else {
|
||||
self.size_of_sized(&ty, locals, "operand of binary op")?
|
||||
};
|
||||
lc = self.read_memory(Address::from_bytes(lc)?, size)?;
|
||||
rc = self.read_memory(Address::from_bytes(rc)?, size)?;
|
||||
}
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind};
|
||||
use hir_def::{
|
||||
adt::VariantData,
|
||||
adt::{VariantData, StructKind},
|
||||
body::Body,
|
||||
expr::{
|
||||
Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId,
|
||||
@ -28,6 +28,9 @@
|
||||
use super::*;
|
||||
|
||||
mod as_place;
|
||||
mod pattern_matching;
|
||||
|
||||
use pattern_matching::AdtPatternShape;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LoopBlocks {
|
||||
@ -107,12 +110,6 @@ fn unresolved_path(db: &dyn HirDatabase, p: &Path) -> Self {
|
||||
|
||||
type Result<T> = std::result::Result<T, MirLowerError>;
|
||||
|
||||
enum AdtPatternShape<'a> {
|
||||
Tuple { args: &'a [PatId], ellipsis: Option<usize> },
|
||||
Record { args: &'a [RecordFieldPat] },
|
||||
Unit,
|
||||
}
|
||||
|
||||
impl MirLowerCtx<'_> {
|
||||
fn temp(&mut self, ty: Ty) -> Result<LocalId> {
|
||||
if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) {
|
||||
@ -275,15 +272,19 @@ fn lower_expr_to_place_without_adjust(
|
||||
Ok(Some(current))
|
||||
}
|
||||
ValueNs::EnumVariantId(variant_id) => {
|
||||
let ty = self.infer.type_of_expr[expr_id].clone();
|
||||
let current = self.lower_enum_variant(
|
||||
variant_id,
|
||||
current,
|
||||
place,
|
||||
ty,
|
||||
vec![],
|
||||
expr_id.into(),
|
||||
)?;
|
||||
let variant_data = &self.db.enum_data(variant_id.parent).variants[variant_id.local_id];
|
||||
if variant_data.variant_data.kind() == StructKind::Unit {
|
||||
let ty = self.infer.type_of_expr[expr_id].clone();
|
||||
current = self.lower_enum_variant(
|
||||
variant_id,
|
||||
current,
|
||||
place,
|
||||
ty,
|
||||
vec![],
|
||||
expr_id.into(),
|
||||
)?;
|
||||
}
|
||||
// Otherwise its a tuple like enum, treated like a zero sized function, so no action is needed
|
||||
Ok(Some(current))
|
||||
}
|
||||
ValueNs::GenericParam(p) => {
|
||||
@ -517,10 +518,7 @@ fn lower_expr_to_place_without_adjust(
|
||||
let cond_ty = self.expr_ty_after_adjustments(*expr);
|
||||
let mut end = None;
|
||||
for MatchArm { pat, guard, expr } in arms.iter() {
|
||||
if guard.is_some() {
|
||||
not_supported!("pattern matching with guard");
|
||||
}
|
||||
let (then, otherwise) = self.pattern_match(
|
||||
let (then, mut otherwise) = self.pattern_match(
|
||||
current,
|
||||
None,
|
||||
cond_place.clone(),
|
||||
@ -528,6 +526,16 @@ fn lower_expr_to_place_without_adjust(
|
||||
*pat,
|
||||
BindingAnnotation::Unannotated,
|
||||
)?;
|
||||
let then = if let &Some(guard) = guard {
|
||||
let next = self.new_basic_block();
|
||||
let o = otherwise.get_or_insert_with(|| self.new_basic_block());
|
||||
if let Some((discr, c)) = self.lower_expr_to_some_operand(guard, then)? {
|
||||
self.set_terminator(c, Terminator::SwitchInt { discr, targets: SwitchTargets::static_if(1, next, *o) });
|
||||
}
|
||||
next
|
||||
} else {
|
||||
then
|
||||
};
|
||||
if let Some(block) = self.lower_expr_to_place(*expr, place.clone(), then)? {
|
||||
let r = end.get_or_insert_with(|| self.new_basic_block());
|
||||
self.set_goto(block, *r);
|
||||
@ -922,7 +930,7 @@ fn lower_enum_variant(
|
||||
) -> Result<BasicBlockId> {
|
||||
let subst = match ty.kind(Interner) {
|
||||
TyKind::Adt(_, subst) => subst.clone(),
|
||||
_ => not_supported!("Non ADT enum"),
|
||||
_ => implementation_error!("Non ADT enum"),
|
||||
};
|
||||
self.push_assignment(
|
||||
prev_block,
|
||||
@ -1020,355 +1028,6 @@ fn push_assignment(
|
||||
self.push_statement(block, StatementKind::Assign(place, rvalue).with_span(span));
|
||||
}
|
||||
|
||||
/// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if
|
||||
/// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which
|
||||
/// can be the `current` block) and one for the mismatched path. If the input pattern is irrefutable, the
|
||||
/// mismatched path block is `None`.
|
||||
///
|
||||
/// By default, it will create a new block for mismatched path. If you already have one, you can provide it with
|
||||
/// `current_else` argument to save an unneccessary jump. If `current_else` isn't `None`, the result mismatched path
|
||||
/// wouldn't be `None` as well. Note that this function will add jumps to the beginning of the `current_else` block,
|
||||
/// so it should be an empty block.
|
||||
fn pattern_match(
|
||||
&mut self,
|
||||
mut current: BasicBlockId,
|
||||
mut current_else: Option<BasicBlockId>,
|
||||
mut cond_place: Place,
|
||||
mut cond_ty: Ty,
|
||||
pattern: PatId,
|
||||
mut binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
Ok(match &self.body.pats[pattern] {
|
||||
Pat::Missing => return Err(MirLowerError::IncompleteExpr),
|
||||
Pat::Wild => (current, current_else),
|
||||
Pat::Tuple { args, ellipsis } => {
|
||||
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
|
||||
let subst = match cond_ty.kind(Interner) {
|
||||
TyKind::Tuple(_, s) => s,
|
||||
_ => {
|
||||
return Err(MirLowerError::TypeError(
|
||||
"non tuple type matched with tuple pattern",
|
||||
))
|
||||
}
|
||||
};
|
||||
self.pattern_match_tuple_like(
|
||||
current,
|
||||
current_else,
|
||||
args,
|
||||
*ellipsis,
|
||||
subst.iter(Interner).enumerate().map(|(i, x)| {
|
||||
(PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone())
|
||||
}),
|
||||
&cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
Pat::Or(pats) => {
|
||||
let then_target = self.new_basic_block();
|
||||
let mut finished = false;
|
||||
for pat in &**pats {
|
||||
let (next, next_else) = self.pattern_match(
|
||||
current,
|
||||
None,
|
||||
cond_place.clone(),
|
||||
cond_ty.clone(),
|
||||
*pat,
|
||||
binding_mode,
|
||||
)?;
|
||||
self.set_goto(next, then_target);
|
||||
match next_else {
|
||||
Some(t) => {
|
||||
current = t;
|
||||
}
|
||||
None => {
|
||||
finished = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !finished {
|
||||
let ce = *current_else.get_or_insert_with(|| self.new_basic_block());
|
||||
self.set_goto(current, ce);
|
||||
}
|
||||
(then_target, current_else)
|
||||
}
|
||||
Pat::Record { args, .. } => {
|
||||
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
|
||||
not_supported!("unresolved variant");
|
||||
};
|
||||
self.pattern_matching_variant(
|
||||
cond_ty,
|
||||
binding_mode,
|
||||
cond_place,
|
||||
variant,
|
||||
current,
|
||||
pattern.into(),
|
||||
current_else,
|
||||
AdtPatternShape::Record { args: &*args },
|
||||
)?
|
||||
}
|
||||
Pat::Range { .. } => not_supported!("range pattern"),
|
||||
Pat::Slice { .. } => not_supported!("slice pattern"),
|
||||
Pat::Path(_) => {
|
||||
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
|
||||
not_supported!("unresolved variant");
|
||||
};
|
||||
self.pattern_matching_variant(
|
||||
cond_ty,
|
||||
binding_mode,
|
||||
cond_place,
|
||||
variant,
|
||||
current,
|
||||
pattern.into(),
|
||||
current_else,
|
||||
AdtPatternShape::Unit,
|
||||
)?
|
||||
}
|
||||
Pat::Lit(l) => {
|
||||
let then_target = self.new_basic_block();
|
||||
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
|
||||
match &self.body.exprs[*l] {
|
||||
Expr::Literal(l) => match l {
|
||||
hir_def::expr::Literal::Int(x, _) => {
|
||||
self.set_terminator(
|
||||
current,
|
||||
Terminator::SwitchInt {
|
||||
discr: Operand::Copy(cond_place),
|
||||
targets: SwitchTargets::static_if(
|
||||
*x as u128,
|
||||
then_target,
|
||||
else_target,
|
||||
),
|
||||
},
|
||||
);
|
||||
}
|
||||
hir_def::expr::Literal::Uint(x, _) => {
|
||||
self.set_terminator(
|
||||
current,
|
||||
Terminator::SwitchInt {
|
||||
discr: Operand::Copy(cond_place),
|
||||
targets: SwitchTargets::static_if(*x, then_target, else_target),
|
||||
},
|
||||
);
|
||||
}
|
||||
_ => not_supported!("non int path literal"),
|
||||
},
|
||||
_ => not_supported!("expression path literal"),
|
||||
}
|
||||
(then_target, Some(else_target))
|
||||
}
|
||||
Pat::Bind { id, subpat } => {
|
||||
let target_place = self.result.binding_locals[*id];
|
||||
let mode = self.body.bindings[*id].mode;
|
||||
if let Some(subpat) = subpat {
|
||||
(current, current_else) = self.pattern_match(
|
||||
current,
|
||||
current_else,
|
||||
cond_place.clone(),
|
||||
cond_ty,
|
||||
*subpat,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) {
|
||||
binding_mode = mode;
|
||||
}
|
||||
self.push_storage_live(*id, current);
|
||||
self.push_assignment(
|
||||
current,
|
||||
target_place.into(),
|
||||
match binding_mode {
|
||||
BindingAnnotation::Unannotated | BindingAnnotation::Mutable => {
|
||||
Operand::Copy(cond_place).into()
|
||||
}
|
||||
BindingAnnotation::Ref => Rvalue::Ref(BorrowKind::Shared, cond_place),
|
||||
BindingAnnotation::RefMut => Rvalue::Ref(
|
||||
BorrowKind::Mut { allow_two_phase_borrow: false },
|
||||
cond_place,
|
||||
),
|
||||
},
|
||||
pattern.into(),
|
||||
);
|
||||
(current, current_else)
|
||||
}
|
||||
Pat::TupleStruct { path: _, args, ellipsis } => {
|
||||
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
|
||||
not_supported!("unresolved variant");
|
||||
};
|
||||
self.pattern_matching_variant(
|
||||
cond_ty,
|
||||
binding_mode,
|
||||
cond_place,
|
||||
variant,
|
||||
current,
|
||||
pattern.into(),
|
||||
current_else,
|
||||
AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
|
||||
)?
|
||||
}
|
||||
Pat::Ref { .. } => not_supported!("& pattern"),
|
||||
Pat::Box { .. } => not_supported!("box pattern"),
|
||||
Pat::ConstBlock(_) => not_supported!("const block pattern"),
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_matching_variant(
|
||||
&mut self,
|
||||
mut cond_ty: Ty,
|
||||
mut binding_mode: BindingAnnotation,
|
||||
mut cond_place: Place,
|
||||
variant: VariantId,
|
||||
current: BasicBlockId,
|
||||
span: MirSpan,
|
||||
current_else: Option<BasicBlockId>,
|
||||
shape: AdtPatternShape<'_>,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
|
||||
let subst = match cond_ty.kind(Interner) {
|
||||
TyKind::Adt(_, s) => s,
|
||||
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
|
||||
};
|
||||
Ok(match variant {
|
||||
VariantId::EnumVariantId(v) => {
|
||||
let e = self.db.const_eval_discriminant(v)? as u128;
|
||||
let next = self.new_basic_block();
|
||||
let tmp = self.discr_temp_place();
|
||||
self.push_assignment(
|
||||
current,
|
||||
tmp.clone(),
|
||||
Rvalue::Discriminant(cond_place.clone()),
|
||||
span,
|
||||
);
|
||||
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
|
||||
self.set_terminator(
|
||||
current,
|
||||
Terminator::SwitchInt {
|
||||
discr: Operand::Copy(tmp),
|
||||
targets: SwitchTargets::static_if(e, next, else_target),
|
||||
},
|
||||
);
|
||||
let enum_data = self.db.enum_data(v.parent);
|
||||
self.pattern_matching_variant_fields(
|
||||
shape,
|
||||
&enum_data.variants[v.local_id].variant_data,
|
||||
variant,
|
||||
subst,
|
||||
next,
|
||||
Some(else_target),
|
||||
&cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
VariantId::StructId(s) => {
|
||||
let struct_data = self.db.struct_data(s);
|
||||
self.pattern_matching_variant_fields(
|
||||
shape,
|
||||
&struct_data.variant_data,
|
||||
variant,
|
||||
subst,
|
||||
current,
|
||||
current_else,
|
||||
&cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
VariantId::UnionId(_) => {
|
||||
return Err(MirLowerError::TypeError("pattern matching on union"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_matching_variant_fields(
|
||||
&mut self,
|
||||
shape: AdtPatternShape<'_>,
|
||||
variant_data: &VariantData,
|
||||
v: VariantId,
|
||||
subst: &Substitution,
|
||||
current: BasicBlockId,
|
||||
current_else: Option<BasicBlockId>,
|
||||
cond_place: &Place,
|
||||
binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
let fields_type = self.db.field_types(v);
|
||||
Ok(match shape {
|
||||
AdtPatternShape::Record { args } => {
|
||||
let it = args
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let field_id =
|
||||
variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?;
|
||||
Ok((
|
||||
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
|
||||
x.pat,
|
||||
fields_type[field_id].clone().substitute(Interner, subst),
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
self.pattern_match_adt(
|
||||
current,
|
||||
current_else,
|
||||
it.into_iter(),
|
||||
cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
AdtPatternShape::Tuple { args, ellipsis } => {
|
||||
let fields = variant_data.fields().iter().map(|(x, _)| {
|
||||
(
|
||||
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
|
||||
fields_type[x].clone().substitute(Interner, subst),
|
||||
)
|
||||
});
|
||||
self.pattern_match_tuple_like(
|
||||
current,
|
||||
current_else,
|
||||
args,
|
||||
ellipsis,
|
||||
fields,
|
||||
cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
AdtPatternShape::Unit => (current, current_else),
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_match_adt(
|
||||
&mut self,
|
||||
mut current: BasicBlockId,
|
||||
mut current_else: Option<BasicBlockId>,
|
||||
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>,
|
||||
cond_place: &Place,
|
||||
binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
for (proj, arg, ty) in args {
|
||||
let mut cond_place = cond_place.clone();
|
||||
cond_place.projection.push(proj);
|
||||
(current, current_else) =
|
||||
self.pattern_match(current, current_else, cond_place, ty, arg, binding_mode)?;
|
||||
}
|
||||
Ok((current, current_else))
|
||||
}
|
||||
|
||||
fn pattern_match_tuple_like(
|
||||
&mut self,
|
||||
current: BasicBlockId,
|
||||
current_else: Option<BasicBlockId>,
|
||||
args: &[PatId],
|
||||
ellipsis: Option<usize>,
|
||||
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone,
|
||||
cond_place: &Place,
|
||||
binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
|
||||
let it = al
|
||||
.iter()
|
||||
.zip(fields.clone())
|
||||
.chain(ar.iter().rev().zip(fields.rev()))
|
||||
.map(|(x, y)| (y.0, *x, y.1));
|
||||
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode)
|
||||
}
|
||||
|
||||
fn discr_temp_place(&mut self) -> Place {
|
||||
match &self.discr_temp {
|
||||
Some(x) => x.clone(),
|
||||
@ -1546,22 +1205,6 @@ fn lower_block_to_place(
|
||||
}
|
||||
}
|
||||
|
||||
fn pattern_matching_dereference(
|
||||
cond_ty: &mut Ty,
|
||||
binding_mode: &mut BindingAnnotation,
|
||||
cond_place: &mut Place,
|
||||
) {
|
||||
while let Some((ty, _, mu)) = cond_ty.as_reference() {
|
||||
if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref {
|
||||
*binding_mode = BindingAnnotation::RefMut;
|
||||
} else {
|
||||
*binding_mode = BindingAnnotation::Ref;
|
||||
}
|
||||
*cond_ty = ty.clone();
|
||||
cond_place.projection.push(ProjectionElem::Deref);
|
||||
}
|
||||
}
|
||||
|
||||
fn cast_kind(source_ty: &Ty, target_ty: &Ty) -> Result<CastKind> {
|
||||
Ok(match (source_ty.kind(Interner), target_ty.kind(Interner)) {
|
||||
(TyKind::Scalar(s), TyKind::Scalar(t)) => match (s, t) {
|
||||
|
399
crates/hir-ty/src/mir/lower/pattern_matching.rs
Normal file
399
crates/hir-ty/src/mir/lower/pattern_matching.rs
Normal file
@ -0,0 +1,399 @@
|
||||
//! MIR lowering for patterns
|
||||
|
||||
use super::*;
|
||||
|
||||
macro_rules! not_supported {
|
||||
($x: expr) => {
|
||||
return Err(MirLowerError::NotSupported(format!($x)))
|
||||
};
|
||||
}
|
||||
|
||||
pub(super) enum AdtPatternShape<'a> {
|
||||
Tuple { args: &'a [PatId], ellipsis: Option<usize> },
|
||||
Record { args: &'a [RecordFieldPat] },
|
||||
Unit,
|
||||
}
|
||||
|
||||
impl MirLowerCtx<'_> {
|
||||
/// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if
|
||||
/// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which
|
||||
/// can be the `current` block) and one for the mismatched path. If the input pattern is irrefutable, the
|
||||
/// mismatched path block is `None`.
|
||||
///
|
||||
/// By default, it will create a new block for mismatched path. If you already have one, you can provide it with
|
||||
/// `current_else` argument to save an unneccessary jump. If `current_else` isn't `None`, the result mismatched path
|
||||
/// wouldn't be `None` as well. Note that this function will add jumps to the beginning of the `current_else` block,
|
||||
/// so it should be an empty block.
|
||||
pub(super) fn pattern_match(
|
||||
&mut self,
|
||||
mut current: BasicBlockId,
|
||||
mut current_else: Option<BasicBlockId>,
|
||||
mut cond_place: Place,
|
||||
mut cond_ty: Ty,
|
||||
pattern: PatId,
|
||||
mut binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
Ok(match &self.body.pats[pattern] {
|
||||
Pat::Missing => return Err(MirLowerError::IncompleteExpr),
|
||||
Pat::Wild => (current, current_else),
|
||||
Pat::Tuple { args, ellipsis } => {
|
||||
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
|
||||
let subst = match cond_ty.kind(Interner) {
|
||||
TyKind::Tuple(_, s) => s,
|
||||
_ => {
|
||||
return Err(MirLowerError::TypeError(
|
||||
"non tuple type matched with tuple pattern",
|
||||
))
|
||||
}
|
||||
};
|
||||
self.pattern_match_tuple_like(
|
||||
current,
|
||||
current_else,
|
||||
args,
|
||||
*ellipsis,
|
||||
subst.iter(Interner).enumerate().map(|(i, x)| {
|
||||
(PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone())
|
||||
}),
|
||||
&cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
Pat::Or(pats) => {
|
||||
let then_target = self.new_basic_block();
|
||||
let mut finished = false;
|
||||
for pat in &**pats {
|
||||
let (next, next_else) = self.pattern_match(
|
||||
current,
|
||||
None,
|
||||
cond_place.clone(),
|
||||
cond_ty.clone(),
|
||||
*pat,
|
||||
binding_mode,
|
||||
)?;
|
||||
self.set_goto(next, then_target);
|
||||
match next_else {
|
||||
Some(t) => {
|
||||
current = t;
|
||||
}
|
||||
None => {
|
||||
finished = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !finished {
|
||||
let ce = *current_else.get_or_insert_with(|| self.new_basic_block());
|
||||
self.set_goto(current, ce);
|
||||
}
|
||||
(then_target, current_else)
|
||||
}
|
||||
Pat::Record { args, .. } => {
|
||||
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
|
||||
not_supported!("unresolved variant");
|
||||
};
|
||||
self.pattern_matching_variant(
|
||||
cond_ty,
|
||||
binding_mode,
|
||||
cond_place,
|
||||
variant,
|
||||
current,
|
||||
pattern.into(),
|
||||
current_else,
|
||||
AdtPatternShape::Record { args: &*args },
|
||||
)?
|
||||
}
|
||||
Pat::Range { .. } => not_supported!("range pattern"),
|
||||
Pat::Slice { .. } => not_supported!("slice pattern"),
|
||||
Pat::Path(_) => {
|
||||
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
|
||||
not_supported!("unresolved variant");
|
||||
};
|
||||
self.pattern_matching_variant(
|
||||
cond_ty,
|
||||
binding_mode,
|
||||
cond_place,
|
||||
variant,
|
||||
current,
|
||||
pattern.into(),
|
||||
current_else,
|
||||
AdtPatternShape::Unit,
|
||||
)?
|
||||
}
|
||||
Pat::Lit(l) => match &self.body.exprs[*l] {
|
||||
Expr::Literal(l) => {
|
||||
let c = self.lower_literal_to_operand(cond_ty, l)?;
|
||||
self.pattern_match_const(current_else, current, c, cond_place, pattern)?
|
||||
}
|
||||
_ => not_supported!("expression path literal"),
|
||||
},
|
||||
Pat::Bind { id, subpat } => {
|
||||
let target_place = self.result.binding_locals[*id];
|
||||
let mode = self.body.bindings[*id].mode;
|
||||
if let Some(subpat) = subpat {
|
||||
(current, current_else) = self.pattern_match(
|
||||
current,
|
||||
current_else,
|
||||
cond_place.clone(),
|
||||
cond_ty,
|
||||
*subpat,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) {
|
||||
binding_mode = mode;
|
||||
}
|
||||
self.push_storage_live(*id, current);
|
||||
self.push_assignment(
|
||||
current,
|
||||
target_place.into(),
|
||||
match binding_mode {
|
||||
BindingAnnotation::Unannotated | BindingAnnotation::Mutable => {
|
||||
Operand::Copy(cond_place).into()
|
||||
}
|
||||
BindingAnnotation::Ref => Rvalue::Ref(BorrowKind::Shared, cond_place),
|
||||
BindingAnnotation::RefMut => Rvalue::Ref(
|
||||
BorrowKind::Mut { allow_two_phase_borrow: false },
|
||||
cond_place,
|
||||
),
|
||||
},
|
||||
pattern.into(),
|
||||
);
|
||||
(current, current_else)
|
||||
}
|
||||
Pat::TupleStruct { path: _, args, ellipsis } => {
|
||||
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
|
||||
not_supported!("unresolved variant");
|
||||
};
|
||||
self.pattern_matching_variant(
|
||||
cond_ty,
|
||||
binding_mode,
|
||||
cond_place,
|
||||
variant,
|
||||
current,
|
||||
pattern.into(),
|
||||
current_else,
|
||||
AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
|
||||
)?
|
||||
}
|
||||
Pat::Ref { pat, mutability: _ } => {
|
||||
if let Some((ty, _, _)) = cond_ty.as_reference() {
|
||||
cond_ty = ty.clone();
|
||||
cond_place.projection.push(ProjectionElem::Deref);
|
||||
self.pattern_match(
|
||||
current,
|
||||
current_else,
|
||||
cond_place,
|
||||
cond_ty,
|
||||
*pat,
|
||||
binding_mode,
|
||||
)?
|
||||
} else {
|
||||
return Err(MirLowerError::TypeError("& pattern for non reference"));
|
||||
}
|
||||
}
|
||||
Pat::Box { .. } => not_supported!("box pattern"),
|
||||
Pat::ConstBlock(_) => not_supported!("const block pattern"),
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_match_const(
|
||||
&mut self,
|
||||
current_else: Option<BasicBlockId>,
|
||||
current: BasicBlockId,
|
||||
c: Operand,
|
||||
cond_place: Place,
|
||||
pattern: Idx<Pat>,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
let then_target = self.new_basic_block();
|
||||
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
|
||||
let discr: Place = self.temp(TyBuilder::bool())?.into();
|
||||
self.push_assignment(
|
||||
current,
|
||||
discr.clone(),
|
||||
Rvalue::CheckedBinaryOp(BinOp::Eq, c, Operand::Copy(cond_place)),
|
||||
pattern.into(),
|
||||
);
|
||||
let discr = Operand::Copy(discr);
|
||||
self.set_terminator(
|
||||
current,
|
||||
Terminator::SwitchInt {
|
||||
discr,
|
||||
targets: SwitchTargets::static_if(1, then_target, else_target),
|
||||
},
|
||||
);
|
||||
Ok((then_target, Some(else_target)))
|
||||
}
|
||||
|
||||
pub(super) fn pattern_matching_variant(
|
||||
&mut self,
|
||||
mut cond_ty: Ty,
|
||||
mut binding_mode: BindingAnnotation,
|
||||
mut cond_place: Place,
|
||||
variant: VariantId,
|
||||
current: BasicBlockId,
|
||||
span: MirSpan,
|
||||
current_else: Option<BasicBlockId>,
|
||||
shape: AdtPatternShape<'_>,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
|
||||
let subst = match cond_ty.kind(Interner) {
|
||||
TyKind::Adt(_, s) => s,
|
||||
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
|
||||
};
|
||||
Ok(match variant {
|
||||
VariantId::EnumVariantId(v) => {
|
||||
let e = self.db.const_eval_discriminant(v)? as u128;
|
||||
let next = self.new_basic_block();
|
||||
let tmp = self.discr_temp_place();
|
||||
self.push_assignment(
|
||||
current,
|
||||
tmp.clone(),
|
||||
Rvalue::Discriminant(cond_place.clone()),
|
||||
span,
|
||||
);
|
||||
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
|
||||
self.set_terminator(
|
||||
current,
|
||||
Terminator::SwitchInt {
|
||||
discr: Operand::Copy(tmp),
|
||||
targets: SwitchTargets::static_if(e, next, else_target),
|
||||
},
|
||||
);
|
||||
let enum_data = self.db.enum_data(v.parent);
|
||||
self.pattern_matching_variant_fields(
|
||||
shape,
|
||||
&enum_data.variants[v.local_id].variant_data,
|
||||
variant,
|
||||
subst,
|
||||
next,
|
||||
Some(else_target),
|
||||
&cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
VariantId::StructId(s) => {
|
||||
let struct_data = self.db.struct_data(s);
|
||||
self.pattern_matching_variant_fields(
|
||||
shape,
|
||||
&struct_data.variant_data,
|
||||
variant,
|
||||
subst,
|
||||
current,
|
||||
current_else,
|
||||
&cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
VariantId::UnionId(_) => {
|
||||
return Err(MirLowerError::TypeError("pattern matching on union"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_matching_variant_fields(
|
||||
&mut self,
|
||||
shape: AdtPatternShape<'_>,
|
||||
variant_data: &VariantData,
|
||||
v: VariantId,
|
||||
subst: &Substitution,
|
||||
current: BasicBlockId,
|
||||
current_else: Option<BasicBlockId>,
|
||||
cond_place: &Place,
|
||||
binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
let fields_type = self.db.field_types(v);
|
||||
Ok(match shape {
|
||||
AdtPatternShape::Record { args } => {
|
||||
let it = args
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let field_id =
|
||||
variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?;
|
||||
Ok((
|
||||
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
|
||||
x.pat,
|
||||
fields_type[field_id].clone().substitute(Interner, subst),
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
self.pattern_match_adt(
|
||||
current,
|
||||
current_else,
|
||||
it.into_iter(),
|
||||
cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
AdtPatternShape::Tuple { args, ellipsis } => {
|
||||
let fields = variant_data.fields().iter().map(|(x, _)| {
|
||||
(
|
||||
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
|
||||
fields_type[x].clone().substitute(Interner, subst),
|
||||
)
|
||||
});
|
||||
self.pattern_match_tuple_like(
|
||||
current,
|
||||
current_else,
|
||||
args,
|
||||
ellipsis,
|
||||
fields,
|
||||
cond_place,
|
||||
binding_mode,
|
||||
)?
|
||||
}
|
||||
AdtPatternShape::Unit => (current, current_else),
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_match_adt(
|
||||
&mut self,
|
||||
mut current: BasicBlockId,
|
||||
mut current_else: Option<BasicBlockId>,
|
||||
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>,
|
||||
cond_place: &Place,
|
||||
binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
for (proj, arg, ty) in args {
|
||||
let mut cond_place = cond_place.clone();
|
||||
cond_place.projection.push(proj);
|
||||
(current, current_else) =
|
||||
self.pattern_match(current, current_else, cond_place, ty, arg, binding_mode)?;
|
||||
}
|
||||
Ok((current, current_else))
|
||||
}
|
||||
|
||||
fn pattern_match_tuple_like(
|
||||
&mut self,
|
||||
current: BasicBlockId,
|
||||
current_else: Option<BasicBlockId>,
|
||||
args: &[PatId],
|
||||
ellipsis: Option<usize>,
|
||||
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone,
|
||||
cond_place: &Place,
|
||||
binding_mode: BindingAnnotation,
|
||||
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
|
||||
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
|
||||
let it = al
|
||||
.iter()
|
||||
.zip(fields.clone())
|
||||
.chain(ar.iter().rev().zip(fields.rev()))
|
||||
.map(|(x, y)| (y.0, *x, y.1));
|
||||
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode)
|
||||
}
|
||||
}
|
||||
|
||||
fn pattern_matching_dereference(
|
||||
cond_ty: &mut Ty,
|
||||
binding_mode: &mut BindingAnnotation,
|
||||
cond_place: &mut Place,
|
||||
) {
|
||||
while let Some((ty, _, mu)) = cond_ty.as_reference() {
|
||||
if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref {
|
||||
*binding_mode = BindingAnnotation::RefMut;
|
||||
} else {
|
||||
*binding_mode = BindingAnnotation::Ref;
|
||||
}
|
||||
*cond_ty = ty.clone();
|
||||
cond_place.projection.push(ProjectionElem::Deref);
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//! A pretty-printer for MIR.
|
||||
|
||||
use std::fmt::{Display, Write};
|
||||
use std::fmt::{Display, Write, Debug};
|
||||
|
||||
use hir_def::{body::Body, expr::BindingId};
|
||||
use hir_expand::name::Name;
|
||||
@ -23,6 +23,18 @@ pub fn pretty_print(&self, db: &dyn HirDatabase) -> String {
|
||||
ctx.for_body();
|
||||
ctx.result
|
||||
}
|
||||
|
||||
// String with lines is rendered poorly in `dbg!` macros, which I use very much, so this
|
||||
// function exists to solve that.
|
||||
pub fn dbg(&self, db: &dyn HirDatabase) -> impl Debug {
|
||||
struct StringDbg(String);
|
||||
impl Debug for StringDbg {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.0)
|
||||
}
|
||||
}
|
||||
StringDbg(self.pretty_print(db))
|
||||
}
|
||||
}
|
||||
|
||||
struct MirPrettyCtx<'a> {
|
||||
|
@ -1376,6 +1376,7 @@ pub struct LiteralPat {
|
||||
}
|
||||
impl LiteralPat {
|
||||
pub fn literal(&self) -> Option<Literal> { support::child(&self.syntax) }
|
||||
pub fn minus_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![-]) }
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
|
@ -597,7 +597,10 @@ pub fn and<U>(self, optb: Option<U>) -> Option<U> {
|
||||
loop {}
|
||||
}
|
||||
pub fn unwrap_or(self, default: T) -> T {
|
||||
loop {}
|
||||
match self {
|
||||
Some(val) => val,
|
||||
None => default,
|
||||
}
|
||||
}
|
||||
// region:fn
|
||||
pub fn and_then<U, F>(self, f: F) -> Option<U>
|
||||
|
Loading…
Reference in New Issue
Block a user