Auto merge of #123572 - Mark-Simulacrum:vtable-methods, r=oli-obk

Increase vtable layout size

This improves LLVM's codegen by allowing vtable loads to be hoisted out of loops (as just one example). The calculation here is an under-approximation but works for simple trait hierarchies (e.g., FnMut will be improved). We have a runtime assert that the approximation is accurate, so there's no risk of UB as a result of getting this wrong.

```rust
#[no_mangle]
pub fn foo(elements: &[u32], callback: &mut dyn Callback) {
    for element in elements.iter() {
        if *element != 0 {
            callback.call(*element);
        }
    }
}

pub trait Callback {
    fn call(&mut self, _: u32);
}
```

Simplifying a bit (e.g., numbering ends up different):

```diff
 ; Function Attrs: nonlazybind uwtable
-define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(24) %callback.1) unnamed_addr #0 {
+define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(32) %callback.1) unnamed_addr #0 {
 start:
   %_15 = getelementptr inbounds i32, ptr %elements.0, i64 %elements.1
`@@` -13,4 +13,5 `@@`
 bb4.lr.ph:                                        ; preds = %start
   %1 = getelementptr inbounds i8, ptr %callback.1, i64 24
+  %2 = load ptr, ptr %1, align 8, !nonnull !3
   br label %bb4

 bb6:                                              ; preds = %bb4
-  %4 = load ptr, ptr %1, align 8, !invariant.load !3, !nonnull !3
-  tail call void %4(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9)
+  tail call void %2(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9)
   br label %bb7
 }
```
This commit is contained in:
bors 2024-06-01 14:31:07 +00:00
commit f2208b3297
11 changed files with 144 additions and 127 deletions

View File

@ -12,7 +12,6 @@ use rustc_middle::query::Providers;
use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
use rustc_session::parse::feature_err;
use rustc_span::{sym, ErrorGuaranteed};
use rustc_trait_selection::traits;
mod builtin;
mod inherent_impls;
@ -199,7 +198,7 @@ fn check_object_overlap<'tcx>(
// With the feature enabled, the trait is not implemented automatically,
// so this is valid.
} else {
let mut supertrait_def_ids = traits::supertrait_def_ids(tcx, component_def_id);
let mut supertrait_def_ids = tcx.supertrait_def_ids(component_def_id);
if supertrait_def_ids.any(|d| d == trait_def_id) {
let span = tcx.def_span(impl_def_id);
return Err(struct_span_code_err!(

View File

@ -827,25 +827,14 @@ where
});
}
let mk_dyn_vtable = || {
let mk_dyn_vtable = |principal: Option<ty::PolyExistentialTraitRef<'tcx>>| {
let min_count = ty::vtable_min_entries(tcx, principal);
Ty::new_imm_ref(
tcx,
tcx.lifetimes.re_static,
Ty::new_array(tcx, tcx.types.usize, 3),
// FIXME: properly type (e.g. usize and fn pointers) the fields.
Ty::new_array(tcx, tcx.types.usize, min_count.try_into().unwrap()),
)
/* FIXME: use actual fn pointers
Warning: naively computing the number of entries in the
vtable by counting the methods on the trait + methods on
all parent traits does not work, because some methods can
be not object safe and thus excluded from the vtable.
Increase this counter if you tried to implement this but
failed to do it without duplicating a lot of code from
other places in the compiler: 2
Ty::new_tup(tcx,&[
Ty::new_array(tcx,tcx.types.usize, 3),
Ty::new_array(tcx,Option<fn()>),
])
*/
};
let metadata = if let Some(metadata_def_id) = tcx.lang_items().metadata_type()
@ -864,16 +853,16 @@ where
// `std::mem::uninitialized::<&dyn Trait>()`, for example.
if let ty::Adt(def, args) = metadata.kind()
&& Some(def.did()) == tcx.lang_items().dyn_metadata()
&& args.type_at(0).is_trait()
&& let ty::Dynamic(data, _, ty::Dyn) = args.type_at(0).kind()
{
mk_dyn_vtable()
mk_dyn_vtable(data.principal())
} else {
metadata
}
} else {
match tcx.struct_tail_erasing_lifetimes(pointee, cx.param_env()).kind() {
ty::Slice(_) | ty::Str => tcx.types.usize,
ty::Dynamic(_, _, ty::Dyn) => mk_dyn_vtable(),
ty::Dynamic(data, _, ty::Dyn) => mk_dyn_vtable(data.principal()),
_ => bug!("TyAndLayout::field({:?}): not applicable", this),
}
};

