Rework the ByMoveBody shim to actually work correctly
This commit is contained in:
parent
1921968cc5
commit
3674032eb2
@ -60,14 +60,13 @@
|
||||
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
|
||||
//! we use this "by move" body instead.
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use rustc_data_structures::unord::UnordSet;
|
||||
use rustc_data_structures::unord::UnordMap;
|
||||
use rustc_hir as hir;
|
||||
use rustc_middle::hir::place::{Projection, ProjectionKind};
|
||||
use rustc_middle::mir::visit::MutVisitor;
|
||||
use rustc_middle::mir::{self, dump_mir, MirPass};
|
||||
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
|
||||
use rustc_target::abi::FieldIdx;
|
||||
use rustc_target::abi::{FieldIdx, VariantIdx};
|
||||
|
||||
pub struct ByMoveBody;
|
||||
|
||||
@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
|
||||
.tuple_fields()
|
||||
.len();
|
||||
|
||||
let mut by_ref_fields = UnordSet::default();
|
||||
for (idx, (coroutine_capture, parent_capture)) in tcx
|
||||
let mut field_remapping = UnordMap::default();
|
||||
|
||||
let mut parent_captures =
|
||||
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
|
||||
|
||||
for (child_field_idx, child_capture) in tcx
|
||||
.closure_captures(coroutine_def_id)
|
||||
.iter()
|
||||
.copied()
|
||||
// By construction we capture all the args first.
|
||||
.skip(num_args)
|
||||
.zip_eq(tcx.closure_captures(parent_def_id))
|
||||
.enumerate()
|
||||
{
|
||||
// This upvar is captured by-move from the parent closure, but by-ref
|
||||
// from the inner async block. That means that it's being borrowed from
|
||||
// the outer closure body -- we need to change the coroutine to take the
|
||||
// upvar by value.
|
||||
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
|
||||
assert_ne!(
|
||||
coroutine_kind,
|
||||
ty::ClosureKind::FnOnce,
|
||||
"`FnOnce` coroutine-closures return coroutines that capture from \
|
||||
their body; it will always result in a borrowck error!"
|
||||
);
|
||||
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
|
||||
}
|
||||
loop {
|
||||
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
|
||||
bug!("we ran out of parent captures!")
|
||||
};
|
||||
|
||||
// Make sure we're actually talking about the same capture.
|
||||
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
|
||||
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
|
||||
if !std::iter::zip(
|
||||
&child_capture.place.projections,
|
||||
&parent_capture.place.projections,
|
||||
)
|
||||
.all(|(child, parent)| child.kind == parent.kind)
|
||||
{
|
||||
// Skip this field.
|
||||
let _ = parent_captures.next().unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
let child_precise_captures =
|
||||
&child_capture.place.projections[parent_capture.place.projections.len()..];
|
||||
|
||||
let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
|
||||
if needs_deref {
|
||||
assert_ne!(
|
||||
coroutine_kind,
|
||||
ty::ClosureKind::FnOnce,
|
||||
"`FnOnce` coroutine-closures return coroutines that capture from \
|
||||
their body; it will always result in a borrowck error!"
|
||||
);
|
||||
}
|
||||
|
||||
let mut parent_capture_ty = parent_capture.place.ty();
|
||||
parent_capture_ty = match parent_capture.info.capture_kind {
|
||||
ty::UpvarCapture::ByValue => parent_capture_ty,
|
||||
ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
|
||||
tcx,
|
||||
tcx.lifetimes.re_erased,
|
||||
parent_capture_ty,
|
||||
kind.to_mutbl_lossy(),
|
||||
),
|
||||
};
|
||||
|
||||
field_remapping.insert(
|
||||
FieldIdx::from_usize(child_field_idx + num_args),
|
||||
(
|
||||
FieldIdx::from_usize(parent_field_idx + num_args),
|
||||
parent_capture_ty,
|
||||
needs_deref,
|
||||
child_precise_captures,
|
||||
),
|
||||
);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if coroutine_kind == ty::ClosureKind::FnOnce {
|
||||
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
|
||||
return;
|
||||
}
|
||||
|
||||
let by_move_coroutine_ty = tcx
|
||||
@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
|
||||
);
|
||||
|
||||
let mut by_move_body = body.clone();
|
||||
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
|
||||
MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
|
||||
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
|
||||
by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
|
||||
coroutine_def_id: coroutine_def_id.to_def_id(),
|
||||
@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
|
||||
|
||||
struct MakeByMoveBody<'tcx> {
|
||||
tcx: TyCtxt<'tcx>,
|
||||
by_ref_fields: UnordSet<FieldIdx>,
|
||||
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,
|
||||
by_move_coroutine_ty: Ty<'tcx>,
|
||||
}
|
||||
|
||||
@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
|
||||
location: mir::Location,
|
||||
) {
|
||||
if place.local == ty::CAPTURE_STRUCT_LOCAL
|
||||
&& let Some((&mir::ProjectionElem::Field(idx, ty), projection)) =
|
||||
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
|
||||
place.projection.split_first()
|
||||
&& self.by_ref_fields.contains(&idx)
|
||||
&& let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
|
||||
self.field_remapping.get(&idx)
|
||||
{
|
||||
let (begin, end) = projection.split_first().unwrap();
|
||||
// FIXME(async_closures): I'm actually a bit surprised to see that we always
|
||||
// initially deref the by-ref upvars. If this is not actually true, then we
|
||||
// will at least get an ICE that explains why this isn't true :^)
|
||||
assert_eq!(*begin, mir::ProjectionElem::Deref);
|
||||
// Peel one ref off of the ty.
|
||||
let peeled_ty = ty.builtin_deref(true).unwrap().ty;
|
||||
let final_deref = if needs_deref {
|
||||
let Some((mir::ProjectionElem::Deref, rest)) = projection.split_first() else {
|
||||
bug!();
|
||||
};
|
||||
rest
|
||||
} else {
|
||||
projection
|
||||
};
|
||||
|
||||
let additional_projections =
|
||||
additional_projections.iter().map(|elem| match elem.kind {
|
||||
ProjectionKind::Deref => mir::ProjectionElem::Deref,
|
||||
ProjectionKind::Field(idx, VariantIdx::ZERO) => {
|
||||
mir::ProjectionElem::Field(idx, elem.ty)
|
||||
}
|
||||
_ => unreachable!("precise captures only through fields and derefs"),
|
||||
});
|
||||
|
||||
*place = mir::Place {
|
||||
local: place.local,
|
||||
projection: self.tcx.mk_place_elems_from_iter(
|
||||
[mir::ProjectionElem::Field(idx, peeled_ty)]
|
||||
[mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
|
||||
.into_iter()
|
||||
.chain(end.iter().copied()),
|
||||
.chain(additional_projections)
|
||||
.chain(final_deref.iter().copied()),
|
||||
),
|
||||
};
|
||||
}
|
||||
|
@ -0,0 +1,29 @@
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
untouched
|
||||
|
||||
after call
|
||||
drop first
|
||||
after await
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
drop first
|
||||
after await
|
||||
uncaptured
|
@ -0,0 +1,29 @@
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
fixed
|
||||
after await
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
untouched
|
||||
|
||||
after call
|
||||
drop first
|
||||
after await
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
drop first
|
||||
after await
|
||||
uncaptured
|
@ -0,0 +1,29 @@
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
fixed
|
||||
after await
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
after await
|
||||
fixed
|
||||
untouched
|
||||
|
||||
after call
|
||||
drop first
|
||||
after await
|
||||
uncaptured
|
||||
|
||||
after call
|
||||
drop first
|
||||
after await
|
||||
uncaptured
|
157
tests/ui/async-await/async-closures/precise-captures.rs
Normal file
157
tests/ui/async-await/async-closures/precise-captures.rs
Normal file
@ -0,0 +1,157 @@
|
||||
//@ aux-build:block-on.rs
|
||||
//@ edition:2021
|
||||
//@ run-pass
|
||||
//@ check-run-results
|
||||
//@ revisions: call call_once force_once
|
||||
|
||||
// call - Call the closure regularly.
|
||||
// call_once - Call the closure w/ `async FnOnce`, so exercising the by_move shim.
|
||||
// force_once - Force the closure mode to `FnOnce`, so exercising what was fixed
|
||||
// in <https://github.com/rust-lang/rust/pull/123350>.
|
||||
|
||||
#![feature(async_closure)]
|
||||
#![allow(unused_mut)]
|
||||
|
||||
extern crate block_on;
|
||||
|
||||
#[cfg(any(call, force_once))]
|
||||
macro_rules! call {
|
||||
($c:expr) => { ($c)() }
|
||||
}
|
||||
|
||||
#[cfg(call_once)]
|
||||
async fn call_once(f: impl async FnOnce()) {
|
||||
f().await
|
||||
}
|
||||
|
||||
#[cfg(call_once)]
|
||||
macro_rules! call {
|
||||
($c:expr) => { call_once($c) }
|
||||
}
|
||||
|
||||
#[cfg(not(force_once))]
|
||||
macro_rules! guidance {
|
||||
($c:expr) => { $c }
|
||||
}
|
||||
|
||||
#[cfg(force_once)]
|
||||
fn infer_fnonce(c: impl async FnOnce()) -> impl async FnOnce() { c }
|
||||
|
||||
#[cfg(force_once)]
|
||||
macro_rules! guidance {
|
||||
($c:expr) => { infer_fnonce($c) }
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Drop(&'static str);
|
||||
|
||||
impl std::ops::Drop for Drop {
|
||||
fn drop(&mut self) {
|
||||
println!("{}", self.0);
|
||||
}
|
||||
}
|
||||
|
||||
struct S {
|
||||
a: i32,
|
||||
b: Drop,
|
||||
c: Drop,
|
||||
}
|
||||
|
||||
async fn async_main() {
|
||||
// Precise capture struct
|
||||
{
|
||||
let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||
let mut c = guidance!(async || {
|
||||
s.a = 2;
|
||||
let w = &mut s.b;
|
||||
w.0 = "fixed";
|
||||
});
|
||||
s.c.0 = "uncaptured";
|
||||
let fut = call!(c);
|
||||
println!("after call");
|
||||
fut.await;
|
||||
println!("after await");
|
||||
}
|
||||
println!();
|
||||
|
||||
// Precise capture &mut struct
|
||||
{
|
||||
let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||
let mut c = guidance!(async || {
|
||||
s.a = 2;
|
||||
let w = &mut s.b;
|
||||
w.0 = "fixed";
|
||||
});
|
||||
s.c.0 = "uncaptured";
|
||||
let fut = call!(c);
|
||||
println!("after call");
|
||||
fut.await;
|
||||
println!("after await");
|
||||
}
|
||||
println!();
|
||||
|
||||
// Precise capture struct by move
|
||||
{
|
||||
let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||
let mut c = guidance!(async move || {
|
||||
s.a = 2;
|
||||
let w = &mut s.b;
|
||||
w.0 = "fixed";
|
||||
});
|
||||
s.c.0 = "uncaptured";
|
||||
let fut = call!(c);
|
||||
println!("after call");
|
||||
fut.await;
|
||||
println!("after await");
|
||||
}
|
||||
println!();
|
||||
|
||||
// Precise capture &mut struct by move
|
||||
{
|
||||
let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||
let mut c = guidance!(async move || {
|
||||
s.a = 2;
|
||||
let w = &mut s.b;
|
||||
w.0 = "fixed";
|
||||
});
|
||||
// `s` is still captured fully as `&mut S`.
|
||||
let fut = call!(c);
|
||||
println!("after call");
|
||||
fut.await;
|
||||
println!("after await");
|
||||
}
|
||||
println!();
|
||||
|
||||
// Precise capture struct, consume field
|
||||
{
|
||||
let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
|
||||
let c = guidance!(async move || {
|
||||
// s.a = 2; // FIXME(async_closures): Figure out why this fails
|
||||
drop(s.b);
|
||||
});
|
||||
s.c.0 = "uncaptured";
|
||||
let fut = call!(c);
|
||||
println!("after call");
|
||||
fut.await;
|
||||
println!("after await");
|
||||
}
|
||||
println!();
|
||||
|
||||
// Precise capture struct by move, consume field
|
||||
{
|
||||
let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
|
||||
let c = guidance!(async move || {
|
||||
// s.a = 2; // FIXME(async_closures): Figure out why this fails
|
||||
drop(s.b);
|
||||
});
|
||||
s.c.0 = "uncaptured";
|
||||
let fut = call!(c);
|
||||
println!("after call");
|
||||
fut.await;
|
||||
println!("after await");
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
block_on::block_on(async_main());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user