From c623319a30d50d88091120fe5ac9354572d56662 Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Sat, 9 Mar 2024 03:29:52 +0100 Subject: [PATCH] Lower deref patterns to MIR This handles using deref patterns to choose the correct match arm. This does not handle bindings or guards. Co-authored-by: Deadbeef --- .../src/mem_categorization.rs | 8 +- .../rustc_mir_build/src/build/matches/mod.rs | 7 ++ .../rustc_mir_build/src/build/matches/test.rs | 88 +++++++++++++------ .../rustc_mir_build/src/build/matches/util.rs | 15 ++-- tests/ui/pattern/deref-patterns/bindings.rs | 37 ++++++++ tests/ui/pattern/deref-patterns/branch.rs | 40 +++++++++ .../cant_move_out_of_pattern.rs | 24 +++++ .../cant_move_out_of_pattern.stderr | 27 ++++++ tests/ui/pattern/deref-patterns/typeck.rs | 16 ++-- .../ui/pattern/deref-patterns/typeck_fail.rs | 17 ++++ .../pattern/deref-patterns/typeck_fail.stderr | 19 ++++ 11 files changed, 259 insertions(+), 39 deletions(-) create mode 100644 tests/ui/pattern/deref-patterns/bindings.rs create mode 100644 tests/ui/pattern/deref-patterns/branch.rs create mode 100644 tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.rs create mode 100644 tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.stderr create mode 100644 tests/ui/pattern/deref-patterns/typeck_fail.rs create mode 100644 tests/ui/pattern/deref-patterns/typeck_fail.stderr diff --git a/compiler/rustc_hir_typeck/src/mem_categorization.rs b/compiler/rustc_hir_typeck/src/mem_categorization.rs index 859877962fe..15f802560fd 100644 --- a/compiler/rustc_hir_typeck/src/mem_categorization.rs +++ b/compiler/rustc_hir_typeck/src/mem_categorization.rs @@ -711,13 +711,19 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> { self.cat_pattern_(place_with_id, subpat, op)?; } - PatKind::Box(subpat) | PatKind::Ref(subpat, _) | PatKind::Deref(subpat) => { + PatKind::Box(subpat) | PatKind::Ref(subpat, _) => { // box p1, &p1, &mut p1. we can ignore the mutability of // PatKind::Ref since that information is already contained // in the type. let subplace = self.cat_deref(pat, place_with_id)?; self.cat_pattern_(subplace, subpat, op)?; } + PatKind::Deref(subpat) => { + let ty = self.pat_ty_adjusted(subpat)?; + // A deref pattern generates a temporary. + let place = self.cat_rvalue(pat.hir_id, ty); + self.cat_pattern_(place, subpat, op)?; + } PatKind::Slice(before, ref slice, after) => { let Some(element_ty) = place_with_id.place.ty().builtin_index() else { diff --git a/compiler/rustc_mir_build/src/build/matches/mod.rs b/compiler/rustc_mir_build/src/build/matches/mod.rs index 9730473c428..f16dddeaa68 100644 --- a/compiler/rustc_mir_build/src/build/matches/mod.rs +++ b/compiler/rustc_mir_build/src/build/matches/mod.rs @@ -1163,6 +1163,7 @@ enum TestCase<'pat, 'tcx> { Constant { value: mir::Const<'tcx> }, Range(&'pat PatRange<'tcx>), Slice { len: usize, variable_length: bool }, + Deref { temp: Place<'tcx> }, Or { pats: Box<[FlatPat<'pat, 'tcx>]> }, } @@ -1222,6 +1223,12 @@ enum TestKind<'tcx> { /// Test that the length of the slice is equal to `len`. Len { len: u64, op: BinOp }, + + /// Call `Deref::deref` on the value. + Deref { + /// Temporary to store the result of `deref()`. + temp: Place<'tcx>, + }, } /// A test to perform to determine which [`Candidate`] matches a value. diff --git a/compiler/rustc_mir_build/src/build/matches/test.rs b/compiler/rustc_mir_build/src/build/matches/test.rs index 690879b9488..f6827fd7c52 100644 --- a/compiler/rustc_mir_build/src/build/matches/test.rs +++ b/compiler/rustc_mir_build/src/build/matches/test.rs @@ -42,6 +42,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestKind::Len { len: len as u64, op } } + TestCase::Deref { temp } => TestKind::Deref { temp }, + TestCase::Or { .. } => bug!("or-patterns should have already been handled"), TestCase::Irrefutable { .. } => span_bug!( @@ -143,35 +145,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ); } let re_erased = tcx.lifetimes.re_erased; - let ref_string = self.temp(Ty::new_imm_ref(tcx, re_erased, ty), test.span); let ref_str_ty = Ty::new_imm_ref(tcx, re_erased, tcx.types.str_); let ref_str = self.temp(ref_str_ty, test.span); - let deref = tcx.require_lang_item(LangItem::Deref, None); - let method = trait_method(tcx, deref, sym::deref, [ty]); let eq_block = self.cfg.start_new_block(); - self.cfg.push_assign( - block, - source_info, - ref_string, - Rvalue::Ref(re_erased, BorrowKind::Shared, place), - ); - self.cfg.terminate( - block, - source_info, - TerminatorKind::Call { - func: Operand::Constant(Box::new(ConstOperand { - span: test.span, - user_ty: None, - const_: method, - })), - args: vec![Spanned { node: Operand::Move(ref_string), span: DUMMY_SP }], - destination: ref_str, - target: Some(eq_block), - unwind: UnwindAction::Continue, - call_source: CallSource::Misc, - fn_span: source_info.span, - }, - ); + // `let ref_str: &str = ::deref(&place);` + self.call_deref(block, eq_block, place, ty, ref_str, test.span); self.non_scalar_compare( eq_block, success_block, @@ -270,9 +248,57 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { Operand::Move(expected), ); } + + TestKind::Deref { temp } => { + let ty = place_ty.ty; + let target = target_block(TestBranch::Success); + self.call_deref(block, target, place, ty, temp, test.span); + } } } + /// Perform `let temp = ::deref(&place)`. + pub(super) fn call_deref( + &mut self, + block: BasicBlock, + target_block: BasicBlock, + place: Place<'tcx>, + ty: Ty<'tcx>, + temp: Place<'tcx>, + span: Span, + ) { + let source_info = self.source_info(span); + let re_erased = self.tcx.lifetimes.re_erased; + let deref = self.tcx.require_lang_item(LangItem::Deref, None); + let method = trait_method(self.tcx, deref, sym::deref, [ty]); + let ref_src = self.temp(Ty::new_imm_ref(self.tcx, re_erased, ty), span); + // `let ref_src = &src_place;` + self.cfg.push_assign( + block, + source_info, + ref_src, + Rvalue::Ref(re_erased, BorrowKind::Shared, place), + ); + // `let temp = ::deref(ref_src);` + self.cfg.terminate( + block, + source_info, + TerminatorKind::Call { + func: Operand::Constant(Box::new(ConstOperand { + span, + user_ty: None, + const_: method, + })), + args: vec![Spanned { node: Operand::Move(ref_src), span }], + destination: temp, + target: Some(target_block), + unwind: UnwindAction::Continue, + call_source: CallSource::Misc, + fn_span: source_info.span, + }, + ); + } + /// Compare using the provided built-in comparison operator fn compare( &mut self, @@ -660,13 +686,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } + (TestKind::Deref { temp: test_temp }, TestCase::Deref { temp }) + if test_temp == temp => + { + fully_matched = true; + Some(TestBranch::Success) + } + ( TestKind::Switch { .. } | TestKind::SwitchInt { .. } | TestKind::If | TestKind::Len { .. } | TestKind::Range { .. } - | TestKind::Eq { .. }, + | TestKind::Eq { .. } + | TestKind::Deref { .. }, _, ) => { fully_matched = false; diff --git a/compiler/rustc_mir_build/src/build/matches/util.rs b/compiler/rustc_mir_build/src/build/matches/util.rs index d6376b7b0dc..4d8248c0e39 100644 --- a/compiler/rustc_mir_build/src/build/matches/util.rs +++ b/compiler/rustc_mir_build/src/build/matches/util.rs @@ -5,8 +5,8 @@ use rustc_data_structures::fx::FxIndexSet; use rustc_infer::infer::type_variable::TypeVariableOrigin; use rustc_middle::mir::*; use rustc_middle::thir::{self, *}; -use rustc_middle::ty; use rustc_middle::ty::TypeVisitableExt; +use rustc_middle::ty::{self, Ty}; impl<'a, 'tcx> Builder<'a, 'tcx> { pub(crate) fn field_match_pairs<'pat>( @@ -249,10 +249,15 @@ impl<'pat, 'tcx> MatchPair<'pat, 'tcx> { default_irrefutable() } - PatKind::DerefPattern { .. } => { - // FIXME(deref_patterns) - // Treat it like a wildcard for now. - default_irrefutable() + PatKind::DerefPattern { ref subpattern } => { + // Create a new temporary for each deref pattern. + // FIXME(deref_patterns): dedup temporaries to avoid multiple `deref()` calls? + let temp = cx.temp( + Ty::new_imm_ref(cx.tcx, cx.tcx.lifetimes.re_erased, subpattern.ty), + pattern.span, + ); + subpairs.push(MatchPair::new(PlaceBuilder::from(temp).deref(), subpattern, cx)); + TestCase::Deref { temp } } }; diff --git a/tests/ui/pattern/deref-patterns/bindings.rs b/tests/ui/pattern/deref-patterns/bindings.rs new file mode 100644 index 00000000000..75d59c1967a --- /dev/null +++ b/tests/ui/pattern/deref-patterns/bindings.rs @@ -0,0 +1,37 @@ +//@ run-pass +#![feature(deref_patterns)] +#![allow(incomplete_features)] + +fn simple_vec(vec: Vec) -> u32 { + match vec { + deref!([]) => 100, + // FIXME(deref_patterns): fake borrows break guards + // deref!([x]) if x == 4 => x + 4, + deref!([x]) => x, + deref!([1, x]) => x + 200, + deref!(ref slice) => slice.iter().sum(), + _ => 2000, + } +} + +fn nested_vec(vecvec: Vec>) -> u32 { + match vecvec { + deref!([]) => 0, + deref!([deref!([x])]) => x, + deref!([deref!([0, x]) | deref!([1, x])]) => x, + deref!([ref x]) => x.iter().sum(), + deref!([deref!([]), deref!([1, x, y])]) => y - x, + _ => 2000, + } +} + +fn main() { + assert_eq!(simple_vec(vec![1]), 1); + assert_eq!(simple_vec(vec![1, 2]), 202); + assert_eq!(simple_vec(vec![1, 2, 3]), 6); + + assert_eq!(nested_vec(vec![vec![0, 42]]), 42); + assert_eq!(nested_vec(vec![vec![1, 42]]), 42); + assert_eq!(nested_vec(vec![vec![1, 2, 3]]), 6); + assert_eq!(nested_vec(vec![vec![], vec![1, 2, 3]]), 1); +} diff --git a/tests/ui/pattern/deref-patterns/branch.rs b/tests/ui/pattern/deref-patterns/branch.rs new file mode 100644 index 00000000000..1bac1006d9d --- /dev/null +++ b/tests/ui/pattern/deref-patterns/branch.rs @@ -0,0 +1,40 @@ +//@ run-pass +// Test the execution of deref patterns. +#![feature(deref_patterns)] +#![allow(incomplete_features)] + +fn branch(vec: Vec) -> u32 { + match vec { + deref!([]) => 0, + deref!([1, _, 3]) => 1, + deref!([2, ..]) => 2, + _ => 1000, + } +} + +fn nested(vec: Vec>) -> u32 { + match vec { + deref!([deref!([]), ..]) => 1, + deref!([deref!([0, ..]), deref!([1, ..])]) => 2, + _ => 1000, + } +} + +fn main() { + assert!(matches!(Vec::::new(), deref!([]))); + assert!(matches!(vec![1], deref!([1]))); + assert!(matches!(&vec![1], deref!([1]))); + assert!(matches!(vec![&1], deref!([1]))); + assert!(matches!(vec![vec![1]], deref!([deref!([1])]))); + + assert_eq!(branch(vec![]), 0); + assert_eq!(branch(vec![1, 2, 3]), 1); + assert_eq!(branch(vec![3, 2, 1]), 1000); + assert_eq!(branch(vec![2]), 2); + assert_eq!(branch(vec![2, 3]), 2); + assert_eq!(branch(vec![3, 2]), 1000); + + assert_eq!(nested(vec![vec![], vec![2]]), 1); + assert_eq!(nested(vec![vec![0], vec![1]]), 2); + assert_eq!(nested(vec![vec![0, 2], vec![1, 2]]), 2); +} diff --git a/tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.rs b/tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.rs new file mode 100644 index 00000000000..84b5ec09dc7 --- /dev/null +++ b/tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.rs @@ -0,0 +1,24 @@ +#![feature(deref_patterns)] +#![allow(incomplete_features)] + +use std::rc::Rc; + +struct Struct; + +fn cant_move_out_box(b: Box) -> Struct { + match b { + //~^ ERROR: cannot move out of a shared reference + deref!(x) => x, + _ => unreachable!(), + } +} + +fn cant_move_out_rc(rc: Rc) -> Struct { + match rc { + //~^ ERROR: cannot move out of a shared reference + deref!(x) => x, + _ => unreachable!(), + } +} + +fn main() {} diff --git a/tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.stderr b/tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.stderr new file mode 100644 index 00000000000..108db6d9e4b --- /dev/null +++ b/tests/ui/pattern/deref-patterns/cant_move_out_of_pattern.stderr @@ -0,0 +1,27 @@ +error[E0507]: cannot move out of a shared reference + --> $DIR/cant_move_out_of_pattern.rs:9:11 + | +LL | match b { + | ^ +LL | +LL | deref!(x) => x, + | - + | | + | data moved here + | move occurs because `x` has type `Struct`, which does not implement the `Copy` trait + +error[E0507]: cannot move out of a shared reference + --> $DIR/cant_move_out_of_pattern.rs:17:11 + | +LL | match rc { + | ^^ +LL | +LL | deref!(x) => x, + | - + | | + | data moved here + | move occurs because `x` has type `Struct`, which does not implement the `Copy` trait + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0507`. diff --git a/tests/ui/pattern/deref-patterns/typeck.rs b/tests/ui/pattern/deref-patterns/typeck.rs index ead6dcdbaf0..f23f7042cd8 100644 --- a/tests/ui/pattern/deref-patterns/typeck.rs +++ b/tests/ui/pattern/deref-patterns/typeck.rs @@ -4,6 +4,8 @@ use std::rc::Rc; +struct Struct; + fn main() { let vec: Vec = Vec::new(); match vec { @@ -22,10 +24,12 @@ fn main() { deref!(1..) => {} _ => {} } - // FIXME(deref_patterns): fails to typecheck because `"foo"` has type &str but deref creates a - // place of type `str`. - // match "foo".to_string() { - // box "foo" => {} - // _ => {} - // } + let _: &Struct = match &Rc::new(Struct) { + deref!(x) => x, + _ => unreachable!(), + }; + let _: &[Struct] = match &Rc::new(vec![Struct]) { + deref!(deref!(x)) => x, + _ => unreachable!(), + }; } diff --git a/tests/ui/pattern/deref-patterns/typeck_fail.rs b/tests/ui/pattern/deref-patterns/typeck_fail.rs new file mode 100644 index 00000000000..040118449ec --- /dev/null +++ b/tests/ui/pattern/deref-patterns/typeck_fail.rs @@ -0,0 +1,17 @@ +#![feature(deref_patterns)] +#![allow(incomplete_features)] + +fn main() { + // FIXME(deref_patterns): fails to typecheck because `"foo"` has type &str but deref creates a + // place of type `str`. + match "foo".to_string() { + deref!("foo") => {} + //~^ ERROR: mismatched types + _ => {} + } + match &"foo".to_string() { + deref!("foo") => {} + //~^ ERROR: mismatched types + _ => {} + } +} diff --git a/tests/ui/pattern/deref-patterns/typeck_fail.stderr b/tests/ui/pattern/deref-patterns/typeck_fail.stderr new file mode 100644 index 00000000000..1c14802745a --- /dev/null +++ b/tests/ui/pattern/deref-patterns/typeck_fail.stderr @@ -0,0 +1,19 @@ +error[E0308]: mismatched types + --> $DIR/typeck_fail.rs:8:16 + | +LL | match "foo".to_string() { + | ----------------- this expression has type `String` +LL | deref!("foo") => {} + | ^^^^^ expected `str`, found `&str` + +error[E0308]: mismatched types + --> $DIR/typeck_fail.rs:13:16 + | +LL | match &"foo".to_string() { + | ------------------ this expression has type `&String` +LL | deref!("foo") => {} + | ^^^^^ expected `str`, found `&str` + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0308`.