diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index b82f878ea87..ab82bbd46a6 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -712,7 +712,7 @@ impl<'hir> LoweringContext<'_, 'hir> { let full_span = expr.span.to(await_kw_span); match self.coroutine_kind { Some(hir::CoroutineKind::Async(_)) => {} - Some(hir::CoroutineKind::Coroutine) | None => { + Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => { self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks { await_kw_span, item_span: self.current_item, @@ -936,8 +936,8 @@ impl<'hir> LoweringContext<'_, 'hir> { } Some(movability) } - Some(hir::CoroutineKind::Async(_)) => { - panic!("non-`async` closure body turned `async` during lowering"); + Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => { + panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering"); } None => { if movability == Movability::Static { @@ -1446,6 +1446,7 @@ impl<'hir> LoweringContext<'_, 'hir> { fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> { match self.coroutine_kind { Some(hir::CoroutineKind::Coroutine) => {} + Some(hir::CoroutineKind::Gen(_)) => {} Some(hir::CoroutineKind::Async(_)) => { self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }); } diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs index 928c25d1061..9db7dec1cf8 100644 --- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs @@ -2505,6 +2505,11 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { }; let kind = match use_span.coroutine_kind() { Some(coroutine_kind) => match coroutine_kind { + CoroutineKind::Gen(kind) => match kind { + CoroutineSource::Block => "gen block", + CoroutineSource::Closure => "gen closure", + _ => bug!("gen block/closure expected, but gen function found."), + }, CoroutineKind::Async(async_kind) => match async_kind { CoroutineSource::Block => "async block", CoroutineSource::Closure => "async closure", diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs index e630ac7c4fa..d38cfbc54d7 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs @@ -698,6 +698,20 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { " of async function" } }, + Some(hir::CoroutineKind::Gen(gen)) => match gen { + hir::CoroutineSource::Block => " of gen block", + hir::CoroutineSource::Closure => " of gen closure", + hir::CoroutineSource::Fn => { + let parent_item = + hir.get_by_def_id(hir.get_parent_item(mir_hir_id).def_id); + let output = &parent_item + .fn_decl() + .expect("coroutine lowered from gen fn should be in fn") + .output; + span = output.span(); + " of gen function" + } + }, Some(hir::CoroutineKind::Coroutine) => " of coroutine", None => " of closure", }; diff --git a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs index 5900c764073..1a85eb8dd79 100644 --- a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs +++ b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs @@ -560,6 +560,9 @@ pub fn push_item_name(tcx: TyCtxt<'_>, def_id: DefId, qualified: bool, output: & fn coroutine_kind_label(coroutine_kind: Option) -> &'static str { match coroutine_kind { + Some(CoroutineKind::Gen(CoroutineSource::Block)) => "gen_block", + Some(CoroutineKind::Gen(CoroutineSource::Closure)) => "gen_closure", + Some(CoroutineKind::Gen(CoroutineSource::Fn)) => "gen_fn", Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block", Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure", Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn", diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 259af4f565b..22ee39634a5 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -1513,6 +1513,9 @@ pub enum CoroutineKind { /// An explicit `async` block or the body of an async function. Async(CoroutineSource), + /// An explicit `gen` block or the body of a `gen` function. + Gen(CoroutineSource), + /// A coroutine literal created via a `yield` inside a closure. Coroutine, } @@ -1529,6 +1532,14 @@ impl fmt::Display for CoroutineKind { k.fmt(f) } CoroutineKind::Coroutine => f.write_str("coroutine"), + CoroutineKind::Gen(k) => { + if f.alternate() { + f.write_str("`gen` ")?; + } else { + f.write_str("gen ")? + } + k.fmt(f) + } } } } @@ -2242,6 +2253,7 @@ impl From for YieldSource { // Guess based on the kind of the current coroutine. CoroutineKind::Coroutine => Self::Yield, CoroutineKind::Async(_) => Self::Await { expr: None }, + CoroutineKind::Gen(_) => Self::Yield, } } } diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index e7060dac844..14cd3881a06 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -58,15 +58,16 @@ pub(super) fn check_fn<'a, 'tcx>( if let Some(kind) = body.coroutine_kind && can_be_coroutine.is_some() { - let yield_ty = if kind == hir::CoroutineKind::Coroutine { - let yield_ty = fcx.next_ty_var(TypeVariableOrigin { - kind: TypeVariableOriginKind::TypeInference, - span, - }); - fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType); - yield_ty - } else { - Ty::new_unit(tcx) + let yield_ty = match kind { + hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => { + let yield_ty = fcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::TypeInference, + span, + }); + fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType); + yield_ty + } + hir::CoroutineKind::Async(..) => Ty::new_unit(tcx), }; // Resume type defaults to `()` if the coroutine has no argument. diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index affa83fa348..ef1d0c867ea 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -148,8 +148,14 @@ impl AssertKind { RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero", ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion", ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion", + ResumedAfterReturn(CoroutineKind::Gen(_)) => { + bug!("`gen fn` should just keep returning `None` after the first time") + } ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking", ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking", + ResumedAfterPanic(CoroutineKind::Gen(_)) => { + bug!("`gen fn` should just keep returning `None` after panicking") + } BoundsCheck { .. } | MisalignedPointerDereference { .. } => { bug!("Unexpected AssertKind") } @@ -236,10 +242,14 @@ impl AssertKind { DivisionByZero(_) => middle_assert_divide_by_zero, RemainderByZero(_) => middle_assert_remainder_by_zero, ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return, + // FIXME(gen_blocks): custom error message for `gen` blocks + ResumedAfterReturn(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_return, ResumedAfterReturn(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_return } ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic, + // FIXME(gen_blocks): custom error message for `gen` blocks + ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_panic, ResumedAfterPanic(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_panic } diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs index be48c0e8926..d29d1902404 100644 --- a/compiler/rustc_middle/src/ty/util.rs +++ b/compiler/rustc_middle/src/ty/util.rs @@ -749,6 +749,7 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() { rustc_hir::CoroutineKind::Async(..) => "async closure", rustc_hir::CoroutineKind::Coroutine => "coroutine", + rustc_hir::CoroutineKind::Gen(..) => "gen closure", }, _ => def_kind.descr(def_id), } @@ -766,6 +767,7 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() { rustc_hir::CoroutineKind::Async(..) => "an", rustc_hir::CoroutineKind::Coroutine => "a", + rustc_hir::CoroutineKind::Gen(..) => "a", }, _ => def_kind.article(), } diff --git a/compiler/rustc_smir/src/rustc_smir/mod.rs b/compiler/rustc_smir/src/rustc_smir/mod.rs index 928536721c9..219f6d284dd 100644 --- a/compiler/rustc_smir/src/rustc_smir/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/mod.rs @@ -880,18 +880,28 @@ impl<'tcx> Stable<'tcx> for mir::AggregateKind<'tcx> { } } +impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineSource { + type T = stable_mir::mir::CoroutineSource; + fn stable(&self, _: &mut Tables<'tcx>) -> Self::T { + use rustc_hir::CoroutineSource; + match self { + CoroutineSource::Block => stable_mir::mir::CoroutineSource::Block, + CoroutineSource::Closure => stable_mir::mir::CoroutineSource::Closure, + CoroutineSource::Fn => stable_mir::mir::CoroutineSource::Fn, + } + } +} + impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind { type T = stable_mir::mir::CoroutineKind; - fn stable(&self, _: &mut Tables<'tcx>) -> Self::T { - use rustc_hir::{CoroutineKind, CoroutineSource}; + fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { + use rustc_hir::CoroutineKind; match self { - CoroutineKind::Async(async_gen) => { - let async_gen = match async_gen { - CoroutineSource::Block => stable_mir::mir::CoroutineSource::Block, - CoroutineSource::Closure => stable_mir::mir::CoroutineSource::Closure, - CoroutineSource::Fn => stable_mir::mir::CoroutineSource::Fn, - }; - stable_mir::mir::CoroutineKind::Async(async_gen) + CoroutineKind::Async(source) => { + stable_mir::mir::CoroutineKind::Async(source.stable(tables)) + } + CoroutineKind::Gen(source) => { + stable_mir::mir::CoroutineKind::Gen(source.stable(tables)) } CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine, } diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index b9a08056ad1..3087b734cf2 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -2425,6 +2425,21 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { CoroutineKind::Async(CoroutineSource::Closure) => { format!("future created by async closure is not {trait_name}") } + CoroutineKind::Gen(CoroutineSource::Fn) => self + .tcx + .parent(coroutine_did) + .as_local() + .map(|parent_did| hir.local_def_id_to_hir_id(parent_did)) + .and_then(|parent_hir_id| hir.opt_name(parent_hir_id)) + .map(|name| { + format!("iterator returned by `{name}` is not {trait_name}") + })?, + CoroutineKind::Gen(CoroutineSource::Block) => { + format!("iterator created by gen block is not {trait_name}") + } + CoroutineKind::Gen(CoroutineSource::Closure) => { + format!("iterator created by gen closure is not {trait_name}") + } }) }) .unwrap_or_else(|| format!("{future_or_coroutine} is not {trait_name}")); @@ -2905,7 +2920,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { } ObligationCauseCode::SizedCoroutineInterior(coroutine_def_id) => { let what = match self.tcx.coroutine_kind(coroutine_def_id) { - None | Some(hir::CoroutineKind::Coroutine) => "yield", + None | Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) => "yield", Some(hir::CoroutineKind::Async(..)) => "await", }; err.note(format!( diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs index 353a69ef6d3..030813a4ac3 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs @@ -1614,6 +1614,9 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> { hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block", hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function", hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure", + hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block", + hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function", + hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure", }) } diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index c2e79eaaf3f..b686004b2fe 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -137,6 +137,7 @@ pub enum UnOp { pub enum CoroutineKind { Async(CoroutineSource), Coroutine, + Gen(CoroutineSource), } #[derive(Clone, Debug)]