Don't populate yield and resume types after the fact

This commit is contained in:
Michael Goulet 2024-01-06 17:00:24 +00:00
parent 9212108a9b
commit 5e2b66fc9d
5 changed files with 85 additions and 78 deletions

View File

@ -969,7 +969,7 @@ pub fn promote_candidates<'tcx>(
0, 0,
vec![], vec![],
body.span, body.span,
body.coroutine_kind(), None,
body.tainted_by_errors, body.tainted_by_errors,
); );
promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial); promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);

View File

@ -263,6 +263,23 @@ pub struct CoroutineInfo<'tcx> {
pub coroutine_kind: CoroutineKind, pub coroutine_kind: CoroutineKind,
} }
impl<'tcx> CoroutineInfo<'tcx> {
// Sets up `CoroutineInfo` for a pre-coroutine-transform MIR body.
pub fn initial(
coroutine_kind: CoroutineKind,
yield_ty: Ty<'tcx>,
resume_ty: Ty<'tcx>,
) -> CoroutineInfo<'tcx> {
CoroutineInfo {
coroutine_kind,
yield_ty: Some(yield_ty),
resume_ty: Some(resume_ty),
coroutine_drop: None,
coroutine_layout: None,
}
}
}
/// The lowered representation of a single function. /// The lowered representation of a single function.
#[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)] #[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)]
pub struct Body<'tcx> { pub struct Body<'tcx> {
@ -367,7 +384,7 @@ impl<'tcx> Body<'tcx> {
arg_count: usize, arg_count: usize,
var_debug_info: Vec<VarDebugInfo<'tcx>>, var_debug_info: Vec<VarDebugInfo<'tcx>>,
span: Span, span: Span,
coroutine_kind: Option<CoroutineKind>, coroutine: Option<Box<CoroutineInfo<'tcx>>>,
tainted_by_errors: Option<ErrorGuaranteed>, tainted_by_errors: Option<ErrorGuaranteed>,
) -> Self { ) -> Self {
// We need `arg_count` locals, and one for the return place. // We need `arg_count` locals, and one for the return place.
@ -384,15 +401,7 @@ impl<'tcx> Body<'tcx> {
source, source,
basic_blocks: BasicBlocks::new(basic_blocks), basic_blocks: BasicBlocks::new(basic_blocks),
source_scopes, source_scopes,
coroutine: coroutine_kind.map(|coroutine_kind| { coroutine,
Box::new(CoroutineInfo {
yield_ty: None,
resume_ty: None,
coroutine_drop: None,
coroutine_layout: None,
coroutine_kind,
})
}),
local_decls, local_decls,
user_type_annotations, user_type_annotations,
arg_count, arg_count,

View File

@ -86,8 +86,6 @@ macro_rules! thir_with_elements {
} }
} }
pub const UPVAR_ENV_PARAM: ParamId = ParamId::from_u32(0);
thir_with_elements! { thir_with_elements! {
body_type: BodyTy<'tcx>, body_type: BodyTy<'tcx>,

View File

@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed;
use rustc_hir as hir; use rustc_hir as hir;
use rustc_hir::def::DefKind; use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::{CoroutineKind, Node}; use rustc_hir::Node;
use rustc_index::bit_set::GrowableBitSet; use rustc_index::bit_set::GrowableBitSet;
use rustc_index::{Idx, IndexSlice, IndexVec}; use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt}; use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
check_overflow: bool, check_overflow: bool,
fn_span: Span, fn_span: Span,
arg_count: usize, arg_count: usize,
coroutine_kind: Option<CoroutineKind>, coroutine: Option<Box<CoroutineInfo<'tcx>>>,
/// The current set of scopes, updated as we traverse; /// The current set of scopes, updated as we traverse;
/// see the `scope` module for more details. /// see the `scope` module for more details.
@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
) -> Body<'tcx> { ) -> Body<'tcx> {
let span = tcx.def_span(fn_def); let span = tcx.def_span(fn_def);
let fn_id = tcx.local_def_id_to_hir_id(fn_def); let fn_id = tcx.local_def_id_to_hir_id(fn_def);
let coroutine_kind = tcx.coroutine_kind(fn_def);
// The representation of thir for `-Zunpretty=thir-tree` relies on // The representation of thir for `-Zunpretty=thir-tree` relies on
// the entry expression being the last element of `thir.exprs`. // the entry expression being the last element of `thir.exprs`.
@ -488,17 +487,15 @@ fn construct_fn<'tcx>(
let arguments = &thir.params; let arguments = &thir.params;
let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() { let return_ty = fn_sig.output();
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty; let coroutine = match tcx.type_of(fn_def).instantiate_identity().kind() {
let coroutine_sig = match coroutine_ty.kind() { ty::Coroutine(_, args) => Some(Box::new(CoroutineInfo::initial(
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(), tcx.coroutine_kind(fn_def).unwrap(),
_ => { args.as_coroutine().yield_ty(),
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty) args.as_coroutine().resume_ty(),
} ))),
}; ty::Closure(..) | ty::FnDef(..) => None,
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty) ty => span_bug!(span_with_body, "unexpected type of body: {ty:?}"),
} else {
(None, None, fn_sig.output())
}; };
if let Some(custom_mir_attr) = if let Some(custom_mir_attr) =
@ -529,7 +526,7 @@ fn construct_fn<'tcx>(
safety, safety,
return_ty, return_ty,
return_ty_span, return_ty_span,
coroutine_kind, coroutine,
); );
let call_site_scope = let call_site_scope =
@ -563,11 +560,6 @@ fn construct_fn<'tcx>(
None None
}; };
if coroutine_kind.is_some() {
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
}
body body
} }
@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>(
fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> { fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> {
let span = tcx.def_span(def_id); let span = tcx.def_span(def_id);
let hir_id = tcx.local_def_id_to_hir_id(def_id); let hir_id = tcx.local_def_id_to_hir_id(def_id);
let coroutine_kind = tcx.coroutine_kind(def_id);
let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) { let (inputs, output, coroutine) = match tcx.def_kind(def_id) {
DefKind::Const DefKind::Const
| DefKind::AssocConst | DefKind::AssocConst
| DefKind::AnonConst | DefKind::AnonConst
| DefKind::InlineConst | DefKind::InlineConst
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None), | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => { DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
let sig = tcx.liberate_late_bound_regions( let sig = tcx.liberate_late_bound_regions(
def_id.to_def_id(), def_id.to_def_id(),
tcx.fn_sig(def_id).instantiate_identity(), tcx.fn_sig(def_id).instantiate_identity(),
); );
(sig.inputs().to_vec(), sig.output(), None, None) (sig.inputs().to_vec(), sig.output(), None)
} }
DefKind::Closure if coroutine_kind.is_some() => { DefKind::Closure => {
let coroutine_ty = tcx.type_of(def_id).instantiate_identity(); let closure_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Coroutine(_, args) = coroutine_ty.kind() else { match closure_ty.kind() {
bug!("expected type of coroutine-like closure to be a coroutine") ty::Closure(_, args) => {
let args = args.as_closure();
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
let self_ty = match args.kind() {
ty::ClosureKind::Fn => {
Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
}
ty::ClosureKind::FnMut => {
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
}
ty::ClosureKind::FnOnce => closure_ty,
}; };
(
[self_ty].into_iter().chain(sig.inputs().to_vec()).collect(),
sig.output(),
None,
)
}
ty::Coroutine(_, args) => {
let args = args.as_coroutine(); let args = args.as_coroutine();
let resume_ty = args.resume_ty(); let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty(); let yield_ty = args.yield_ty();
let return_ty = args.return_ty(); let return_ty = args.return_ty();
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty)) (
vec![closure_ty, args.resume_ty()],
return_ty,
Some(Box::new(CoroutineInfo::initial(
tcx.coroutine_kind(def_id).unwrap(),
yield_ty,
resume_ty,
))),
)
} }
DefKind::Closure => { _ => {
let closure_ty = tcx.type_of(def_id).instantiate_identity(); span_bug!(span, "expected type of closure body to be a closure or coroutine");
let ty::Closure(_, args) = closure_ty.kind() else {
bug!("expected type of closure to be a closure")
};
let args = args.as_closure();
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
let self_ty = match args.kind() {
ty::ClosureKind::Fn => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnOnce => closure_ty,
};
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
} }
dk => bug!("{:?} is not a body: {:?}", def_id, dk), }
}
dk => span_bug!(span, "{:?} is not a body: {:?}", def_id, dk),
}; };
let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE }; let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE };
@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable); cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable);
let mut body = Body::new( Body::new(
MirSource::item(def_id.to_def_id()), MirSource::item(def_id.to_def_id()),
cfg.basic_blocks, cfg.basic_blocks,
source_scopes, source_scopes,
@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
inputs.len(), inputs.len(),
vec![], vec![],
span, span,
coroutine_kind, coroutine,
Some(guar), Some(guar),
); )
body.coroutine.as_mut().map(|gen| {
gen.yield_ty = yield_ty;
gen.resume_ty = resume_ty;
});
body
} }
impl<'a, 'tcx> Builder<'a, 'tcx> { impl<'a, 'tcx> Builder<'a, 'tcx> {
@ -728,7 +728,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
safety: Safety, safety: Safety,
return_ty: Ty<'tcx>, return_ty: Ty<'tcx>,
return_span: Span, return_span: Span,
coroutine_kind: Option<CoroutineKind>, coroutine: Option<Box<CoroutineInfo<'tcx>>>,
) -> Builder<'a, 'tcx> { ) -> Builder<'a, 'tcx> {
let tcx = infcx.tcx; let tcx = infcx.tcx;
let attrs = tcx.hir().attrs(hir_id); let attrs = tcx.hir().attrs(hir_id);
@ -759,7 +759,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
cfg: CFG { basic_blocks: IndexVec::new() }, cfg: CFG { basic_blocks: IndexVec::new() },
fn_span: span, fn_span: span,
arg_count, arg_count,
coroutine_kind, coroutine,
scopes: scope::Scopes::new(), scopes: scope::Scopes::new(),
block_context: BlockContext::new(), block_context: BlockContext::new(),
source_scopes: IndexVec::new(), source_scopes: IndexVec::new(),
@ -803,7 +803,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.arg_count, self.arg_count,
self.var_debug_info, self.var_debug_info,
self.fn_span, self.fn_span,
self.coroutine_kind, self.coroutine,
None, None,
) )
} }