View File

@ -3,6 +3,8 @@ use std::fmt;
use crate::mir::interpret::{alloc_range, AllocId, Allocation, Pointer, Scalar};
use crate::ty::{self, Instance, PolyTraitRef, Ty, TyCtxt};
use rustc_ast::Mutability;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def_id::DefId;
use rustc_macros::HashStable;
#[derive(Clone, Copy, PartialEq, HashStable)]
@ -40,12 +42,69 @@ impl<'tcx> fmt::Debug for VtblEntry<'tcx> {
impl<'tcx> TyCtxt<'tcx> {
pub const COMMON_VTABLE_ENTRIES: &'tcx [VtblEntry<'tcx>] =
&[VtblEntry::MetadataDropInPlace, VtblEntry::MetadataSize, VtblEntry::MetadataAlign];
pub fn supertrait_def_ids(self, trait_def_id: DefId) -> SupertraitDefIds<'tcx> {
SupertraitDefIds {
tcx: self,
stack: vec![trait_def_id],
visited: Some(trait_def_id).into_iter().collect(),
}
}
}
pub const COMMON_VTABLE_ENTRIES_DROPINPLACE: usize = 0;
pub const COMMON_VTABLE_ENTRIES_SIZE: usize = 1;
pub const COMMON_VTABLE_ENTRIES_ALIGN: usize = 2;
pub struct SupertraitDefIds<'tcx> {
tcx: TyCtxt<'tcx>,
stack: Vec<DefId>,
visited: FxHashSet<DefId>,
}
impl Iterator for SupertraitDefIds<'_> {
type Item = DefId;
fn next(&mut self) -> Option<DefId> {
let def_id = self.stack.pop()?;
let predicates = self.tcx.super_predicates_of(def_id);
let visited = &mut self.visited;
self.stack.extend(
predicates
.predicates
.iter()
.filter_map(|(pred, _)| pred.as_trait_clause())
.map(|trait_ref| trait_ref.def_id())
.filter(|&super_def_id| visited.insert(super_def_id)),
);
Some(def_id)
}
}
// Note that we don't have access to a self type here, this has to be purely based on the trait (and
// supertrait) definitions. That means we can't call into the same vtable_entries code since that
// returns a specific instantiation (e.g., with Vacant slots when bounds aren't satisfied). The goal
// here is to do a best-effort approximation without duplicating a lot of code.
//
// This function is used in layout computation for e.g. &dyn Trait, so it's critical that this
// function is an accurate approximation. We verify this when actually computing the vtable below.
pub(crate) fn vtable_min_entries<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
) -> usize {
let mut count = TyCtxt::COMMON_VTABLE_ENTRIES.len();
let Some(trait_ref) = trait_ref else {
return count;
};
// This includes self in supertraits.
for def_id in tcx.supertrait_def_ids(trait_ref.def_id()) {
count += tcx.own_existential_vtable_entries(def_id).len();
}
count
}
/// Retrieves an allocation that represents the contents of a vtable.
/// Since this is a query, allocations are cached and not duplicated.
pub(super) fn vtable_allocation_provider<'tcx>(
@ -63,6 +122,9 @@ pub(super) fn vtable_allocation_provider<'tcx>(
TyCtxt::COMMON_VTABLE_ENTRIES
};
// This confirms that the layout computation for &dyn Trait has an accurate sizing.
assert!(vtable_entries.len() >= vtable_min_entries(tcx, poly_trait_ref));
let layout = tcx
.layout_of(ty::ParamEnv::reveal_all().and(ty))
.expect("failed to build vtable representation");

