Rollup merge of #119666 - compiler-errors:construct-coroutine-info-immediately, r=cjgillot

Populate `yield` and `resume` types in MIR body while body is being initialized

I found it weird that we went back and populated these types *after* the body was constructed. Let's just do it all at once.
This commit is contained in:
Michael Goulet 2024-01-06 21:51:46 -05:00 committed by GitHub
commit 854d1131ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 78 deletions

View File

@ -969,7 +969,7 @@ pub fn promote_candidates<'tcx>(
0,
vec![],
body.span,
body.coroutine_kind(),
None,
body.tainted_by_errors,
);
promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);

View File

@ -263,6 +263,23 @@ pub struct CoroutineInfo<'tcx> {
pub coroutine_kind: CoroutineKind,
}
impl<'tcx> CoroutineInfo<'tcx> {
// Sets up `CoroutineInfo` for a pre-coroutine-transform MIR body.
pub fn initial(
coroutine_kind: CoroutineKind,
yield_ty: Ty<'tcx>,
resume_ty: Ty<'tcx>,
) -> CoroutineInfo<'tcx> {
CoroutineInfo {
coroutine_kind,
yield_ty: Some(yield_ty),
resume_ty: Some(resume_ty),
coroutine_drop: None,
coroutine_layout: None,
}
}
}
/// The lowered representation of a single function.
#[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)]
pub struct Body<'tcx> {
@ -367,7 +384,7 @@ pub fn new(
arg_count: usize,
var_debug_info: Vec<VarDebugInfo<'tcx>>,
span: Span,
coroutine_kind: Option<CoroutineKind>,
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
tainted_by_errors: Option<ErrorGuaranteed>,
) -> Self {
// We need `arg_count` locals, and one for the return place.
@ -384,15 +401,7 @@ pub fn new(
source,
basic_blocks: BasicBlocks::new(basic_blocks),
source_scopes,
coroutine: coroutine_kind.map(|coroutine_kind| {
Box::new(CoroutineInfo {
yield_ty: None,
resume_ty: None,
coroutine_drop: None,
coroutine_layout: None,
coroutine_kind,
})
}),
coroutine,
local_decls,
user_type_annotations,
arg_count,

View File

@ -86,8 +86,6 @@ fn index(&self, index: $id) -> &Self::Output {
}
}
pub const UPVAR_ENV_PARAM: ParamId = ParamId::from_u32(0);
thir_with_elements! {
body_type: BodyTy<'tcx>,

View File

@ -9,7 +9,7 @@
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::{CoroutineKind, Node};
use rustc_hir::Node;
use rustc_index::bit_set::GrowableBitSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
check_overflow: bool,
fn_span: Span,
arg_count: usize,
coroutine_kind: Option<CoroutineKind>,
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
/// The current set of scopes, updated as we traverse;
/// see the `scope` module for more details.
@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
) -> Body<'tcx> {
let span = tcx.def_span(fn_def);
let fn_id = tcx.local_def_id_to_hir_id(fn_def);
let coroutine_kind = tcx.coroutine_kind(fn_def);
// The representation of thir for `-Zunpretty=thir-tree` relies on
// the entry expression being the last element of `thir.exprs`.
@ -488,17 +487,15 @@ fn construct_fn<'tcx>(
let arguments = &thir.params;
let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let coroutine_sig = match coroutine_ty.kind() {
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
_ => {
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
}
};
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
} else {
(None, None, fn_sig.output())
let return_ty = fn_sig.output();
let coroutine = match tcx.type_of(fn_def).instantiate_identity().kind() {
ty::Coroutine(_, args) => Some(Box::new(CoroutineInfo::initial(
tcx.coroutine_kind(fn_def).unwrap(),
args.as_coroutine().yield_ty(),
args.as_coroutine().resume_ty(),
))),
ty::Closure(..) | ty::FnDef(..) => None,
ty => span_bug!(span_with_body, "unexpected type of body: {ty:?}"),
};
if let Some(custom_mir_attr) =
@ -529,7 +526,7 @@ fn construct_fn<'tcx>(
safety,
return_ty,
return_ty_span,
coroutine_kind,
coroutine,
);
let call_site_scope =
@ -563,11 +560,6 @@ fn construct_fn<'tcx>(
None
};
if coroutine_kind.is_some() {
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
}
body
}
@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>(
fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> {
let span = tcx.def_span(def_id);
let hir_id = tcx.local_def_id_to_hir_id(def_id);
let coroutine_kind = tcx.coroutine_kind(def_id);
let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
let (inputs, output, coroutine) = match tcx.def_kind(def_id) {
DefKind::Const
| DefKind::AssocConst
| DefKind::AnonConst
| DefKind::InlineConst
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
let sig = tcx.liberate_late_bound_regions(
def_id.to_def_id(),
tcx.fn_sig(def_id).instantiate_identity(),
);
(sig.inputs().to_vec(), sig.output(), None, None)
}
DefKind::Closure if coroutine_kind.is_some() => {
let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
bug!("expected type of coroutine-like closure to be a coroutine")
};
let args = args.as_coroutine();
let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty();
let return_ty = args.return_ty();
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
(sig.inputs().to_vec(), sig.output(), None)
}
DefKind::Closure => {
let closure_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Closure(_, args) = closure_ty.kind() else {
bug!("expected type of closure to be a closure")
};
let args = args.as_closure();
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
let self_ty = match args.kind() {
ty::ClosureKind::Fn => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnOnce => closure_ty,
};
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
match closure_ty.kind() {
ty::Closure(_, args) => {
let args = args.as_closure();
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
let self_ty = match args.kind() {
ty::ClosureKind::Fn => {
Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
}
ty::ClosureKind::FnMut => {
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
}
ty::ClosureKind::FnOnce => closure_ty,
};
(
[self_ty].into_iter().chain(sig.inputs().to_vec()).collect(),
sig.output(),
None,
)
}
ty::Coroutine(_, args) => {
let args = args.as_coroutine();
let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty();
let return_ty = args.return_ty();
(
vec![closure_ty, args.resume_ty()],
return_ty,
Some(Box::new(CoroutineInfo::initial(
tcx.coroutine_kind(def_id).unwrap(),
yield_ty,
resume_ty,
))),
)
}
_ => {
span_bug!(span, "expected type of closure body to be a closure or coroutine");
}
}
}
dk => bug!("{:?} is not a body: {:?}", def_id, dk),
dk => span_bug!(span, "{:?} is not a body: {:?}", def_id, dk),
};
let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE };
@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable);
let mut body = Body::new(
Body::new(
MirSource::item(def_id.to_def_id()),
cfg.basic_blocks,
source_scopes,
@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
inputs.len(),
vec![],
span,
coroutine_kind,
coroutine,
Some(guar),
);
body.coroutine.as_mut().map(|gen| {
gen.yield_ty = yield_ty;
gen.resume_ty = resume_ty;
});
body
)
}
impl<'a, 'tcx> Builder<'a, 'tcx> {
@ -728,7 +728,7 @@ fn new(
safety: Safety,
return_ty: Ty<'tcx>,
return_span: Span,
coroutine_kind: Option<CoroutineKind>,
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
) -> Builder<'a, 'tcx> {
let tcx = infcx.tcx;
let attrs = tcx.hir().attrs(hir_id);
@ -759,7 +759,7 @@ fn new(
cfg: CFG { basic_blocks: IndexVec::new() },
fn_span: span,
arg_count,
coroutine_kind,
coroutine,
scopes: scope::Scopes::new(),
block_context: BlockContext::new(),
source_scopes: IndexVec::new(),
@ -803,7 +803,7 @@ fn finish(self) -> Body<'tcx> {
self.arg_count,
self.var_debug_info,
self.fn_span,
self.coroutine_kind,
self.coroutine,
None,
)
}

View File

@ -706,7 +706,7 @@ fn leave_top_scope(&mut self, block: BasicBlock) -> BasicBlock {
// If we are emitting a `drop` statement, we need to have the cached
// diverge cleanup pads ready in case that drop panics.
let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup());
let is_coroutine = self.coroutine_kind.is_some();
let is_coroutine = self.coroutine.is_some();
let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX };
let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes");
@ -960,7 +960,7 @@ pub(crate) fn schedule_drop(
// path, we only need to invalidate the cache for drops that happen on
// the unwind or coroutine drop paths. This means that for
// non-coroutines we don't need to invalidate caches for `DropKind::Storage`.
let invalidate_caches = needs_drop || self.coroutine_kind.is_some();
let invalidate_caches = needs_drop || self.coroutine.is_some();
for scope in self.scopes.scopes.iter_mut().rev() {
if invalidate_caches {
scope.invalidate_cache();
@ -1073,7 +1073,7 @@ fn diverge_cleanup_target(&mut self, target_scope: region::Scope, span: Span) ->
return cached_drop;
}
let is_coroutine = self.coroutine_kind.is_some();
let is_coroutine = self.coroutine.is_some();
for scope in &mut self.scopes.scopes[uncached_scope..=target] {
for drop in &scope.drops {
if is_coroutine || drop.kind == DropKind::Value {
@ -1318,7 +1318,7 @@ fn build_exit_tree(
blocks[ROOT_NODE] = continue_block;
drops.build_mir::<ExitScopes>(&mut self.cfg, &mut blocks);
let is_coroutine = self.coroutine_kind.is_some();
let is_coroutine = self.coroutine.is_some();
// Link the exit drop tree to unwind drop tree.
if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) {
@ -1355,7 +1355,7 @@ fn build_exit_tree(
/// Build the unwind and coroutine drop trees.
pub(crate) fn build_drop_trees(&mut self) {
if self.coroutine_kind.is_some() {
if self.coroutine.is_some() {
self.build_coroutine_drop_trees();
} else {
Self::build_unwind_tree(