From 2157f3173161dae18621ccdfb88a1446eb2d41ff Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Fri, 22 Sep 2023 09:14:39 +0000 Subject: [PATCH] Add a way to decouple the implementation and the declaration of a TyCtxt method. --- compiler/rustc_codegen_gcc/src/lib.rs | 3 +- compiler/rustc_codegen_llvm/src/lib.rs | 3 +- .../rustc_codegen_ssa/src/traits/backend.rs | 3 +- .../rustc_const_eval/src/const_eval/mod.rs | 7 +- compiler/rustc_const_eval/src/lib.rs | 8 +-- compiler/rustc_interface/src/passes.rs | 13 ++-- compiler/rustc_middle/src/hooks/mod.rs | 64 +++++++++++++++++++ compiler/rustc_middle/src/lib.rs | 1 + compiler/rustc_middle/src/mir/pretty.rs | 2 +- compiler/rustc_middle/src/query/mod.rs | 10 --- compiler/rustc_middle/src/ty/context.rs | 6 ++ src/librustdoc/core.rs | 2 +- src/tools/clippy/clippy_utils/src/consts.rs | 2 +- 13 files changed, 96 insertions(+), 28 deletions(-) create mode 100644 compiler/rustc_middle/src/hooks/mod.rs diff --git a/compiler/rustc_codegen_gcc/src/lib.rs b/compiler/rustc_codegen_gcc/src/lib.rs index 697ae015fed..1567fea3e1c 100644 --- a/compiler/rustc_codegen_gcc/src/lib.rs +++ b/compiler/rustc_codegen_gcc/src/lib.rs @@ -82,6 +82,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::query::Providers; use rustc_middle::ty::TyCtxt; +use rustc_middle::hooks; use rustc_session::config::{Lto, OptLevel, OutputFilenames}; use rustc_session::Session; use rustc_span::Symbol; @@ -127,7 +128,7 @@ fn init(&self, sess: &Session) { *self.supports_128bit_integers.lock().expect("lock") = check_context.get_last_error() == Ok(None); } - fn provide(&self, providers: &mut Providers) { + fn provide(&self, providers: &mut Providers, _: &mut hooks::Providers) { // FIXME(antoyo) compute list of enabled features from cli flags providers.global_backend_features = |_tcx, ()| vec![]; } diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index ac199624e34..0f2acfd30f7 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -39,6 +39,7 @@ use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::hooks; use rustc_middle::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; @@ -268,7 +269,7 @@ fn init(&self, sess: &Session) { llvm_util::init(sess); // Make sure llvm is inited } - fn provide(&self, providers: &mut Providers) { + fn provide(&self, providers: &mut Providers, _hooks: &mut hooks::Providers) { providers.global_backend_features = |tcx, ()| llvm_util::global_llvm_features(tcx.sess, true) } diff --git a/compiler/rustc_codegen_ssa/src/traits/backend.rs b/compiler/rustc_codegen_ssa/src/traits/backend.rs index 0a02ca6b317..e00243075b9 100644 --- a/compiler/rustc_codegen_ssa/src/traits/backend.rs +++ b/compiler/rustc_codegen_ssa/src/traits/backend.rs @@ -11,6 +11,7 @@ use rustc_errors::ErrorGuaranteed; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::hooks; use rustc_middle::query::{ExternProviders, Providers}; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, LayoutOf, TyAndLayout}; use rustc_middle::ty::{Ty, TyCtxt}; @@ -84,7 +85,7 @@ fn metadata_loader(&self) -> Box { Box::new(crate::back::metadata::DefaultMetadataLoader) } - fn provide(&self, _providers: &mut Providers) {} + fn provide(&self, _providers: &mut Providers, _hooks: &mut hooks::Providers) {} fn provide_extern(&self, _providers: &mut ExternProviders) {} fn codegen_crate<'tcx>( &self, diff --git a/compiler/rustc_const_eval/src/const_eval/mod.rs b/compiler/rustc_const_eval/src/const_eval/mod.rs index 886d7972a15..bcbe996be7d 100644 --- a/compiler/rustc_const_eval/src/const_eval/mod.rs +++ b/compiler/rustc_const_eval/src/const_eval/mod.rs @@ -4,6 +4,7 @@ use crate::interpret::{intern_const_alloc_recursive, InternKind, InterpCx, Scalar}; use rustc_middle::mir; use rustc_middle::mir::interpret::{EvalToValTreeResult, GlobalId}; +use rustc_middle::query::TyCtxtAt; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::{source_map::DUMMY_SP, symbol::Symbol}; @@ -86,17 +87,17 @@ pub(crate) fn eval_to_valtree<'tcx>( #[instrument(skip(tcx), level = "debug")] pub(crate) fn try_destructure_mir_constant_for_diagnostics<'tcx>( - tcx: TyCtxt<'tcx>, + tcx: TyCtxtAt<'tcx>, val: mir::ConstValue<'tcx>, ty: Ty<'tcx>, ) -> Option> { let param_env = ty::ParamEnv::reveal_all(); - let ecx = mk_eval_cx(tcx, DUMMY_SP, param_env, CanAccessStatics::No); + let ecx = mk_eval_cx(tcx.tcx, tcx.span, param_env, CanAccessStatics::No); let op = ecx.const_val_to_op(val, ty, None).ok()?; // We go to `usize` as we cannot allocate anything bigger anyway. let (field_count, variant, down) = match ty.kind() { - ty::Array(_, len) => (len.eval_target_usize(tcx, param_env) as usize, None, op), + ty::Array(_, len) => (len.eval_target_usize(tcx.tcx, param_env) as usize, None, op), ty::Adt(def, _) if def.variants().is_empty() => { return None; } diff --git a/compiler/rustc_const_eval/src/lib.rs b/compiler/rustc_const_eval/src/lib.rs index c126f749bf3..8ac2b519760 100644 --- a/compiler/rustc_const_eval/src/lib.rs +++ b/compiler/rustc_const_eval/src/lib.rs @@ -39,11 +39,11 @@ use rustc_errors::{DiagnosticMessage, SubdiagnosticMessage}; use rustc_fluent_macro::fluent_messages; use rustc_middle::query::Providers; -use rustc_middle::ty; +use rustc_middle::{hooks, ty}; fluent_messages! { "../messages.ftl" } -pub fn provide(providers: &mut Providers) { +pub fn provide(providers: &mut Providers, hooks: &mut hooks::Providers) { const_eval::provide(providers); providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider; providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider; @@ -52,8 +52,8 @@ pub fn provide(providers: &mut Providers) { let (param_env, raw) = param_env_and_value.into_parts(); const_eval::eval_to_valtree(tcx, param_env, raw) }; - providers.try_destructure_mir_constant_for_diagnostics = - |tcx, (cv, ty)| const_eval::try_destructure_mir_constant_for_diagnostics(tcx, cv, ty); + hooks.try_destructure_mir_constant_for_diagnostics = + const_eval::try_destructure_mir_constant_for_diagnostics; providers.valtree_to_const_val = |tcx, (ty, valtree)| { const_eval::valtree_to_const_value(tcx, ty::ParamEnv::empty().and(ty), valtree) }; diff --git a/compiler/rustc_interface/src/passes.rs b/compiler/rustc_interface/src/passes.rs index e5ae6d5b5d6..6de7686fbd9 100644 --- a/compiler/rustc_interface/src/passes.rs +++ b/compiler/rustc_interface/src/passes.rs @@ -18,6 +18,7 @@ use rustc_metadata::creader::CStore; use rustc_middle::arena::Arena; use rustc_middle::dep_graph::DepGraph; +use rustc_middle::hooks; use rustc_middle::query::{ExternProviders, Providers}; use rustc_middle::ty::{self, GlobalCtxt, RegisteredTools, TyCtxt}; use rustc_mir_build as mir_build; @@ -645,15 +646,16 @@ fn output_filenames(tcx: TyCtxt<'_>, (): ()) -> Arc { outputs.into() } -pub static DEFAULT_QUERY_PROVIDERS: LazyLock = LazyLock::new(|| { +pub static DEFAULT_QUERY_PROVIDERS: LazyLock<(Providers, hooks::Providers)> = LazyLock::new(|| { let providers = &mut Providers::default(); + let hooks = &mut hooks::Providers::default(); providers.analysis = analysis; providers.hir_crate = rustc_ast_lowering::lower_to_hir; providers.output_filenames = output_filenames; providers.resolver_for_lowering = resolver_for_lowering; providers.early_lint_checks = early_lint_checks; proc_macro_decls::provide(providers); - rustc_const_eval::provide(providers); + rustc_const_eval::provide(providers, hooks); rustc_middle::hir::provide(providers); mir_borrowck::provide(providers); mir_build::provide(providers); @@ -672,7 +674,7 @@ fn output_filenames(tcx: TyCtxt<'_>, (): ()) -> Arc { rustc_lint::provide(providers); rustc_symbol_mangling::provide(providers); rustc_codegen_ssa::provide(providers); - *providers + (*providers, *hooks) }); pub static DEFAULT_EXTERN_QUERY_PROVIDERS: LazyLock = LazyLock::new(|| { @@ -702,8 +704,8 @@ pub fn create_global_ctxt<'tcx>( let query_result_on_disk_cache = rustc_incremental::load_query_result_cache(sess); let codegen_backend = compiler.codegen_backend(); - let mut local_providers = *DEFAULT_QUERY_PROVIDERS; - codegen_backend.provide(&mut local_providers); + let (mut local_providers, mut hooks) = *DEFAULT_QUERY_PROVIDERS; + codegen_backend.provide(&mut local_providers, &mut hooks); let mut extern_providers = *DEFAULT_EXTERN_QUERY_PROVIDERS; codegen_backend.provide_extern(&mut extern_providers); @@ -732,6 +734,7 @@ pub fn create_global_ctxt<'tcx>( query_result_on_disk_cache, incremental, ), + hooks, ) }) }) diff --git a/compiler/rustc_middle/src/hooks/mod.rs b/compiler/rustc_middle/src/hooks/mod.rs new file mode 100644 index 00000000000..cb84a5936b3 --- /dev/null +++ b/compiler/rustc_middle/src/hooks/mod.rs @@ -0,0 +1,64 @@ +use crate::mir; +use crate::query::TyCtxtAt; +use crate::ty::{Ty, TyCtxt}; +use rustc_span::DUMMY_SP; + +macro_rules! declare_hooks { + ($($(#[$attr:meta])*hook $name:ident($($arg:ident: $K:ty),*) -> $V:ty;)*) => { + + impl<'tcx> TyCtxt<'tcx> { + $( + $(#[$attr])* + #[inline(always)] + #[must_use] + pub fn $name(self, $($arg: $K,)*) -> $V + { + self.at(DUMMY_SP).$name($($arg,)*) + } + )* + } + + impl<'tcx> TyCtxtAt<'tcx> { + $( + $(#[$attr])* + #[inline(always)] + #[must_use] + pub fn $name(self, $($arg: $K,)*) -> $V + { + (self.tcx.hooks.$name)(self, $($arg,)*) + } + )* + } + + pub struct Providers { + $(pub $name: for<'tcx> fn( + TyCtxtAt<'tcx>, + $($arg: $K,)* + ) -> $V,)* + } + + impl Default for Providers { + fn default() -> Self { + Providers { + $($name: |_, $($arg,)*| bug!( + "`tcx.{}{:?}` cannot be called as `{}` was never assigned to a provider function.\n", + stringify!($name), + ($($arg,)*), + stringify!($name), + ),)* + } + } + } + + impl Copy for Providers {} + impl Clone for Providers { + fn clone(&self) -> Self { *self } + } + }; +} + +declare_hooks! { + /// Tries to destructure an `mir::Const` ADT or array into its variant index + /// and its field values. This should only be used for pretty printing. + hook try_destructure_mir_constant_for_diagnostics(val: mir::ConstValue<'tcx>, ty: Ty<'tcx>) -> Option>; +} diff --git a/compiler/rustc_middle/src/lib.rs b/compiler/rustc_middle/src/lib.rs index 50b69181d67..fe4fc3761b3 100644 --- a/compiler/rustc_middle/src/lib.rs +++ b/compiler/rustc_middle/src/lib.rs @@ -89,6 +89,7 @@ pub mod arena; pub mod error; pub mod hir; +pub mod hooks; pub mod infer; pub mod lint; pub mod metadata; diff --git a/compiler/rustc_middle/src/mir/pretty.rs b/compiler/rustc_middle/src/mir/pretty.rs index cc2a5aa62c8..8488367189f 100644 --- a/compiler/rustc_middle/src/mir/pretty.rs +++ b/compiler/rustc_middle/src/mir/pretty.rs @@ -1691,7 +1691,7 @@ fn pretty_print_const_value_tcx<'tcx>( (_, ty::Array(..) | ty::Tuple(..) | ty::Adt(..)) if !ty.has_non_region_param() => { let ct = tcx.lift(ct).unwrap(); let ty = tcx.lift(ty).unwrap(); - if let Some(contents) = tcx.try_destructure_mir_constant_for_diagnostics((ct, ty)) { + if let Some(contents) = tcx.try_destructure_mir_constant_for_diagnostics(ct, ty) { let fields: Vec<(ConstValue<'_>, Ty<'_>)> = contents.fields.to_vec(); match *ty.kind() { ty::Array(..) => { diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index d888a2c0fb6..6928f59c3fd 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -1100,16 +1100,6 @@ desc { "destructuring type level constant"} } - /// Tries to destructure an `mir::Const` ADT or array into its variant index - /// and its field values. This should only be used for pretty printing. - query try_destructure_mir_constant_for_diagnostics( - key: (mir::ConstValue<'tcx>, Ty<'tcx>) - ) -> Option> { - desc { "destructuring MIR constant"} - no_hash - eval_always - } - query const_caller_location(key: (rustc_span::Symbol, u32, u32)) -> mir::ConstValue<'tcx> { desc { "getting a &core::panic::Location referring to a span" } } diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 9ff4b64f48b..25eafce0d9e 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -554,6 +554,10 @@ pub struct GlobalCtxt<'tcx> { /// Common consts, pre-interned for your convenience. pub consts: CommonConsts<'tcx>, + /// Hooks to be able to register functions in other crates that can then still + /// be called from rustc_middle. + pub(crate) hooks: crate::hooks::Providers, + untracked: Untracked, pub query_system: QuerySystem<'tcx>, @@ -703,6 +707,7 @@ pub fn create_global_ctxt( dep_graph: DepGraph, query_kinds: &'tcx [DepKindStruct<'tcx>], query_system: QuerySystem<'tcx>, + hooks: crate::hooks::Providers, ) -> GlobalCtxt<'tcx> { let data_layout = s.target.parse_data_layout().unwrap_or_else(|err| { s.emit_fatal(err); @@ -721,6 +726,7 @@ pub fn create_global_ctxt( hir_arena, interners, dep_graph, + hooks, prof: s.prof.clone(), types: common_types, lifetimes: common_lifetimes, diff --git a/src/librustdoc/core.rs b/src/librustdoc/core.rs index 7cd25ef444b..cc4d4fd11fa 100644 --- a/src/librustdoc/core.rs +++ b/src/librustdoc/core.rs @@ -286,7 +286,7 @@ pub(crate) fn create_config( let body = hir.body(hir.body_owned_by(def_id)); debug!("visiting body for {def_id:?}"); EmitIgnoredResolutionErrors::new(tcx).visit_body(body); - (rustc_interface::DEFAULT_QUERY_PROVIDERS.typeck)(tcx, def_id) + (rustc_interface::DEFAULT_QUERY_PROVIDERS.0.typeck)(tcx, def_id) }; }), make_codegen_backend: None, diff --git a/src/tools/clippy/clippy_utils/src/consts.rs b/src/tools/clippy/clippy_utils/src/consts.rs index a136de86240..6b1a738aaa9 100644 --- a/src/tools/clippy/clippy_utils/src/consts.rs +++ b/src/tools/clippy/clippy_utils/src/consts.rs @@ -718,7 +718,7 @@ fn field_of_struct<'tcx>( field: &Ident, ) -> Option> { if let mir::Const::Val(result, ty) = result - && let Some(dc) = lcx.tcx.try_destructure_mir_constant_for_diagnostics((result, ty)) + && let Some(dc) = lcx.tcx.try_destructure_mir_constant_for_diagnostics(result, ty) && let Some(dc_variant) = dc.variant && let Some(variant) = adt_def.variants().get(dc_variant) && let Some(field_idx) = variant.fields.iter().position(|el| el.name == field.name)