diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index 99dbb342268..60b52fba219 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -3,6 +3,8 @@ //! be a coroutine body that takes all of its upvars by-move, and which we stash //! into the `CoroutineInfo` for all coroutines returned by coroutine-closures. +use itertools::Itertools; + use rustc_data_structures::unord::UnordSet; use rustc_hir as hir; use rustc_middle::mir::visit::MutVisitor; @@ -26,36 +28,68 @@ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { if coroutine_ty.references_error() { return; } - let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") }; - let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(); + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") }; + let args = args.as_coroutine(); + + let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap(); if coroutine_kind == ty::ClosureKind::FnOnce { return; } - let mut by_ref_fields = UnordSet::default(); - let by_move_upvars = Ty::new_tup_from_iter( - tcx, - tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { - if capture.is_by_ref() { - by_ref_fields.insert(FieldIdx::from_usize(idx)); - } - capture.place.ty() - }), + let parent_def_id = tcx.local_parent(coroutine_def_id); + let ty::CoroutineClosure(_, parent_args) = + *tcx.type_of(parent_def_id).instantiate_identity().kind() + else { + bug!(); + }; + let parent_args = parent_args.as_coroutine_closure(); + let parent_upvars_ty = parent_args.tupled_upvars_ty(); + let tupled_inputs_ty = tcx.instantiate_bound_regions_with_erased( + parent_args.coroutine_closure_sig().map_bound(|sig| sig.tupled_inputs_ty), ); + let num_args = tupled_inputs_ty.tuple_fields().len(); + + let mut by_ref_fields = UnordSet::default(); + for (idx, (coroutine_capture, parent_capture)) in tcx + .closure_captures(coroutine_def_id) + .iter() + // By construction we capture all the args first. + .skip(num_args) + .zip_eq(tcx.closure_captures(parent_def_id)) + .enumerate() + { + // This argument is captured by-move from the parent closure, but by-ref + // from the inner async block. That means that it's being borrowed from + // the closure body -- we need to change the coroutine take it by move. + if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() { + by_ref_fields.insert(FieldIdx::from_usize(num_args + idx)); + } + + // Make sure we're actually talking about the same capture. + assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty()); + } + let by_move_coroutine_ty = Ty::new_coroutine( tcx, coroutine_def_id.to_def_id(), ty::CoroutineArgs::new( tcx, ty::CoroutineArgsParts { - parent_args: args.as_coroutine().parent_args(), + parent_args: args.parent_args(), kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), - resume_ty: args.as_coroutine().resume_ty(), - yield_ty: args.as_coroutine().yield_ty(), - return_ty: args.as_coroutine().return_ty(), - witness: args.as_coroutine().witness(), - tupled_upvars_ty: by_move_upvars, + resume_ty: args.resume_ty(), + yield_ty: args.yield_ty(), + return_ty: args.return_ty(), + witness: args.witness(), + // Concatenate the args + closure's captures (since they're all by move). + tupled_upvars_ty: Ty::new_tup_from_iter( + tcx, + tupled_inputs_ty + .tuple_fields() + .iter() + .chain(parent_upvars_ty.tuple_fields()), + ), }, ) .args, diff --git a/src/tools/miri/tests/pass/async-closure-captures.rs b/src/tools/miri/tests/pass/async-closure-captures.rs new file mode 100644 index 00000000000..acff4a38338 --- /dev/null +++ b/src/tools/miri/tests/pass/async-closure-captures.rs @@ -0,0 +1,89 @@ +#![feature(async_closure, noop_waker)] + +use std::future::Future; +use std::pin::pin; +use std::task::*; + +pub fn block_on(fut: impl Future) -> T { + let mut fut = pin!(fut); + let ctx = &mut Context::from_waker(Waker::noop()); + + loop { + match fut.as_mut().poll(ctx) { + Poll::Pending => {} + Poll::Ready(t) => break t, + } + } +} + +fn main() { + block_on(async_main()); +} + +async fn call(f: &impl async Fn() -> T) -> T { + f().await +} + +async fn call_once(f: impl async FnOnce() -> T) -> T { + f().await +} + +#[derive(Debug)] +#[allow(unused)] +struct Hello(i32); + +async fn async_main() { + // Capture something by-ref + { + let x = Hello(0); + let c = async || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + + let x = &Hello(1); + let c = async || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + } + + // Capture something and consume it (force to `AsyncFnOnce`) + { + let x = Hello(2); + let c = async || { + println!("{x:?}"); + drop(x); + }; + call_once(c).await; + } + + // Capture something with `move`, don't consume it + { + let x = Hello(3); + let c = async move || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + + let x = &Hello(4); + let c = async move || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + } + + // Capture something with `move`, also consume it (so `AsyncFnOnce`) + { + let x = Hello(5); + let c = async move || { + println!("{x:?}"); + drop(x); + }; + call_once(c).await; + } +} diff --git a/src/tools/miri/tests/pass/async-closure-captures.stdout b/src/tools/miri/tests/pass/async-closure-captures.stdout new file mode 100644 index 00000000000..a0db6d236fe --- /dev/null +++ b/src/tools/miri/tests/pass/async-closure-captures.stdout @@ -0,0 +1,10 @@ +Hello(0) +Hello(0) +Hello(1) +Hello(1) +Hello(2) +Hello(3) +Hello(3) +Hello(4) +Hello(4) +Hello(5) diff --git a/tests/ui/async-await/async-closures/captures.rs b/tests/ui/async-await/async-closures/captures.rs new file mode 100644 index 00000000000..46bbf53f0a7 --- /dev/null +++ b/tests/ui/async-await/async-closures/captures.rs @@ -0,0 +1,80 @@ +//@ aux-build:block-on.rs +//@ edition:2021 +//@ run-pass +//@ check-run-results + +#![feature(async_closure)] + +extern crate block_on; + +fn main() { + block_on::block_on(async_main()); +} + +async fn call(f: &impl async Fn() -> T) -> T { + f().await +} + +async fn call_once(f: impl async FnOnce() -> T) -> T { + f().await +} + +#[derive(Debug)] +#[allow(unused)] +struct Hello(i32); + +async fn async_main() { + // Capture something by-ref + { + let x = Hello(0); + let c = async || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + + let x = &Hello(1); + let c = async || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + } + + // Capture something and consume it (force to `AsyncFnOnce`) + { + let x = Hello(2); + let c = async || { + println!("{x:?}"); + drop(x); + }; + call_once(c).await; + } + + // Capture something with `move`, don't consume it + { + let x = Hello(3); + let c = async move || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + + let x = &Hello(4); + let c = async move || { + println!("{x:?}"); + }; + call(&c).await; + call_once(c).await; + } + + // Capture something with `move`, also consume it (so `AsyncFnOnce`) + { + let x = Hello(5); + let c = async move || { + println!("{x:?}"); + drop(x); + }; + call_once(c).await; + } +} diff --git a/tests/ui/async-await/async-closures/captures.run.stdout b/tests/ui/async-await/async-closures/captures.run.stdout new file mode 100644 index 00000000000..a0db6d236fe --- /dev/null +++ b/tests/ui/async-await/async-closures/captures.run.stdout @@ -0,0 +1,10 @@ +Hello(0) +Hello(0) +Hello(1) +Hello(1) +Hello(2) +Hello(3) +Hello(3) +Hello(4) +Hello(4) +Hello(5)