View File

@ -1,7 +1,5 @@
//! Dealing with trait goals, i.e. `T: Trait<'a, U>`.
use crate::traits::supertrait_def_ids;
use super::assembly::structural_traits::AsyncCallableRelevantTypes;
use super::assembly::{self, structural_traits, Candidate};
use super::{EvalCtxt, GoalSource, SolverMode};
@ -837,7 +835,8 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
let a_auto_traits: FxIndexSet<DefId> = a_data
.auto_traits()
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
supertrait_def_ids(self.interner(), principal_def_id)
self.interner()
.supertrait_def_ids(principal_def_id)
.filter(|def_id| self.interner().trait_is_auto(*def_id))
}))
.collect();

View File

@ -65,10 +65,7 @@ pub use self::structural_normalize::StructurallyNormalizeExt;
pub use self::util::elaborate;
pub use self::util::{expand_trait_aliases, TraitAliasExpander, TraitAliasExpansionInfo};
pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices};
pub use self::util::{
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item,
SupertraitDefIds,
};
pub use self::util::{supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item};
pub use self::util::{with_replaced_escaping_bound_vars, BoundVarReplacer, PlaceholderReplacer};
pub use rustc_infer::traits::*;

View File

@ -26,6 +26,7 @@ use rustc_middle::ty::{TypeVisitableExt, Upcast};
use rustc_session::lint::builtin::WHERE_CLAUSES_OBJECT_SAFETY;
use rustc_span::symbol::Symbol;
use rustc_span::Span;
use rustc_target::abi::Abi;
use smallvec::SmallVec;
use std::iter;
@ -44,7 +45,8 @@ pub fn hir_ty_lowering_object_safety_violations(
trait_def_id: DefId,
) -> Vec<ObjectSafetyViolation> {
debug_assert!(tcx.generics_of(trait_def_id).has_self);
let violations = traits::supertrait_def_ids(tcx, trait_def_id)
let violations = tcx
.supertrait_def_ids(trait_def_id)
.map(|def_id| predicates_reference_self(tcx, def_id, true))
.filter(|spans| !spans.is_empty())
.map(ObjectSafetyViolation::SupertraitSelf)
@ -58,7 +60,7 @@ fn object_safety_violations(tcx: TyCtxt<'_>, trait_def_id: DefId) -> &'_ [Object
debug!("object_safety_violations: {:?}", trait_def_id);
tcx.arena.alloc_from_iter(
traits::supertrait_def_ids(tcx, trait_def_id)
tcx.supertrait_def_ids(trait_def_id)
.flat_map(|def_id| object_safety_violations_for_trait(tcx, def_id)),
)
}
@ -145,6 +147,14 @@ fn object_safety_violations_for_trait(
violations.push(ObjectSafetyViolation::SupertraitNonLifetimeBinder(spans));
}
if violations.is_empty() {
for item in tcx.associated_items(trait_def_id).in_definition_order() {
if let ty::AssocKind::Fn = item.kind {
check_receiver_correct(tcx, trait_def_id, *item);
}
}
}
debug!(
"object_safety_violations_for_trait(trait_def_id={:?}) = {:?}",
trait_def_id, violations
@ -493,59 +503,8 @@ fn virtual_call_violations_for_method<'tcx>(
};
errors.push(MethodViolationCode::UndispatchableReceiver(span));
} else {
// Do sanity check to make sure the receiver actually has the layout of a pointer.
use rustc_target::abi::Abi;
let param_env = tcx.param_env(method.def_id);
let abi_of_ty = |ty: Ty<'tcx>| -> Option<Abi> {
match tcx.layout_of(param_env.and(ty)) {
Ok(layout) => Some(layout.abi),
Err(err) => {
// #78372
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!("error: {err}\n while computing layout for type {ty:?}"),
);
None
}
}
};
// e.g., `Rc<()>`
let unit_receiver_ty =
receiver_for_self_ty(tcx, receiver_ty, tcx.types.unit, method.def_id);
match abi_of_ty(unit_receiver_ty) {
Some(Abi::Scalar(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!(
"receiver when `Self = ()` should have a Scalar ABI; found {abi:?}"
),
);
}
}
let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);
// e.g., `Rc<dyn Trait>`
let trait_object_receiver =
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method.def_id);
match abi_of_ty(trait_object_receiver) {
Some(Abi::ScalarPair(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!(
"receiver when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
),
);
}
}
// We confirm that the `receiver_is_dispatchable` is accurate later,
// see `check_receiver_correct`. It should be kept in sync with this code.
}
}
@ -606,6 +565,55 @@ fn virtual_call_violations_for_method<'tcx>(
errors
}
/// This code checks that `receiver_is_dispatchable` is correctly implemented.
///
/// This check is outlined from the object safety check to avoid cycles with
/// layout computation, which relies on knowing whether methods are object safe.
pub fn check_receiver_correct<'tcx>(tcx: TyCtxt<'tcx>, trait_def_id: DefId, method: ty::AssocItem) {
if !is_vtable_safe_method(tcx, trait_def_id, method) {
return;
}
let method_def_id = method.def_id;
let sig = tcx.fn_sig(method_def_id).instantiate_identity();
let param_env = tcx.param_env(method_def_id);
let receiver_ty = tcx.liberate_late_bound_regions(method_def_id, sig.input(0));
if receiver_ty == tcx.types.self_param {
// Assumed OK, may change later if unsized_locals permits `self: Self` as dispatchable.
return;
}
// e.g., `Rc<()>`
let unit_receiver_ty = receiver_for_self_ty(tcx, receiver_ty, tcx.types.unit, method_def_id);
match tcx.layout_of(param_env.and(unit_receiver_ty)).map(|l| l.abi) {
Ok(Abi::Scalar(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method_def_id),
format!("receiver {unit_receiver_ty:?} when `Self = ()` should have a Scalar ABI; found {abi:?}"),
);
}
}
let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);
// e.g., `Rc<dyn Trait>`
let trait_object_receiver =
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method_def_id);
match tcx.layout_of(param_env.and(trait_object_receiver)).map(|l| l.abi) {
Ok(Abi::ScalarPair(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method_def_id),
format!(
"receiver {trait_object_receiver:?} when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
),
);
}
}
}
/// Performs a type instantiation to produce the version of `receiver_ty` when `Self = self_ty`.
/// For example, for `receiver_ty = Rc<Self>` and `self_ty = Foo`, returns `Rc<Foo>`.
fn receiver_for_self_ty<'tcx>(

