diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index 8a9a5d254df..f7914b578e4 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -1008,6 +1008,57 @@ fn call(f: &&&&&impl Fn(u8) -> u8, x: u8) -> u8 { ); } +#[test] +fn dyn_trait() { + check_number( + r#" + //- minicore: coerce_unsized, index, slice + trait Foo { + fn foo(&self) -> u8 { 10 } + } + struct S1; + struct S2; + struct S3; + impl Foo for S1 { + fn foo(&self) -> u8 { 1 } + } + impl Foo for S2 { + fn foo(&self) -> u8 { 2 } + } + impl Foo for S3 {} + const GOAL: u8 = { + let x: &[&dyn Foo] = &[&S1, &S2, &S3]; + x[0].foo() + x[1].foo() + x[2].foo() + }; + "#, + 13, + ); + check_number( + r#" + //- minicore: coerce_unsized, index, slice + trait Foo { + fn foo(&self) -> i32 { 10 } + } + trait Bar { + fn bar(&self) -> i32 { 20 } + } + + struct S; + impl Foo for S { + fn foo(&self) -> i32 { 200 } + } + impl Bar for dyn Foo { + fn bar(&self) -> i32 { 700 } + } + const GOAL: i32 = { + let x: &dyn Foo = &S; + x.bar() + x.foo() + }; + "#, + 900, + ); +} + #[test] fn array_and_index() { check_number( diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs index f105c94086c..6244b98104f 100644 --- a/crates/hir-ty/src/method_resolution.rs +++ b/crates/hir-ty/src/method_resolution.rs @@ -5,7 +5,7 @@ use std::{ops::ControlFlow, sync::Arc}; use base_db::{CrateId, Edition}; -use chalk_ir::{cast::Cast, Mutability, TyKind, UniverseIndex}; +use chalk_ir::{cast::Cast, Mutability, TyKind, UniverseIndex, WhereClause}; use hir_def::{ data::ImplData, item_scope::ItemScope, lang_item::LangItem, nameres::DefMap, AssocItemId, BlockId, ConstId, FunctionId, HasModule, ImplId, ItemContainerId, Lookup, ModuleDefId, @@ -692,6 +692,38 @@ pub fn lookup_impl_const( .unwrap_or((const_id, subs)) } +/// Checks if the self parameter of `Trait` method is the `dyn Trait` and we should +/// call the method using the vtable. +pub fn is_dyn_method( + db: &dyn HirDatabase, + _env: Arc, + func: FunctionId, + fn_subst: Substitution, +) -> Option { + let ItemContainerId::TraitId(trait_id) = func.lookup(db.upcast()).container else { + return None; + }; + let trait_params = db.generic_params(trait_id.into()).type_or_consts.len(); + let fn_params = fn_subst.len(Interner) - trait_params; + let trait_ref = TraitRef { + trait_id: to_chalk_trait_id(trait_id), + substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)), + }; + let self_ty = trait_ref.self_type_parameter(Interner); + if let TyKind::Dyn(d) = self_ty.kind(Interner) { + let is_my_trait_in_bounds = d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() { + // rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter + // what the generics are, we are sure that the method is come from the vtable. + WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id, + _ => false, + }); + if is_my_trait_in_bounds { + return Some(fn_params); + } + } + None +} + /// Looks up the impl method that actually runs for the trait method `func`. /// /// Returns `func` if it's not a method defined in a trait or the lookup failed. @@ -701,9 +733,8 @@ pub fn lookup_impl_method( func: FunctionId, fn_subst: Substitution, ) -> (FunctionId, Substitution) { - let trait_id = match func.lookup(db.upcast()).container { - ItemContainerId::TraitId(id) => id, - _ => return (func, fn_subst), + let ItemContainerId::TraitId(trait_id) = func.lookup(db.upcast()).container else { + return (func, fn_subst) }; let trait_params = db.generic_params(trait_id.into()).type_or_consts.len(); let fn_params = fn_subst.len(Interner) - trait_params; diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 88ef92a4ae6..7293156a978 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -23,10 +23,10 @@ infer::{normalize, PointerCast}, layout::layout_of_ty, mapping::from_chalk, - method_resolution::lookup_impl_method, + method_resolution::{is_dyn_method, lookup_impl_method}, traits::FnTrait, CallableDefId, Const, ConstScalar, FnDefId, Interner, MemoryMap, Substitution, - TraitEnvironment, Ty, TyBuilder, TyExt, + TraitEnvironment, Ty, TyBuilder, TyExt, GenericArgData, }; use super::{ @@ -34,6 +34,15 @@ Operand, Place, ProjectionElem, Rvalue, StatementKind, Terminator, UnOp, }; +macro_rules! from_bytes { + ($ty:tt, $value:expr) => { + ($ty::from_le_bytes(match ($value).try_into() { + Ok(x) => x, + Err(_) => return Err(MirEvalError::TypeError("mismatched size")), + })) + }; +} + #[derive(Debug, Default)] struct VTableMap { ty_to_id: HashMap, @@ -54,6 +63,11 @@ fn id(&mut self, ty: Ty) -> usize { fn ty(&self, id: usize) -> Result<&Ty> { self.id_to_ty.get(id).ok_or(MirEvalError::InvalidVTableId(id)) } + + fn ty_of_bytes(&self, bytes: &[u8]) -> Result<&Ty> { + let id = from_bytes!(usize, bytes); + self.ty(id) + } } pub struct Evaluator<'a> { @@ -110,15 +124,6 @@ pub(crate) fn to_vec(self, memory: &Evaluator<'_>) -> Result> { } } -macro_rules! from_bytes { - ($ty:tt, $value:expr) => { - ($ty::from_le_bytes(match ($value).try_into() { - Ok(x) => x, - Err(_) => return Err(MirEvalError::TypeError("mismatched size")), - })) - }; -} - impl Address { fn from_bytes(x: &[u8]) -> Result { Ok(Address::from_usize(from_bytes!(usize, x))) @@ -781,7 +786,18 @@ fn eval_rvalue<'a>( } _ => not_supported!("slice unsizing from non pointers"), }, - TyKind::Dyn(_) => not_supported!("dyn pointer unsize cast"), + TyKind::Dyn(_) => match ¤t_ty.data(Interner).kind { + TyKind::Raw(_, ty) | TyKind::Ref(_, _, ty) => { + let vtable = self.vtable_map.id(ty.clone()); + let addr = + self.eval_operand(operand, locals)?.get(&self)?; + let mut r = Vec::with_capacity(16); + r.extend(addr.iter().copied()); + r.extend(vtable.to_le_bytes().into_iter()); + Owned(r) + } + _ => not_supported!("dyn unsizing from non pointers"), + }, _ => not_supported!("unknown unsized cast"), } } @@ -1227,44 +1243,8 @@ fn exec_fn_def( let arg_bytes = args .iter() .map(|x| Ok(self.eval_operand(x, &locals)?.get(&self)?.to_owned())) - .collect::>>()? - .into_iter(); - let function_data = self.db.function_data(def); - let is_intrinsic = match &function_data.abi { - Some(abi) => *abi == Interned::new_str("rust-intrinsic"), - None => match def.lookup(self.db.upcast()).container { - hir_def::ItemContainerId::ExternBlockId(block) => { - let id = block.lookup(self.db.upcast()).id; - id.item_tree(self.db.upcast())[id.value].abi.as_deref() - == Some("rust-intrinsic") - } - _ => false, - }, - }; - let result = if is_intrinsic { - self.exec_intrinsic( - function_data.name.as_text().unwrap_or_default().as_str(), - arg_bytes, - generic_args, - &locals, - )? - } else if let Some(x) = self.detect_lang_function(def) { - self.exec_lang_item(x, arg_bytes)? - } else { - let (imp, generic_args) = lookup_impl_method( - self.db, - self.trait_env.clone(), - def, - generic_args.clone(), - ); - let generic_args = self.subst_filler(&generic_args, &locals); - let def = imp.into(); - let mir_body = - self.db.mir_body(def).map_err(|e| MirEvalError::MirLowerError(imp, e))?; - self.interpret_mir(&mir_body, arg_bytes, generic_args) - .map_err(|e| MirEvalError::InFunction(imp, Box::new(e)))? - }; - self.write_memory(dest_addr, &result)?; + .collect::>>()?; + self.exec_fn_with_args(def, arg_bytes, generic_args, locals, dest_addr)?; } CallableDefId::StructId(id) => { let (size, variant_layout, tag) = @@ -1284,6 +1264,77 @@ fn exec_fn_def( Ok(()) } + fn exec_fn_with_args( + &mut self, + def: FunctionId, + arg_bytes: Vec>, + generic_args: Substitution, + locals: &Locals<'_>, + dest_addr: Address, + ) -> Result<()> { + let function_data = self.db.function_data(def); + let is_intrinsic = match &function_data.abi { + Some(abi) => *abi == Interned::new_str("rust-intrinsic"), + None => match def.lookup(self.db.upcast()).container { + hir_def::ItemContainerId::ExternBlockId(block) => { + let id = block.lookup(self.db.upcast()).id; + id.item_tree(self.db.upcast())[id.value].abi.as_deref() + == Some("rust-intrinsic") + } + _ => false, + }, + }; + let result = if is_intrinsic { + self.exec_intrinsic( + function_data.name.as_text().unwrap_or_default().as_str(), + arg_bytes.iter().cloned(), + generic_args, + &locals, + )? + } else if let Some(x) = self.detect_lang_function(def) { + self.exec_lang_item(x, &arg_bytes)? + } else { + if let Some(self_ty_idx) = + is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone()) + { + // In the layout of current possible receiver, which at the moment of writing this code is one of + // `&T`, `&mut T`, `Box`, `Rc`, `Arc`, and `Pin

