From 568703c4bd5102c2d596e75db492c37d0102d16b Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 9 Apr 2024 11:11:17 -0400 Subject: [PATCH] Use a helper to zip together parent and child captures for coroutine-closures --- compiler/rustc_middle/src/ty/closure.rs | 67 ++++++++++++++++ compiler/rustc_middle/src/ty/mod.rs | 7 +- .../src/coroutine/by_move_body.rs | 78 +++---------------- 3 files changed, 81 insertions(+), 71 deletions(-) diff --git a/compiler/rustc_middle/src/ty/closure.rs b/compiler/rustc_middle/src/ty/closure.rs index 95d1e08b58b..211d403998f 100644 --- a/compiler/rustc_middle/src/ty/closure.rs +++ b/compiler/rustc_middle/src/ty/closure.rs @@ -6,6 +6,7 @@ use std::fmt::Write; use crate::query::Providers; +use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxIndexMap; use rustc_hir as hir; use rustc_hir::def_id::LocalDefId; @@ -415,6 +416,72 @@ pub fn to_mutbl_lossy(self) -> hir::Mutability { } } +pub fn analyze_coroutine_closure_captures<'a, 'tcx: 'a, T>( + parent_captures: impl IntoIterator>, + child_captures: impl IntoIterator>, + mut for_each: impl FnMut((usize, &'a CapturedPlace<'tcx>), (usize, &'a CapturedPlace<'tcx>)) -> T, +) -> impl Iterator + Captures<'a> + Captures<'tcx> { + std::iter::from_coroutine(move || { + let mut child_captures = child_captures.into_iter().enumerate().peekable(); + + // One parent capture may correspond to several child captures if we end up + // refining the set of captures via edition-2021 precise captures. We want to + // match up any number of child captures with one parent capture, so we keep + // peeking off this `Peekable` until the child doesn't match anymore. + for (parent_field_idx, parent_capture) in parent_captures.into_iter().enumerate() { + // Make sure we use every field at least once, b/c why are we capturing something + // if it's not used in the inner coroutine. + let mut field_used_at_least_once = false; + + // A parent matches a child if they share the same prefix of projections. + // The child may have more, if it is capturing sub-fields out of + // something that is captured by-move in the parent closure. + while child_captures.peek().map_or(false, |(_, child_capture)| { + child_prefix_matches_parent_projections(parent_capture, child_capture) + }) { + let (child_field_idx, child_capture) = child_captures.next().unwrap(); + // This analysis only makes sense if the parent capture is a + // prefix of the child capture. + assert!( + child_capture.place.projections.len() >= parent_capture.place.projections.len(), + "parent capture ({parent_capture:#?}) expected to be prefix of \ + child capture ({child_capture:#?})" + ); + + yield for_each( + (parent_field_idx, parent_capture), + (child_field_idx, child_capture), + ); + + field_used_at_least_once = true; + } + + // Make sure the field was used at least once. + assert!( + field_used_at_least_once, + "we captured {parent_capture:#?} but it was not used in the child coroutine?" + ); + } + assert_eq!(child_captures.next(), None, "leftover child captures?"); + }) +} + +fn child_prefix_matches_parent_projections( + parent_capture: &ty::CapturedPlace<'_>, + child_capture: &ty::CapturedPlace<'_>, +) -> bool { + let HirPlaceBase::Upvar(parent_base) = parent_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + let HirPlaceBase::Upvar(child_base) = child_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + + parent_base.var_path.hir_id == child_base.var_path.hir_id + && std::iter::zip(&child_capture.place.projections, &parent_capture.place.projections) + .all(|(child, parent)| child.kind == parent.kind) +} + pub fn provide(providers: &mut Providers) { *providers = Providers { closure_typeinfo, ..*providers } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index ee4dc9744ac..e6b773ae512 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -77,9 +77,10 @@ pub use rustc_type_ir::*; pub use self::closure::{ - is_ancestor_or_same_capture, place_to_string_for_capture, BorrowKind, CaptureInfo, - CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, MinCaptureList, - RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath, CAPTURE_STRUCT_LOCAL, + analyze_coroutine_closure_captures, is_ancestor_or_same_capture, place_to_string_for_capture, + BorrowKind, CaptureInfo, CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, + MinCaptureList, RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath, + CAPTURE_STRUCT_LOCAL, }; pub use self::consts::{ Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree, diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index b26f968bf5e..3d6c1a95204 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -71,7 +71,7 @@ use rustc_data_structures::unord::UnordMap; use rustc_hir as hir; -use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind}; +use rustc_middle::hir::place::{Projection, ProjectionKind}; use rustc_middle::mir::visit::MutVisitor; use rustc_middle::mir::{self, dump_mir, MirPass}; use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt}; @@ -124,44 +124,10 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { .tuple_fields() .len(); - let mut field_remapping = UnordMap::default(); - - let mut child_captures = tcx - .closure_captures(coroutine_def_id) - .iter() - .copied() - // By construction we capture all the args first. - .skip(num_args) - .enumerate() - .peekable(); - - // One parent capture may correspond to several child captures if we end up - // refining the set of captures via edition-2021 precise captures. We want to - // match up any number of child captures with one parent capture, so we keep - // peeking off this `Peekable` until the child doesn't match anymore. - for (parent_field_idx, parent_capture) in - tcx.closure_captures(parent_def_id).iter().copied().enumerate() - { - // Make sure we use every field at least once, b/c why are we capturing something - // if it's not used in the inner coroutine. - let mut field_used_at_least_once = false; - - // A parent matches a child if they share the same prefix of projections. - // The child may have more, if it is capturing sub-fields out of - // something that is captured by-move in the parent closure. - while child_captures.peek().map_or(false, |(_, child_capture)| { - child_prefix_matches_parent_projections(parent_capture, child_capture) - }) { - let (child_field_idx, child_capture) = child_captures.next().unwrap(); - - // This analysis only makes sense if the parent capture is a - // prefix of the child capture. - assert!( - child_capture.place.projections.len() >= parent_capture.place.projections.len(), - "parent capture ({parent_capture:#?}) expected to be prefix of \ - child capture ({child_capture:#?})" - ); - + let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures( + tcx.closure_captures(parent_def_id).iter().copied(), + tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(), + |(parent_field_idx, parent_capture), (child_field_idx, child_capture)| { // Store this set of additional projections (fields and derefs). // We need to re-apply them later. let child_precise_captures = @@ -192,7 +158,7 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { ), }; - field_remapping.insert( + ( FieldIdx::from_usize(child_field_idx + num_args), ( FieldIdx::from_usize(parent_field_idx + num_args), @@ -200,18 +166,10 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { needs_deref, child_precise_captures, ), - ); - - field_used_at_least_once = true; - } - - // Make sure the field was used at least once. - assert!( - field_used_at_least_once, - "we captured {parent_capture:#?} but it was not used in the child coroutine?" - ); - } - assert_eq!(child_captures.next(), None, "leftover child captures?"); + ) + }, + ) + .collect(); if coroutine_kind == ty::ClosureKind::FnOnce { assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len()); @@ -241,22 +199,6 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { } } -fn child_prefix_matches_parent_projections( - parent_capture: &ty::CapturedPlace<'_>, - child_capture: &ty::CapturedPlace<'_>, -) -> bool { - let PlaceBase::Upvar(parent_base) = parent_capture.place.base else { - bug!("expected capture to be an upvar"); - }; - let PlaceBase::Upvar(child_base) = child_capture.place.base else { - bug!("expected capture to be an upvar"); - }; - - parent_base.var_path.hir_id == child_base.var_path.hir_id - && std::iter::zip(&child_capture.place.projections, &parent_capture.place.projections) - .all(|(child, parent)| child.kind == parent.kind) -} - struct MakeByMoveBody<'tcx> { tcx: TyCtxt<'tcx>, field_remapping: UnordMap, bool, &'tcx [Projection<'tcx>])>,