View File

@ -1004,7 +1004,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let a_auto_traits: FxIndexSet<DefId> = a_data
.auto_traits()
.chain(principal_def_id_a.into_iter().flat_map(|principal_def_id| {
util::supertrait_def_ids(self.tcx(), principal_def_id)
self.tcx()
.supertrait_def_ids(principal_def_id)
.filter(|def_id| self.tcx().trait_is_auto(*def_id))
}))
.collect();

View File

@ -2591,8 +2591,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
let a_auto_traits: FxIndexSet<DefId> = a_data
.auto_traits()
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
util::supertrait_def_ids(tcx, principal_def_id)
.filter(|def_id| tcx.trait_is_auto(*def_id))
tcx.supertrait_def_ids(principal_def_id).filter(|def_id| tcx.trait_is_auto(*def_id))
}))
.collect();

View File

@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use super::NormalizeExt;
use super::{ObligationCause, PredicateObligation, SelectionContext};
use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
use rustc_data_structures::fx::FxIndexMap;
use rustc_errors::Diag;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::{InferCtxt, InferOk};
@ -161,43 +161,6 @@ impl<'tcx> Iterator for TraitAliasExpander<'tcx> {
}
}
///////////////////////////////////////////////////////////////////////////
// Iterator over def-IDs of supertraits
///////////////////////////////////////////////////////////////////////////
pub struct SupertraitDefIds<'tcx> {
tcx: TyCtxt<'tcx>,
stack: Vec<DefId>,
visited: FxHashSet<DefId>,
}
pub fn supertrait_def_ids(tcx: TyCtxt<'_>, trait_def_id: DefId) -> SupertraitDefIds<'_> {
SupertraitDefIds {
tcx,
stack: vec![trait_def_id],
visited: Some(trait_def_id).into_iter().collect(),
}
}
impl Iterator for SupertraitDefIds<'_> {
type Item = DefId;
fn next(&mut self) -> Option<DefId> {
let def_id = self.stack.pop()?;
let predicates = self.tcx.super_predicates_of(def_id);
let visited = &mut self.visited;
self.stack.extend(
predicates
.predicates
.iter()
.filter_map(|(pred, _)| pred.as_trait_clause())
.map(|trait_ref| trait_ref.def_id())
.filter(|&super_def_id| visited.insert(super_def_id)),
);
Some(def_id)
}
}
///////////////////////////////////////////////////////////////////////////
// Other
///////////////////////////////////////////////////////////////////////////

