Fix capture analysis for by-move closure bodies
This commit is contained in:
parent
88c2f4f5f5
commit
a1a1f41027
@ -3,6 +3,8 @@
|
||||
//! be a coroutine body that takes all of its upvars by-move, and which we stash
|
||||
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use rustc_data_structures::unord::UnordSet;
|
||||
use rustc_hir as hir;
|
||||
use rustc_middle::mir::visit::MutVisitor;
|
||||
@ -26,36 +28,68 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
|
||||
if coroutine_ty.references_error() {
|
||||
return;
|
||||
}
|
||||
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
|
||||
|
||||
let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
|
||||
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
|
||||
let args = args.as_coroutine();
|
||||
|
||||
let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
|
||||
if coroutine_kind == ty::ClosureKind::FnOnce {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut by_ref_fields = UnordSet::default();
|
||||
let by_move_upvars = Ty::new_tup_from_iter(
|
||||
tcx,
|
||||
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
|
||||
if capture.is_by_ref() {
|
||||
by_ref_fields.insert(FieldIdx::from_usize(idx));
|
||||
}
|
||||
capture.place.ty()
|
||||
}),
|
||||
let parent_def_id = tcx.local_parent(coroutine_def_id);
|
||||
let ty::CoroutineClosure(_, parent_args) =
|
||||
*tcx.type_of(parent_def_id).instantiate_identity().kind()
|
||||
else {
|
||||
bug!();
|
||||
};
|
||||
let parent_args = parent_args.as_coroutine_closure();
|
||||
let parent_upvars_ty = parent_args.tupled_upvars_ty();
|
||||
let tupled_inputs_ty = tcx.instantiate_bound_regions_with_erased(
|
||||
parent_args.coroutine_closure_sig().map_bound(|sig| sig.tupled_inputs_ty),
|
||||
);
|
||||
let num_args = tupled_inputs_ty.tuple_fields().len();
|
||||
|
||||
let mut by_ref_fields = UnordSet::default();
|
||||
for (idx, (coroutine_capture, parent_capture)) in tcx
|
||||
.closure_captures(coroutine_def_id)
|
||||
.iter()
|
||||
// By construction we capture all the args first.
|
||||
.skip(num_args)
|
||||
.zip_eq(tcx.closure_captures(parent_def_id))
|
||||
.enumerate()
|
||||
{
|
||||
// This argument is captured by-move from the parent closure, but by-ref
|
||||
// from the inner async block. That means that it's being borrowed from
|
||||
// the closure body -- we need to change the coroutine take it by move.
|
||||
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
|
||||
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
|
||||
}
|
||||
|
||||
// Make sure we're actually talking about the same capture.
|
||||
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
|
||||
}
|
||||
|
||||
let by_move_coroutine_ty = Ty::new_coroutine(
|
||||
tcx,
|
||||
coroutine_def_id.to_def_id(),
|
||||
ty::CoroutineArgs::new(
|
||||
tcx,
|
||||
ty::CoroutineArgsParts {
|
||||
parent_args: args.as_coroutine().parent_args(),
|
||||
parent_args: args.parent_args(),
|
||||
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
|
||||
resume_ty: args.as_coroutine().resume_ty(),
|
||||
yield_ty: args.as_coroutine().yield_ty(),
|
||||
return_ty: args.as_coroutine().return_ty(),
|
||||
witness: args.as_coroutine().witness(),
|
||||
tupled_upvars_ty: by_move_upvars,
|
||||
resume_ty: args.resume_ty(),
|
||||
yield_ty: args.yield_ty(),
|
||||
return_ty: args.return_ty(),
|
||||
witness: args.witness(),
|
||||
// Concatenate the args + closure's captures (since they're all by move).
|
||||
tupled_upvars_ty: Ty::new_tup_from_iter(
|
||||
tcx,
|
||||
tupled_inputs_ty
|
||||
.tuple_fields()
|
||||
.iter()
|
||||
.chain(parent_upvars_ty.tuple_fields()),
|
||||
),
|
||||
},
|
||||
)
|
||||
.args,
|
||||
|
89
src/tools/miri/tests/pass/async-closure-captures.rs
Normal file
89
src/tools/miri/tests/pass/async-closure-captures.rs
Normal file
@ -0,0 +1,89 @@
|
||||
#![feature(async_closure, noop_waker)]
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::pin;
|
||||
use std::task::*;
|
||||
|
||||
pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
|
||||
let mut fut = pin!(fut);
|
||||
let ctx = &mut Context::from_waker(Waker::noop());
|
||||
|
||||
loop {
|
||||
match fut.as_mut().poll(ctx) {
|
||||
Poll::Pending => {}
|
||||
Poll::Ready(t) => break t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
block_on(async_main());
|
||||
}
|
||||
|
||||
async fn call<T>(f: &impl async Fn() -> T) -> T {
|
||||
f().await
|
||||
}
|
||||
|
||||
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
|
||||
f().await
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(unused)]
|
||||
struct Hello(i32);
|
||||
|
||||
async fn async_main() {
|
||||
// Capture something by-ref
|
||||
{
|
||||
let x = Hello(0);
|
||||
let c = async || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
|
||||
let x = &Hello(1);
|
||||
let c = async || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
}
|
||||
|
||||
// Capture something and consume it (force to `AsyncFnOnce`)
|
||||
{
|
||||
let x = Hello(2);
|
||||
let c = async || {
|
||||
println!("{x:?}");
|
||||
drop(x);
|
||||
};
|
||||
call_once(c).await;
|
||||
}
|
||||
|
||||
// Capture something with `move`, don't consume it
|
||||
{
|
||||
let x = Hello(3);
|
||||
let c = async move || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
|
||||
let x = &Hello(4);
|
||||
let c = async move || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
}
|
||||
|
||||
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
|
||||
{
|
||||
let x = Hello(5);
|
||||
let c = async move || {
|
||||
println!("{x:?}");
|
||||
drop(x);
|
||||
};
|
||||
call_once(c).await;
|
||||
}
|
||||
}
|
10
src/tools/miri/tests/pass/async-closure-captures.stdout
Normal file
10
src/tools/miri/tests/pass/async-closure-captures.stdout
Normal file
@ -0,0 +1,10 @@
|
||||
Hello(0)
|
||||
Hello(0)
|
||||
Hello(1)
|
||||
Hello(1)
|
||||
Hello(2)
|
||||
Hello(3)
|
||||
Hello(3)
|
||||
Hello(4)
|
||||
Hello(4)
|
||||
Hello(5)
|
80
tests/ui/async-await/async-closures/captures.rs
Normal file
80
tests/ui/async-await/async-closures/captures.rs
Normal file
@ -0,0 +1,80 @@
|
||||
//@ aux-build:block-on.rs
|
||||
//@ edition:2021
|
||||
//@ run-pass
|
||||
//@ check-run-results
|
||||
|
||||
#![feature(async_closure)]
|
||||
|
||||
extern crate block_on;
|
||||
|
||||
fn main() {
|
||||
block_on::block_on(async_main());
|
||||
}
|
||||
|
||||
async fn call<T>(f: &impl async Fn() -> T) -> T {
|
||||
f().await
|
||||
}
|
||||
|
||||
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
|
||||
f().await
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(unused)]
|
||||
struct Hello(i32);
|
||||
|
||||
async fn async_main() {
|
||||
// Capture something by-ref
|
||||
{
|
||||
let x = Hello(0);
|
||||
let c = async || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
|
||||
let x = &Hello(1);
|
||||
let c = async || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
}
|
||||
|
||||
// Capture something and consume it (force to `AsyncFnOnce`)
|
||||
{
|
||||
let x = Hello(2);
|
||||
let c = async || {
|
||||
println!("{x:?}");
|
||||
drop(x);
|
||||
};
|
||||
call_once(c).await;
|
||||
}
|
||||
|
||||
// Capture something with `move`, don't consume it
|
||||
{
|
||||
let x = Hello(3);
|
||||
let c = async move || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
|
||||
let x = &Hello(4);
|
||||
let c = async move || {
|
||||
println!("{x:?}");
|
||||
};
|
||||
call(&c).await;
|
||||
call_once(c).await;
|
||||
}
|
||||
|
||||
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
|
||||
{
|
||||
let x = Hello(5);
|
||||
let c = async move || {
|
||||
println!("{x:?}");
|
||||
drop(x);
|
||||
};
|
||||
call_once(c).await;
|
||||
}
|
||||
}
|
10
tests/ui/async-await/async-closures/captures.run.stdout
Normal file
10
tests/ui/async-await/async-closures/captures.run.stdout
Normal file
@ -0,0 +1,10 @@
|
||||
Hello(0)
|
||||
Hello(0)
|
||||
Hello(1)
|
||||
Hello(1)
|
||||
Hello(2)
|
||||
Hello(3)
|
||||
Hello(3)
|
||||
Hello(4)
|
||||
Hello(4)
|
||||
Hello(5)
|
Loading…
Reference in New Issue
Block a user