View File

@ -706,7 +706,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// If we are emitting a `drop` statement, we need to have the cached // If we are emitting a `drop` statement, we need to have the cached
// diverge cleanup pads ready in case that drop panics. // diverge cleanup pads ready in case that drop panics.
let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup()); let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup());
let is_coroutine = self.coroutine_kind.is_some(); let is_coroutine = self.coroutine.is_some();
let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX }; let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX };
let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes"); let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes");
@ -960,7 +960,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// path, we only need to invalidate the cache for drops that happen on // path, we only need to invalidate the cache for drops that happen on
// the unwind or coroutine drop paths. This means that for // the unwind or coroutine drop paths. This means that for
// non-coroutines we don't need to invalidate caches for `DropKind::Storage`. // non-coroutines we don't need to invalidate caches for `DropKind::Storage`.
let invalidate_caches = needs_drop || self.coroutine_kind.is_some(); let invalidate_caches = needs_drop || self.coroutine.is_some();
for scope in self.scopes.scopes.iter_mut().rev() { for scope in self.scopes.scopes.iter_mut().rev() {
if invalidate_caches { if invalidate_caches {
scope.invalidate_cache(); scope.invalidate_cache();
@ -1073,7 +1073,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
return cached_drop; return cached_drop;
} }
let is_coroutine = self.coroutine_kind.is_some(); let is_coroutine = self.coroutine.is_some();
for scope in &mut self.scopes.scopes[uncached_scope..=target] { for scope in &mut self.scopes.scopes[uncached_scope..=target] {
for drop in &scope.drops { for drop in &scope.drops {
if is_coroutine || drop.kind == DropKind::Value { if is_coroutine || drop.kind == DropKind::Value {
@ -1318,7 +1318,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
blocks[ROOT_NODE] = continue_block; blocks[ROOT_NODE] = continue_block;
drops.build_mir::<ExitScopes>(&mut self.cfg, &mut blocks); drops.build_mir::<ExitScopes>(&mut self.cfg, &mut blocks);
let is_coroutine = self.coroutine_kind.is_some(); let is_coroutine = self.coroutine.is_some();
// Link the exit drop tree to unwind drop tree. // Link the exit drop tree to unwind drop tree.
if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) { if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) {
@ -1355,7 +1355,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
/// Build the unwind and coroutine drop trees. /// Build the unwind and coroutine drop trees.
pub(crate) fn build_drop_trees(&mut self) { pub(crate) fn build_drop_trees(&mut self) {
if self.coroutine_kind.is_some() { if self.coroutine.is_some() {
self.build_coroutine_drop_trees(); self.build_coroutine_drop_trees();
} else { } else {
Self::build_unwind_tree( Self::build_unwind_tree(