View File

@ -253,7 +253,7 @@ fn check_trait_items(cx: &LateContext<'_>, visited_trait: &Item<'_>, trait_items
// fill the set with current and super traits
fn fill_trait_set(traitt: DefId, set: &mut DefIdSet, cx: &LateContext<'_>) {
if set.insert(traitt) {
for supertrait in rustc_trait_selection::traits::supertrait_def_ids(cx.tcx, traitt) {
for supertrait in cx.tcx.supertrait_def_ids(traitt) {
fill_trait_set(supertrait, set, cx);
}
}

View File

@ -46,13 +46,13 @@
// cdb-command:dx c
// cdb-check:c [Type: ref$<unsized::Foo<dyn$<core::fmt::Debug> > >]
// cdb-check: [+0x000] pointer : 0x[...] [Type: unsized::Foo<dyn$<core::fmt::Debug> > *]
// cdb-check: [...] vtable : 0x[...] [Type: unsigned [...]int[...] (*)[3]]
// cdb-check: [...] vtable : 0x[...] [Type: unsigned [...]int[...] (*)[4]]
// cdb-command:dx _box
// cdb-check:
// cdb-check:_box [Type: alloc::boxed::Box<unsized::Foo<dyn$<core::fmt::Debug> >,alloc::alloc::Global>]
// cdb-check:[+0x000] pointer : 0x[...] [Type: unsized::Foo<dyn$<core::fmt::Debug> > *]
// cdb-check:[...] vtable : 0x[...] [Type: unsigned [...]int[...] (*)[3]]
// cdb-check:[...] vtable : 0x[...] [Type: unsigned [...]int[...] (*)[4]]
// cdb-command:dx tuple_slice
// cdb-check:tuple_slice [Type: ref$<tuple$<i32,i32,slice2$<i32> > >]
@ -62,7 +62,7 @@
// cdb-command:dx tuple_dyn
// cdb-check:tuple_dyn [Type: ref$<tuple$<i32,i32,dyn$<core::fmt::Debug> > >]
// cdb-check: [+0x000] pointer : 0x[...] [Type: tuple$<i32,i32,dyn$<core::fmt::Debug> > *]
// cdb-check: [...] vtable : 0x[...] [Type: unsigned [...]int[...] (*)[3]]
// cdb-check: [...] vtable : 0x[...] [Type: unsigned [...]int[...] (*)[4]]
#![feature(unsized_tuple_coercion)]
#![feature(omit_gdb_pretty_printer_section)]