From 051dae222164a04c703c5c0914a304d35206a62f Mon Sep 17 00:00:00 2001 From: hkalbasi Date: Tue, 14 Mar 2023 17:02:38 +0330 Subject: [PATCH] Support record pattern MIR lowering --- crates/hir-ty/src/consteval/tests.rs | 41 +++++- crates/hir-ty/src/method_resolution.rs | 13 +- crates/hir-ty/src/mir/eval.rs | 15 +- crates/hir-ty/src/mir/lower.rs | 193 +++++++++++++++++++------ lib/la-arena/src/lib.rs | 2 +- 5 files changed, 201 insertions(+), 63 deletions(-) diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index f7914b578e4..1d298f96091 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -555,6 +555,38 @@ struct Point { "#, 17, ); + check_number( + r#" + struct Point { + x: i32, + y: i32, + } + + const GOAL: i32 = { + let p = Point { x: 5, y: 2 }; + let p2 = Point { x: 3, ..p }; + p.x * 1000 + p.y * 100 + p2.x * 10 + p2.y + }; + "#, + 5232, + ); + check_number( + r#" + struct Point { + x: i32, + y: i32, + } + + const GOAL: i32 = { + let p = Point { x: 5, y: 2 }; + let Point { x, y } = p; + let Point { x: x2, .. } = p; + let Point { y: y2, .. } = p; + x * 1000 + y * 100 + x2 * 10 + y2 + }; + "#, + 5252, + ); } #[test] @@ -599,13 +631,14 @@ fn tuples() { ); check_number( r#" - struct TupleLike(i32, u8, i64, u16); - const GOAL: u8 = { + struct TupleLike(i32, i64, u8, u16); + const GOAL: i64 = { let a = TupleLike(10, 20, 3, 15); - a.1 + let TupleLike(b, .., c) = a; + a.1 * 100 + b as i64 + c as i64 }; "#, - 20, + 2025, ); check_number( r#" diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs index 6244b98104f..2003d24038b 100644 --- a/crates/hir-ty/src/method_resolution.rs +++ b/crates/hir-ty/src/method_resolution.rs @@ -711,12 +711,13 @@ pub fn is_dyn_method( }; let self_ty = trait_ref.self_type_parameter(Interner); if let TyKind::Dyn(d) = self_ty.kind(Interner) { - let is_my_trait_in_bounds = d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() { - // rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter - // what the generics are, we are sure that the method is come from the vtable. - WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id, - _ => false, - }); + let is_my_trait_in_bounds = + d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() { + // rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter + // what the generics are, we are sure that the method is come from the vtable. + WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id, + _ => false, + }); if is_my_trait_in_bounds { return Some(fn_params); } diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 3001832d795..787665bb637 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -25,8 +25,8 @@ mapping::from_chalk, method_resolution::{is_dyn_method, lookup_impl_method}, traits::FnTrait, - CallableDefId, Const, ConstScalar, FnDefId, Interner, MemoryMap, Substitution, - TraitEnvironment, Ty, TyBuilder, TyExt, GenericArgData, + CallableDefId, Const, ConstScalar, FnDefId, GenericArgData, Interner, MemoryMap, Substitution, + TraitEnvironment, Ty, TyBuilder, TyExt, }; use super::{ @@ -1315,10 +1315,13 @@ fn exec_fn_with_args( args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec(); let generics_for_target = Substitution::from_iter( Interner, - generic_args - .iter(Interner) - .enumerate() - .map(|(i, x)| if i == self_ty_idx { &ty } else { x }) + generic_args.iter(Interner).enumerate().map(|(i, x)| { + if i == self_ty_idx { + &ty + } else { + x + } + }), ); return self.exec_fn_with_args( def, diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 4fc3c67a6e1..d36d9e946ca 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -4,16 +4,17 @@ use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind}; use hir_def::{ + adt::VariantData, body::Body, expr::{ Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId, - RecordLitField, + RecordFieldPat, RecordLitField, }, lang_item::{LangItem, LangItemTarget}, layout::LayoutError, path::Path, resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, - DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, TraitId, + DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, LocalFieldId, TraitId, }; use hir_expand::name::Name; use la_arena::ArenaMap; @@ -106,6 +107,12 @@ fn unresolved_path(db: &dyn HirDatabase, p: &Path) -> Self { type Result = std::result::Result; +enum AdtPatternShape<'a> { + Tuple { args: &'a [PatId], ellipsis: Option }, + Record { args: &'a [RecordFieldPat] }, + Unit, +} + impl MirLowerCtx<'_> { fn temp(&mut self, ty: Ty) -> Result { if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) { @@ -444,7 +451,8 @@ fn lower_expr_to_place_without_adjust( current, pat.into(), Some(end), - &[pat], &None)?; + AdtPatternShape::Tuple { args: &[pat], ellipsis: None }, + )?; if let Some((_, block)) = this.lower_expr_as_place(current, body, true)? { this.set_goto(block, begin); } @@ -573,7 +581,17 @@ fn lower_expr_to_place_without_adjust( Ok(None) } Expr::Yield { .. } => not_supported!("yield"), - Expr::RecordLit { fields, path, .. } => { + Expr::RecordLit { fields, path, spread, ellipsis: _, is_assignee_expr: _ } => { + let spread_place = match spread { + &Some(x) => { + let Some((p, c)) = self.lower_expr_as_place(current, x, true)? else { + return Ok(None); + }; + current = c; + Some(p) + }, + None => None, + }; let variant_id = self .infer .variant_resolution_for_expr(expr_id) @@ -603,9 +621,24 @@ fn lower_expr_to_place_without_adjust( place, Rvalue::Aggregate( AggregateKind::Adt(variant_id, subst), - operands.into_iter().map(|x| x).collect::>().ok_or( - MirLowerError::TypeError("missing field in record literal"), - )?, + match spread_place { + Some(sp) => operands.into_iter().enumerate().map(|(i, x)| { + match x { + Some(x) => x, + None => { + let mut p = sp.clone(); + p.projection.push(ProjectionElem::Field(FieldId { + parent: variant_id, + local_id: LocalFieldId::from_raw(RawIdx::from(i as u32)), + })); + Operand::Copy(p) + }, + } + }).collect(), + None => operands.into_iter().map(|x| x).collect::>().ok_or( + MirLowerError::TypeError("missing field in record literal"), + )?, + }, ), expr_id.into(), ); @@ -1021,14 +1054,11 @@ fn pattern_match( self.pattern_match_tuple_like( current, current_else, - args.iter().enumerate().map(|(i, x)| { - ( - PlaceElem::TupleField(i), - *x, - subst.at(Interner, i).assert_ty_ref(Interner).clone(), - ) - }), + args, *ellipsis, + subst.iter(Interner).enumerate().map(|(i, x)| { + (PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone()) + }), &cond_place, binding_mode, )? @@ -1062,7 +1092,21 @@ fn pattern_match( } (then_target, current_else) } - Pat::Record { .. } => not_supported!("record pattern"), + Pat::Record { args, .. } => { + let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else { + not_supported!("unresolved variant"); + }; + self.pattern_matching_variant( + cond_ty, + binding_mode, + cond_place, + variant, + current, + pattern.into(), + current_else, + AdtPatternShape::Record { args: &*args }, + )? + } Pat::Range { .. } => not_supported!("range pattern"), Pat::Slice { .. } => not_supported!("slice pattern"), Pat::Path(_) => { @@ -1077,8 +1121,7 @@ fn pattern_match( current, pattern.into(), current_else, - &[], - &None, + AdtPatternShape::Unit, )? } Pat::Lit(l) => { @@ -1160,8 +1203,7 @@ fn pattern_match( current, pattern.into(), current_else, - args, - ellipsis, + AdtPatternShape::Tuple { args, ellipsis: *ellipsis }, )? } Pat::Ref { .. } => not_supported!("& pattern"), @@ -1179,15 +1221,13 @@ fn pattern_matching_variant( current: BasicBlockId, span: MirSpan, current_else: Option, - args: &[PatId], - ellipsis: &Option, + shape: AdtPatternShape<'_>, ) -> Result<(BasicBlockId, Option)> { pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); let subst = match cond_ty.kind(Interner) { TyKind::Adt(_, s) => s, _ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")), }; - let fields_type = self.db.field_types(variant); Ok(match variant { VariantId::EnumVariantId(v) => { let e = self.db.const_eval_discriminant(v)? as u128; @@ -1208,35 +1248,26 @@ fn pattern_matching_variant( }, ); let enum_data = self.db.enum_data(v.parent); - let fields = - enum_data.variants[v.local_id].variant_data.fields().iter().map(|(x, _)| { - ( - PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), - fields_type[x].clone().substitute(Interner, subst), - ) - }); - self.pattern_match_tuple_like( + self.pattern_matching_variant_fields( + shape, + &enum_data.variants[v.local_id].variant_data, + variant, + subst, next, Some(else_target), - args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)), - *ellipsis, &cond_place, binding_mode, )? } VariantId::StructId(s) => { let struct_data = self.db.struct_data(s); - let fields = struct_data.variant_data.fields().iter().map(|(x, _)| { - ( - PlaceElem::Field(FieldId { parent: s.into(), local_id: x }), - fields_type[x].clone().substitute(Interner, subst), - ) - }); - self.pattern_match_tuple_like( + self.pattern_matching_variant_fields( + shape, + &struct_data.variant_data, + variant, + subst, current, current_else, - args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)), - *ellipsis, &cond_place, binding_mode, )? @@ -1247,18 +1278,69 @@ fn pattern_matching_variant( }) } - fn pattern_match_tuple_like( + fn pattern_matching_variant_fields( + &mut self, + shape: AdtPatternShape<'_>, + variant_data: &VariantData, + v: VariantId, + subst: &Substitution, + current: BasicBlockId, + current_else: Option, + cond_place: &Place, + binding_mode: BindingAnnotation, + ) -> Result<(BasicBlockId, Option)> { + let fields_type = self.db.field_types(v); + Ok(match shape { + AdtPatternShape::Record { args } => { + let it = args + .iter() + .map(|x| { + let field_id = + variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?; + Ok(( + PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }), + x.pat, + fields_type[field_id].clone().substitute(Interner, subst), + )) + }) + .collect::>>()?; + self.pattern_match_adt( + current, + current_else, + it.into_iter(), + cond_place, + binding_mode, + )? + } + AdtPatternShape::Tuple { args, ellipsis } => { + let fields = variant_data.fields().iter().map(|(x, _)| { + ( + PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), + fields_type[x].clone().substitute(Interner, subst), + ) + }); + self.pattern_match_tuple_like( + current, + current_else, + args, + ellipsis, + fields, + cond_place, + binding_mode, + )? + } + AdtPatternShape::Unit => (current, current_else), + }) + } + + fn pattern_match_adt( &mut self, mut current: BasicBlockId, mut current_else: Option, args: impl Iterator, - ellipsis: Option, cond_place: &Place, binding_mode: BindingAnnotation, ) -> Result<(BasicBlockId, Option)> { - if ellipsis.is_some() { - not_supported!("tuple like pattern with ellipsis"); - } for (proj, arg, ty) in args { let mut cond_place = cond_place.clone(); cond_place.projection.push(proj); @@ -1268,6 +1350,25 @@ fn pattern_match_tuple_like( Ok((current, current_else)) } + fn pattern_match_tuple_like( + &mut self, + current: BasicBlockId, + current_else: Option, + args: &[PatId], + ellipsis: Option, + fields: impl DoubleEndedIterator + Clone, + cond_place: &Place, + binding_mode: BindingAnnotation, + ) -> Result<(BasicBlockId, Option)> { + let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); + let it = al + .iter() + .zip(fields.clone()) + .chain(ar.iter().rev().zip(fields.rev())) + .map(|(x, y)| (y.0, *x, y.1)); + self.pattern_match_adt(current, current_else, it, cond_place, binding_mode) + } + fn discr_temp_place(&mut self) -> Place { match &self.discr_temp { Some(x) => x.clone(), diff --git a/lib/la-arena/src/lib.rs b/lib/la-arena/src/lib.rs index ccaaf399176..f6597efd8fd 100644 --- a/lib/la-arena/src/lib.rs +++ b/lib/la-arena/src/lib.rs @@ -295,7 +295,7 @@ pub fn alloc(&mut self, value: T) -> Idx { /// ``` pub fn iter( &self, - ) -> impl Iterator, &T)> + ExactSizeIterator + DoubleEndedIterator { + ) -> impl Iterator, &T)> + ExactSizeIterator + DoubleEndedIterator + Clone { self.data.iter().enumerate().map(|(idx, value)| (Idx::from_raw(RawIdx(idx as u32)), value)) }