` where `P` is one of possible recievers, + // the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on + // the type. + let ty = self + .vtable_map + .ty_of_bytes(&arg_bytes[0][self.ptr_size()..self.ptr_size() * 2])?; + let ty = GenericArgData::Ty(ty.clone()).intern(Interner); + let mut args_for_target = arg_bytes; + args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec(); + let generics_for_target = Substitution::from_iter( + Interner, + generic_args + .iter(Interner) + .enumerate() + .map(|(i, x)| if i == self_ty_idx { &ty } else { x }) + ); + return self.exec_fn_with_args( + def, + args_for_target, + generics_for_target, + locals, + dest_addr, + ); + } + let (imp, generic_args) = + lookup_impl_method(self.db, self.trait_env.clone(), def, generic_args.clone()); + let generic_args = self.subst_filler(&generic_args, &locals); + let def = imp.into(); + let mir_body = + self.db.mir_body(def).map_err(|e| MirEvalError::MirLowerError(imp, e))?; + self.interpret_mir(&mir_body, arg_bytes.iter().cloned(), generic_args) + .map_err(|e| MirEvalError::InFunction(imp, Box::new(e)))? + }; + self.write_memory(dest_addr, &result)?; + Ok(()) + } + fn exec_fn_trait( &mut self, ft: FnTrait, @@ -1317,12 +1368,9 @@ fn exec_fn_trait( Ok(()) } - fn exec_lang_item( - &self, - x: LangItem, - mut args: std::vec::IntoIter>, - ) -> Result> { + fn exec_lang_item(&self, x: LangItem, args: &[Vec]) -> Result> { use LangItem::*; + let mut args = args.iter(); match x { PanicFmt | BeginPanic => Err(MirEvalError::Panic), SliceLen => { diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 7a5ca089420..4fc3c67a6e1 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -230,7 +230,14 @@ fn lower_expr_to_place_without_adjust( self.lower_const(c, current, place, expr_id.into())?; return Ok(Some(current)) }, - _ => not_supported!("associated functions and types"), + hir_def::AssocItemId::FunctionId(_) => { + // FnDefs are zero sized, no action is needed. + return Ok(Some(current)) + } + hir_def::AssocItemId::TypeAliasId(_) => { + // FIXME: If it is unreachable, use proper error instead of `not_supported`. + not_supported!("associated functions and types") + }, } } else if let Some(variant) = self .infer