From 454bff768228e8dee53462ee63f766521191f8b0 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 18 Dec 2023 01:42:19 +0000 Subject: [PATCH] Ensure `yield` expressions desugar correctly in async generators --- compiler/rustc_ast_lowering/src/expr.rs | 52 ++++++++++++------- .../coroutine/async-gen-yield-ty-is-unit.rs | 17 ++++++ 2 files changed, 49 insertions(+), 20 deletions(-) create mode 100644 tests/ui/coroutine/async-gen-yield-ty-is-unit.rs diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 11b5131b8d7..704f124dbcb 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -917,12 +917,13 @@ impl<'hir> LoweringContext<'_, 'hir> { let poll_expr = { let awaitee = self.expr_ident(span, awaitee_ident, awaitee_pat_hid); let ref_mut_awaitee = self.expr_mut_addr_of(span, awaitee); - let task_context = if let Some(task_context_hid) = self.task_context { - self.expr_ident_mut(span, task_context_ident, task_context_hid) - } else { - // Use of `await` outside of an async context, we cannot use `task_context` here. - self.expr_err(span, self.tcx.sess.span_delayed_bug(span, "no task_context hir id")) + + let Some(task_context_hid) = self.task_context else { + unreachable!("use of `await` outside of an async context."); }; + + let task_context = self.expr_ident_mut(span, task_context_ident, task_context_hid); + let new_unchecked = self.expr_call_lang_item_fn_mut( span, hir::LangItem::PinNewUnchecked, @@ -991,16 +992,14 @@ impl<'hir> LoweringContext<'_, 'hir> { ); let yield_expr = self.arena.alloc(yield_expr); - if let Some(task_context_hid) = self.task_context { - let lhs = self.expr_ident(span, task_context_ident, task_context_hid); - let assign = - self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span))); - self.stmt_expr(span, assign) - } else { - // Use of `await` outside of an async context. Return `yield_expr` so that we can - // proceed with type checking. - self.stmt(span, hir::StmtKind::Semi(yield_expr)) - } + let Some(task_context_hid) = self.task_context else { + unreachable!("use of `await` outside of an async context."); + }; + + let lhs = self.expr_ident(span, task_context_ident, task_context_hid); + let assign = + self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span))); + self.stmt_expr(span, assign) }; let loop_block = self.block_all(span, arena_vec![self; inner_match_stmt, yield_stmt], None); @@ -1635,19 +1634,32 @@ impl<'hir> LoweringContext<'_, 'hir> { } }; - let mut yielded = + let yielded = opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span)); if is_async_gen { - // yield async_gen_ready($expr); - yielded = self.expr_call_lang_item_fn( + // `yield $expr` is transformed into `task_context = yield async_gen_ready($expr)`. + // This ensures that we store our resumed `ResumeContext` correctly, and also that + // the apparent value of the `yield` expression is `()`. + let wrapped_yielded = self.expr_call_lang_item_fn( span, hir::LangItem::AsyncGenReady, std::slice::from_ref(yielded), ); - } + let yield_expr = self.arena.alloc( + self.expr(span, hir::ExprKind::Yield(wrapped_yielded, hir::YieldSource::Yield)), + ); - hir::ExprKind::Yield(yielded, hir::YieldSource::Yield) + let Some(task_context_hid) = self.task_context else { + unreachable!("use of `await` outside of an async context."); + }; + let task_context_ident = Ident::with_dummy_span(sym::_task_context); + let lhs = self.expr_ident(span, task_context_ident, task_context_hid); + + hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span)) + } else { + hir::ExprKind::Yield(yielded, hir::YieldSource::Yield) + } } /// Desugar `ExprForLoop` from: `[opt_ident]: for in ` into: diff --git a/tests/ui/coroutine/async-gen-yield-ty-is-unit.rs b/tests/ui/coroutine/async-gen-yield-ty-is-unit.rs new file mode 100644 index 00000000000..aac74d3eacb --- /dev/null +++ b/tests/ui/coroutine/async-gen-yield-ty-is-unit.rs @@ -0,0 +1,17 @@ +// compile-flags: --edition 2024 -Zunstable-options +// check-pass + +#![feature(async_iterator, gen_blocks, noop_waker)] + +use std::{async_iter::AsyncIterator, pin::pin, task::{Context, Waker}}; + +async gen fn gen_fn() -> &'static str { + yield "hello" +} + +pub fn main() { + let async_iterator = pin!(gen_fn()); + let waker = Waker::noop(); + let ctx = &mut Context::from_waker(&waker); + async_iterator.poll_next(ctx); +}