From a8c931410020840584a2efa5f77239a9c5fcb85c Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 8 Dec 2021 16:44:48 -0500 Subject: [PATCH] remove implicit .await from `core::future::join` --- library/core/src/future/join.rs | 100 ++++++++++++++++---------------- library/core/tests/future.rs | 17 ++++-- 2 files changed, 64 insertions(+), 53 deletions(-) diff --git a/library/core/src/future/join.rs b/library/core/src/future/join.rs index 03d106c969b..bed9f3dd51c 100644 --- a/library/core/src/future/join.rs +++ b/library/core/src/future/join.rs @@ -22,7 +22,7 @@ use crate::task::Poll; /// async fn two() -> usize { 2 } /// /// # let _ = async { -/// let x = join!(one(), two()); +/// let x = join!(one(), two()).await; /// assert_eq!(x, (1, 2)); /// # }; /// ``` @@ -39,7 +39,7 @@ use crate::task::Poll; /// async fn three() -> usize { 3 } /// /// # let _ = async { -/// let x = join!(one(), two(), three()); +/// let x = join!(one(), two(), three()).await; /// assert_eq!(x, (1, 2, 3)); /// # }; /// ``` @@ -71,61 +71,63 @@ pub macro join { }, @rest: () ) => {{ - // The futures and whether they have completed - let mut state = ( $( UnsafeCell::new(($fut, false)), )* ); + async move { + // The futures and whether they have completed + let mut state = ( $( UnsafeCell::new(($fut, false)), )* ); - // Make sure the futures don't panic - // if polled after completion, and - // store their output separately - let mut futures = ($( - ({ - let ( $($pos,)* state, .. ) = &state; + // Make sure the futures don't panic + // if polled after completion, and + // store their output separately + let mut futures = ($( + ({ + let ( $($pos,)* state, .. ) = &state; - poll_fn(move |cx| { - // SAFETY: each future borrows a distinct element - // of the tuple - let (fut, done) = unsafe { &mut *state.get() }; + poll_fn(move |cx| { + // SAFETY: each future borrows a distinct element + // of the tuple + let (fut, done) = unsafe { &mut *state.get() }; - if *done { - return Poll::Ready(None) - } + if *done { + return Poll::Ready(None) + } + + // SAFETY: The futures are never moved + match unsafe { Pin::new_unchecked(fut).poll(cx) } { + Poll::Ready(val) => { + *done = true; + Poll::Ready(Some(val)) + } + Poll::Pending => Poll::Pending + } + }) + }, None), + )*); + + poll_fn(move |cx| { + let mut done = true; + + $( + let ( $($pos,)* (fut, out), .. ) = &mut futures; // SAFETY: The futures are never moved match unsafe { Pin::new_unchecked(fut).poll(cx) } { - Poll::Ready(val) => { - *done = true; - Poll::Ready(Some(val)) - } - Poll::Pending => Poll::Pending + Poll::Ready(Some(val)) => *out = Some(val), + // the future was already done + Poll::Ready(None) => {}, + Poll::Pending => done = false, } - }) - }, None), - )*); + )* - poll_fn(move |cx| { - let mut done = true; - - $( - let ( $($pos,)* (fut, out), .. ) = &mut futures; - - // SAFETY: The futures are never moved - match unsafe { Pin::new_unchecked(fut).poll(cx) } { - Poll::Ready(Some(val)) => *out = Some(val), - // the future was already done - Poll::Ready(None) => {}, - Poll::Pending => done = false, + if done { + // Extract all the outputs + Poll::Ready(($({ + let ( $($pos,)* (_, val), .. ) = &mut futures; + val.unwrap() + }),*)) + } else { + Poll::Pending } - )* - - if done { - // Extract all the outputs - Poll::Ready(($({ - let ( $($pos,)* (_, val), .. ) = &mut futures; - val.unwrap() - }),*)) - } else { - Poll::Pending - } - }).await + }).await + } }} } diff --git a/library/core/tests/future.rs b/library/core/tests/future.rs index f47dcc70434..73249b1b8a4 100644 --- a/library/core/tests/future.rs +++ b/library/core/tests/future.rs @@ -32,13 +32,13 @@ fn poll_n(val: usize, num: usize) -> PollN { #[test] fn test_join() { block_on(async move { - let x = join!(async { 0 }); + let x = join!(async { 0 }).await; assert_eq!(x, 0); - let x = join!(async { 0 }, async { 1 }); + let x = join!(async { 0 }, async { 1 }).await; assert_eq!(x, (0, 1)); - let x = join!(async { 0 }, async { 1 }, async { 2 }); + let x = join!(async { 0 }, async { 1 }, async { 2 }).await; assert_eq!(x, (0, 1, 2)); let x = join!( @@ -50,8 +50,17 @@ fn test_join() { poll_n(5, 3), poll_n(6, 4), poll_n(7, 1) - ); + ) + .await; assert_eq!(x, (0, 1, 2, 3, 4, 5, 6, 7)); + + let y = String::new(); + let x = join!(async { + println!("{}", &y); + 1 + }) + .await; + assert_eq!(x, 1); }); }