From 71cc39e1f2f4cd7e8cd6fa09c65ecab06ce9cc80 Mon Sep 17 00:00:00 2001 From: Jason Newcomb Date: Sun, 30 Jul 2023 00:00:19 -0400 Subject: [PATCH] Add debug assertions to `implements_trait` Improve debug assertions for `make_projection` --- clippy_lints/src/derive.rs | 4 +- clippy_lints/src/incorrect_impls.rs | 8 +- clippy_lints/src/loops/explicit_iter_loop.rs | 2 +- clippy_lints/src/methods/or_fun_call.rs | 2 +- clippy_lints/src/needless_pass_by_value.rs | 10 +- clippy_utils/src/ty.rs | 184 +++++++++++-------- 6 files changed, 117 insertions(+), 93 deletions(-) diff --git a/clippy_lints/src/derive.rs b/clippy_lints/src/derive.rs index c343f248d06..ad7b3b27383 100644 --- a/clippy_lints/src/derive.rs +++ b/clippy_lints/src/derive.rs @@ -471,12 +471,12 @@ fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_r if let Some(def_id) = trait_ref.trait_def_id(); if cx.tcx.is_diagnostic_item(sym::PartialEq, def_id); let param_env = param_env_for_derived_eq(cx.tcx, adt.did(), eq_trait_def_id); - if !implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, []); + if !implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, &[]); // If all of our fields implement `Eq`, we can implement `Eq` too if adt .all_fields() .map(|f| f.ty(cx.tcx, args)) - .all(|ty| implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, [])); + .all(|ty| implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, &[])); then { span_lint_and_sugg( cx, diff --git a/clippy_lints/src/incorrect_impls.rs b/clippy_lints/src/incorrect_impls.rs index bce010bb168..f4edb17ae80 100644 --- a/clippy_lints/src/incorrect_impls.rs +++ b/clippy_lints/src/incorrect_impls.rs @@ -189,12 +189,8 @@ fn check_impl_item(&mut self, cx: &LateContext<'_>, impl_item: &ImplItem<'_>) { .diagnostic_items(trait_impl.def_id.krate) .name_to_id .get(&sym::Ord) - && implements_trait( - cx, - hir_ty_to_ty(cx.tcx, imp.self_ty), - *ord_def_id, - trait_impl.args, - ) + && trait_impl.self_ty() == trait_impl.args.type_at(1) + && implements_trait(cx, hir_ty_to_ty(cx.tcx, imp.self_ty), *ord_def_id, &[]) { // If the `cmp` call likely needs to be fully qualified in the suggestion // (like `std::cmp::Ord::cmp`). It's unfortunate we must put this here but we can't diff --git a/clippy_lints/src/loops/explicit_iter_loop.rs b/clippy_lints/src/loops/explicit_iter_loop.rs index a84a0a6eeb8..7b8c88235a9 100644 --- a/clippy_lints/src/loops/explicit_iter_loop.rs +++ b/clippy_lints/src/loops/explicit_iter_loop.rs @@ -109,7 +109,7 @@ fn is_ref_iterable<'tcx>( && let sig = cx.tcx.liberate_late_bound_regions(fn_id, cx.tcx.fn_sig(fn_id).skip_binder()) && let &[req_self_ty, req_res_ty] = &**sig.inputs_and_output && let param_env = cx.tcx.param_env(fn_id) - && implements_trait_with_env(cx.tcx, param_env, req_self_ty, trait_id, []) + && implements_trait_with_env(cx.tcx, param_env, req_self_ty, trait_id, &[]) && let Some(into_iter_ty) = make_normalized_projection_with_regions(cx.tcx, param_env, trait_id, sym!(IntoIter), [req_self_ty]) && let req_res_ty = normalize_with_regions(cx.tcx, param_env, req_res_ty) diff --git a/clippy_lints/src/methods/or_fun_call.rs b/clippy_lints/src/methods/or_fun_call.rs index 23039f2f3f1..8b2f57160af 100644 --- a/clippy_lints/src/methods/or_fun_call.rs +++ b/clippy_lints/src/methods/or_fun_call.rs @@ -57,7 +57,7 @@ fn check_unwrap_or_default( cx.tcx .get_diagnostic_item(sym::Default) .map_or(false, |default_trait_id| { - implements_trait(cx, output_ty, default_trait_id, args) + implements_trait(cx, output_ty, default_trait_id, &[]) }) } else { false diff --git a/clippy_lints/src/needless_pass_by_value.rs b/clippy_lints/src/needless_pass_by_value.rs index 5e26601537f..5ee26966fa7 100644 --- a/clippy_lints/src/needless_pass_by_value.rs +++ b/clippy_lints/src/needless_pass_by_value.rs @@ -2,7 +2,7 @@ use clippy_utils::ptr::get_spans; use clippy_utils::source::{snippet, snippet_opt}; use clippy_utils::ty::{ - implements_trait, implements_trait_with_env, is_copy, is_type_diagnostic_item, is_type_lang_item, + implements_trait, implements_trait_with_env_from_iter, is_copy, is_type_diagnostic_item, is_type_lang_item, }; use clippy_utils::{get_trait_def_id, is_self, paths}; use if_chain::if_chain; @@ -182,7 +182,13 @@ fn check_fn( if !ty.is_mutable_ptr(); if !is_copy(cx, ty); if ty.is_sized(cx.tcx, cx.param_env); - if !allowed_traits.iter().any(|&t| implements_trait_with_env(cx.tcx, cx.param_env, ty, t, [None])); + if !allowed_traits.iter().any(|&t| implements_trait_with_env_from_iter( + cx.tcx, + cx.param_env, + ty, + t, + [Option::>::None], + )); if !implements_borrow_trait; if !all_borrowable_trait; diff --git a/clippy_utils/src/ty.rs b/clippy_utils/src/ty.rs index 44ec8813f19..dbb6b986952 100644 --- a/clippy_utils/src/ty.rs +++ b/clippy_utils/src/ty.rs @@ -3,6 +3,7 @@ #![allow(clippy::module_name_repetitions)] use core::ops::ControlFlow; +use itertools::Itertools; use rustc_ast::ast::Mutability; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_hir as hir; @@ -13,17 +14,19 @@ use rustc_infer::infer::TyCtxtInferExt; use rustc_lint::LateContext; use rustc_middle::mir::interpret::{ConstValue, Scalar}; +use rustc_middle::traits::EvaluationResult; use rustc_middle::ty::layout::ValidityRequirement; use rustc_middle::ty::{ - self, AdtDef, AliasTy, AssocKind, Binder, BoundRegion, FnSig, GenericArg, GenericArgKind, GenericArgsRef, IntTy, - List, ParamEnv, Region, RegionKind, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, - UintTy, VariantDef, VariantDiscr, + self, AdtDef, AliasTy, AssocKind, Binder, BoundRegion, FnSig, GenericArg, GenericArgKind, GenericArgsRef, + GenericParamDefKind, IntTy, List, ParamEnv, Region, RegionKind, ToPredicate, TraitRef, Ty, TyCtxt, + TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, UintTy, VariantDef, VariantDiscr, }; use rustc_span::symbol::Ident; use rustc_span::{sym, Span, Symbol, DUMMY_SP}; use rustc_target::abi::{Size, VariantIdx}; -use rustc_trait_selection::infer::InferCtxtExt; +use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _; use rustc_trait_selection::traits::query::normalize::QueryNormalizeExt; +use rustc_trait_selection::traits::{Obligation, ObligationCause}; use std::iter; use crate::{match_def_path, path_res, paths}; @@ -207,15 +210,9 @@ pub fn implements_trait<'tcx>( cx: &LateContext<'tcx>, ty: Ty<'tcx>, trait_id: DefId, - ty_params: &[GenericArg<'tcx>], + args: &[GenericArg<'tcx>], ) -> bool { - implements_trait_with_env( - cx.tcx, - cx.param_env, - ty, - trait_id, - ty_params.iter().map(|&arg| Some(arg)), - ) + implements_trait_with_env_from_iter(cx.tcx, cx.param_env, ty, trait_id, args.iter().map(|&x| Some(x))) } /// Same as `implements_trait` but allows using a `ParamEnv` different from the lint context. @@ -224,7 +221,18 @@ pub fn implements_trait_with_env<'tcx>( param_env: ParamEnv<'tcx>, ty: Ty<'tcx>, trait_id: DefId, - ty_params: impl IntoIterator>>, + args: &[GenericArg<'tcx>], +) -> bool { + implements_trait_with_env_from_iter(tcx, param_env, ty, trait_id, args.iter().map(|&x| Some(x))) +} + +/// Same as `implements_trait_from_env` but takes the arguments as an iterator. +pub fn implements_trait_with_env_from_iter<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + ty: Ty<'tcx>, + trait_id: DefId, + args: impl IntoIterator>>>, ) -> bool { // Clippy shouldn't have infer types assert!(!ty.has_infer()); @@ -233,19 +241,37 @@ pub fn implements_trait_with_env<'tcx>( if ty.has_escaping_bound_vars() { return false; } + let infcx = tcx.infer_ctxt().build(); - let orig = TypeVariableOrigin { - kind: TypeVariableOriginKind::MiscVariable, - span: DUMMY_SP, - }; - let ty_params = tcx.mk_args_from_iter( - ty_params + let trait_ref = TraitRef::new( + tcx, + trait_id, + Some(GenericArg::from(ty)) .into_iter() - .map(|arg| arg.unwrap_or_else(|| infcx.next_ty_var(orig).into())), + .chain(args.into_iter().map(|arg| { + arg.into().unwrap_or_else(|| { + let orig = TypeVariableOrigin { + kind: TypeVariableOriginKind::MiscVariable, + span: DUMMY_SP, + }; + infcx.next_ty_var(orig).into() + }) + })), ); + + debug_assert_eq!(tcx.def_kind(trait_id), DefKind::Trait); + #[cfg(debug_assertions)] + assert_generic_args_match(tcx, trait_id, trait_ref.args); + + let obligation = Obligation { + cause: ObligationCause::dummy(), + param_env, + recursion_depth: 0, + predicate: ty::Binder::dummy(trait_ref).without_const().to_predicate(tcx), + }; infcx - .type_implements_trait(trait_id, [ty.into()].into_iter().chain(ty_params), param_env) - .must_apply_modulo_regions() + .evaluate_obligation(&obligation) + .is_ok_and(EvaluationResult::must_apply_modulo_regions) } /// Checks whether this type implements `Drop`. @@ -1014,12 +1040,60 @@ pub fn approx_ty_size<'tcx>(cx: &LateContext<'tcx>, ty: Ty<'tcx>) -> u64 { } } +/// Asserts that the given arguments match the generic parameters of the given item. +#[allow(dead_code)] +fn assert_generic_args_match<'tcx>(tcx: TyCtxt<'tcx>, did: DefId, args: &[GenericArg<'tcx>]) { + let g = tcx.generics_of(did); + let parent = g.parent.map(|did| tcx.generics_of(did)); + let count = g.parent_count + g.params.len(); + let params = parent + .map_or([].as_slice(), |p| p.params.as_slice()) + .iter() + .chain(&g.params) + .map(|x| &x.kind); + + assert!( + count == args.len(), + "wrong number of arguments for `{did:?}`: expected `{count}`, found {}\n\ + note: the expected arguments are: `[{}]`\n\ + the given arguments are: `{args:#?}`", + args.len(), + params.clone().map(GenericParamDefKind::descr).format(", "), + ); + + if let Some((idx, (param, arg))) = + params + .clone() + .zip(args.iter().map(|&x| x.unpack())) + .enumerate() + .find(|(_, (param, arg))| match (param, arg) { + (GenericParamDefKind::Lifetime, GenericArgKind::Lifetime(_)) + | (GenericParamDefKind::Type { .. }, GenericArgKind::Type(_)) + | (GenericParamDefKind::Const { .. }, GenericArgKind::Const(_)) => false, + ( + GenericParamDefKind::Lifetime + | GenericParamDefKind::Type { .. } + | GenericParamDefKind::Const { .. }, + _, + ) => true, + }) + { + panic!( + "incorrect argument for `{did:?}` at index `{idx}`: expected a {}, found `{arg:?}`\n\ + note: the expected arguments are `[{}]`\n\ + the given arguments are `{args:#?}`", + param.descr(), + params.clone().map(GenericParamDefKind::descr).format(", "), + ); + } +} + /// Makes the projection type for the named associated type in the given impl or trait impl. /// /// This function is for associated types which are "known" to exist, and as such, will only return /// `None` when debug assertions are disabled in order to prevent ICE's. With debug assertions /// enabled this will check that the named associated type exists, the correct number of -/// substitutions are given, and that the correct kinds of substitutions are given (lifetime, +/// arguments are given, and that the correct kinds of arguments are given (lifetime, /// constant or type). This will not check if type normalization would succeed. pub fn make_projection<'tcx>( tcx: TyCtxt<'tcx>, @@ -1043,49 +1117,7 @@ fn helper<'tcx>( return None; }; #[cfg(debug_assertions)] - { - let generics = tcx.generics_of(assoc_item.def_id); - let generic_count = generics.parent_count + generics.params.len(); - let params = generics - .parent - .map_or([].as_slice(), |id| &*tcx.generics_of(id).params) - .iter() - .chain(&generics.params) - .map(|x| &x.kind); - - debug_assert!( - generic_count == args.len(), - "wrong number of args for `{:?}`: found `{}` expected `{generic_count}`.\n\ - note: the expected parameters are: {:#?}\n\ - the given arguments are: `{args:#?}`", - assoc_item.def_id, - args.len(), - params.map(ty::GenericParamDefKind::descr).collect::>(), - ); - - if let Some((idx, (param, arg))) = params - .clone() - .zip(args.iter().map(GenericArg::unpack)) - .enumerate() - .find(|(_, (param, arg))| { - !matches!( - (param, arg), - (ty::GenericParamDefKind::Lifetime, GenericArgKind::Lifetime(_)) - | (ty::GenericParamDefKind::Type { .. }, GenericArgKind::Type(_)) - | (ty::GenericParamDefKind::Const { .. }, GenericArgKind::Const(_)) - ) - }) - { - debug_assert!( - false, - "mismatched subst type at index {idx}: expected a {}, found `{arg:?}`\n\ - note: the expected parameters are {:#?}\n\ - the given arguments are {args:#?}", - param.descr(), - params.map(ty::GenericParamDefKind::descr).collect::>() - ); - } - } + assert_generic_args_match(tcx, assoc_item.def_id, args); Some(tcx.mk_alias_ty(assoc_item.def_id, args)) } @@ -1100,7 +1132,7 @@ fn helper<'tcx>( /// Normalizes the named associated type in the given impl or trait impl. /// /// This function is for associated types which are "known" to be valid with the given -/// substitutions, and as such, will only return `None` when debug assertions are disabled in order +/// arguments, and as such, will only return `None` when debug assertions are disabled in order /// to prevent ICE's. With debug assertions enabled this will check that type normalization /// succeeds as well as everything checked by `make_projection`. pub fn make_normalized_projection<'tcx>( @@ -1112,17 +1144,12 @@ pub fn make_normalized_projection<'tcx>( ) -> Option> { fn helper<'tcx>(tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, ty: AliasTy<'tcx>) -> Option> { #[cfg(debug_assertions)] - if let Some((i, subst)) = ty - .args - .iter() - .enumerate() - .find(|(_, subst)| subst.has_late_bound_regions()) - { + if let Some((i, arg)) = ty.args.iter().enumerate().find(|(_, arg)| arg.has_late_bound_regions()) { debug_assert!( false, "args contain late-bound region at index `{i}` which can't be normalized.\n\ use `TyCtxt::erase_late_bound_regions`\n\ - note: subst is `{subst:#?}`", + note: arg is `{arg:#?}`", ); return None; } @@ -1190,17 +1217,12 @@ pub fn make_normalized_projection_with_regions<'tcx>( ) -> Option> { fn helper<'tcx>(tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, ty: AliasTy<'tcx>) -> Option> { #[cfg(debug_assertions)] - if let Some((i, subst)) = ty - .args - .iter() - .enumerate() - .find(|(_, subst)| subst.has_late_bound_regions()) - { + if let Some((i, arg)) = ty.args.iter().enumerate().find(|(_, arg)| arg.has_late_bound_regions()) { debug_assert!( false, "args contain late-bound region at index `{i}` which can't be normalized.\n\ use `TyCtxt::erase_late_bound_regions`\n\ - note: subst is `{subst:#?}`", + note: arg is `{arg:#?}`", ); return None; }