Rework the ByMoveBody shim to actually work correctly

This commit is contained in:
Michael Goulet 2024-04-05 15:12:34 -04:00
parent 1921968cc5
commit 3674032eb2
5 changed files with 336 additions and 36 deletions

View File

@ -60,14 +60,13 @@
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
//! we use this "by move" body instead.
use itertools::Itertools;
use rustc_data_structures::unord::UnordSet;
use rustc_data_structures::unord::UnordMap;
use rustc_hir as hir;
use rustc_middle::hir::place::{Projection, ProjectionKind};
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{self, dump_mir, MirPass};
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
use rustc_target::abi::FieldIdx;
use rustc_target::abi::{FieldIdx, VariantIdx};
pub struct ByMoveBody;
@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
.tuple_fields()
.len();
let mut by_ref_fields = UnordSet::default();
for (idx, (coroutine_capture, parent_capture)) in tcx
let mut field_remapping = UnordMap::default();
let mut parent_captures =
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
for (child_field_idx, child_capture) in tcx
.closure_captures(coroutine_def_id)
.iter()
.copied()
// By construction we capture all the args first.
.skip(num_args)
.zip_eq(tcx.closure_captures(parent_def_id))
.enumerate()
{
// This upvar is captured by-move from the parent closure, but by-ref
// from the inner async block. That means that it's being borrowed from
// the outer closure body -- we need to change the coroutine to take the
// upvar by value.
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
assert_ne!(
coroutine_kind,
ty::ClosureKind::FnOnce,
"`FnOnce` coroutine-closures return coroutines that capture from \
their body; it will always result in a borrowck error!"
);
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
}
loop {
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
bug!("we ran out of parent captures!")
};
// Make sure we're actually talking about the same capture.
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
if !std::iter::zip(
&child_capture.place.projections,
&parent_capture.place.projections,
)
.all(|(child, parent)| child.kind == parent.kind)
{
// Skip this field.
let _ = parent_captures.next().unwrap();
continue;
}
let child_precise_captures =
&child_capture.place.projections[parent_capture.place.projections.len()..];
let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
if needs_deref {
assert_ne!(
coroutine_kind,
ty::ClosureKind::FnOnce,
"`FnOnce` coroutine-closures return coroutines that capture from \
their body; it will always result in a borrowck error!"
);
}
let mut parent_capture_ty = parent_capture.place.ty();
parent_capture_ty = match parent_capture.info.capture_kind {
ty::UpvarCapture::ByValue => parent_capture_ty,
ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
tcx,
tcx.lifetimes.re_erased,
parent_capture_ty,
kind.to_mutbl_lossy(),
),
};
field_remapping.insert(
FieldIdx::from_usize(child_field_idx + num_args),
(
FieldIdx::from_usize(parent_field_idx + num_args),
parent_capture_ty,
needs_deref,
child_precise_captures,
),
);
break;
}
}
if coroutine_kind == ty::ClosureKind::FnOnce {
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
return;
}
let by_move_coroutine_ty = tcx
@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
);
let mut by_move_body = body.clone();
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
coroutine_def_id: coroutine_def_id.to_def_id(),
@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
struct MakeByMoveBody<'tcx> {
tcx: TyCtxt<'tcx>,
by_ref_fields: UnordSet<FieldIdx>,
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,
by_move_coroutine_ty: Ty<'tcx>,
}
@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
location: mir::Location,
) {
if place.local == ty::CAPTURE_STRUCT_LOCAL
&& let Some((&mir::ProjectionElem::Field(idx, ty), projection)) =
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
place.projection.split_first()
&& self.by_ref_fields.contains(&idx)
&& let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
self.field_remapping.get(&idx)
{
let (begin, end) = projection.split_first().unwrap();
// FIXME(async_closures): I'm actually a bit surprised to see that we always
// initially deref the by-ref upvars. If this is not actually true, then we
// will at least get an ICE that explains why this isn't true :^)
assert_eq!(*begin, mir::ProjectionElem::Deref);
// Peel one ref off of the ty.
let peeled_ty = ty.builtin_deref(true).unwrap().ty;
let final_deref = if needs_deref {
let Some((mir::ProjectionElem::Deref, rest)) = projection.split_first() else {
bug!();
};
rest
} else {
projection
};
let additional_projections =
additional_projections.iter().map(|elem| match elem.kind {
ProjectionKind::Deref => mir::ProjectionElem::Deref,
ProjectionKind::Field(idx, VariantIdx::ZERO) => {
mir::ProjectionElem::Field(idx, elem.ty)
}
_ => unreachable!("precise captures only through fields and derefs"),
});
*place = mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems_from_iter(
[mir::ProjectionElem::Field(idx, peeled_ty)]
[mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
.into_iter()
.chain(end.iter().copied()),
.chain(additional_projections)
.chain(final_deref.iter().copied()),
),
};
}

View File

@ -0,0 +1,29 @@
after call
after await
fixed
uncaptured
after call
after await
fixed
uncaptured
after call
after await
fixed
uncaptured
after call
after await
fixed
untouched
after call
drop first
after await
uncaptured
after call
drop first
after await
uncaptured

View File

@ -0,0 +1,29 @@
after call
after await
fixed
uncaptured
after call
after await
fixed
uncaptured
after call
fixed
after await
uncaptured
after call
after await
fixed
untouched
after call
drop first
after await
uncaptured
after call
drop first
after await
uncaptured

View File

@ -0,0 +1,29 @@
after call
after await
fixed
uncaptured
after call
after await
fixed
uncaptured
after call
fixed
after await
uncaptured
after call
after await
fixed
untouched
after call
drop first
after await
uncaptured
after call
drop first
after await
uncaptured

View File

@ -0,0 +1,157 @@
//@ aux-build:block-on.rs
//@ edition:2021
//@ run-pass
//@ check-run-results
//@ revisions: call call_once force_once
// call - Call the closure regularly.
// call_once - Call the closure w/ `async FnOnce`, so exercising the by_move shim.
// force_once - Force the closure mode to `FnOnce`, so exercising what was fixed
// in <https://github.com/rust-lang/rust/pull/123350>.
#![feature(async_closure)]
#![allow(unused_mut)]
extern crate block_on;
#[cfg(any(call, force_once))]
macro_rules! call {
($c:expr) => { ($c)() }
}
#[cfg(call_once)]
async fn call_once(f: impl async FnOnce()) {
f().await
}
#[cfg(call_once)]
macro_rules! call {
($c:expr) => { call_once($c) }
}
#[cfg(not(force_once))]
macro_rules! guidance {
($c:expr) => { $c }
}
#[cfg(force_once)]
fn infer_fnonce(c: impl async FnOnce()) -> impl async FnOnce() { c }
#[cfg(force_once)]
macro_rules! guidance {
($c:expr) => { infer_fnonce($c) }
}
#[derive(Debug)]
struct Drop(&'static str);
impl std::ops::Drop for Drop {
fn drop(&mut self) {
println!("{}", self.0);
}
}
struct S {
a: i32,
b: Drop,
c: Drop,
}
async fn async_main() {
// Precise capture struct
{
let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
let mut c = guidance!(async || {
s.a = 2;
let w = &mut s.b;
w.0 = "fixed";
});
s.c.0 = "uncaptured";
let fut = call!(c);
println!("after call");
fut.await;
println!("after await");
}
println!();
// Precise capture &mut struct
{
let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
let mut c = guidance!(async || {
s.a = 2;
let w = &mut s.b;
w.0 = "fixed";
});
s.c.0 = "uncaptured";
let fut = call!(c);
println!("after call");
fut.await;
println!("after await");
}
println!();
// Precise capture struct by move
{
let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
let mut c = guidance!(async move || {
s.a = 2;
let w = &mut s.b;
w.0 = "fixed";
});
s.c.0 = "uncaptured";
let fut = call!(c);
println!("after call");
fut.await;
println!("after await");
}
println!();
// Precise capture &mut struct by move
{
let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
let mut c = guidance!(async move || {
s.a = 2;
let w = &mut s.b;
w.0 = "fixed";
});
// `s` is still captured fully as `&mut S`.
let fut = call!(c);
println!("after call");
fut.await;
println!("after await");
}
println!();
// Precise capture struct, consume field
{
let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
let c = guidance!(async move || {
// s.a = 2; // FIXME(async_closures): Figure out why this fails
drop(s.b);
});
s.c.0 = "uncaptured";
let fut = call!(c);
println!("after call");
fut.await;
println!("after await");
}
println!();
// Precise capture struct by move, consume field
{
let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
let c = guidance!(async move || {
// s.a = 2; // FIXME(async_closures): Figure out why this fails
drop(s.b);
});
s.c.0 = "uncaptured";
let fut = call!(c);
println!("after call");
fut.await;
println!("after await");
}
}
fn main() {
block_on::block_on(async_main());
}