diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 54342f1b7c4..012812cea24 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -50,6 +50,7 @@ per_ns::PerNs, resolver::{HasResolver, Resolver}, src::HasSource as _, + type_ref::ConstScalar, AdtId, AssocItemId, AssocItemLoc, AttrDefId, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, FunctionId, GenericDefId, HasModule, ImplId, ItemContainerId, LifetimeParamId, LocalEnumVariantId, LocalFieldId, Lookup, MacroExpander, MacroId, ModuleId, StaticId, StructId, @@ -65,8 +66,9 @@ primitive::UintTy, traits::FnTrait, AliasTy, CallableDefId, CallableSig, Canonical, CanonicalVarKinds, Cast, ClosureId, - GenericArgData, Interner, ParamKind, QuantifiedWhereClause, Scalar, Substitution, - TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyDefId, TyExt, TyKind, WhereClause, + ConcreteConst, ConstValue, GenericArgData, Interner, ParamKind, QuantifiedWhereClause, Scalar, + Substitution, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyDefId, TyExt, TyKind, + WhereClause, }; use itertools::Itertools; use nameres::diagnostics::DefDiagnosticKind; @@ -3232,6 +3234,19 @@ pub fn tuple_fields(&self, _db: &dyn HirDatabase) -> Vec { } } + pub fn as_array(&self, _db: &dyn HirDatabase) -> Option<(Type, usize)> { + if let TyKind::Array(ty, len) = &self.ty.kind(Interner) { + match len.data(Interner).value { + ConstValue::Concrete(ConcreteConst { interned: ConstScalar::UInt(len) }) => { + Some((self.derived(ty.clone()), len as usize)) + } + _ => None, + } + } else { + None + } + } + pub fn autoderef<'a>(&'a self, db: &'a dyn HirDatabase) -> impl Iterator + 'a { self.autoderef_(db).map(move |ty| self.derived(ty)) } diff --git a/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/crates/ide-assists/src/handlers/add_missing_match_arms.rs index 0461cc790eb..5d81e8cfeac 100644 --- a/crates/ide-assists/src/handlers/add_missing_match_arms.rs +++ b/crates/ide-assists/src/handlers/add_missing_match_arms.rs @@ -140,6 +140,31 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) }) .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat)); ((Box::new(missing_pats) as Box>).peekable(), is_non_exhaustive) + } else if let Some((enum_def, len)) = resolve_array_of_enum_def(&ctx.sema, &expr) { + let is_non_exhaustive = enum_def.is_non_exhaustive(ctx.db(), module.krate()); + let variants = enum_def.variants(ctx.db()); + + if len.pow(variants.len() as u32) > 256 { + return None; + } + + let variants_of_enums = vec![variants.clone(); len]; + + let missing_pats = variants_of_enums + .into_iter() + .multi_cartesian_product() + .inspect(|_| cov_mark::hit!(add_missing_match_arms_lazy_computation)) + .map(|variants| { + let is_hidden = variants + .iter() + .any(|variant| variant.should_be_hidden(ctx.db(), module.krate())); + let patterns = variants.into_iter().filter_map(|variant| { + build_pat(ctx.db(), module, variant.clone(), ctx.config.prefer_no_std) + }); + (ast::Pat::from(make::slice_pat(patterns)), is_hidden) + }) + .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat)); + ((Box::new(missing_pats) as Box>).peekable(), is_non_exhaustive) } else { return None; }; @@ -266,6 +291,9 @@ fn is_variant_missing(existing_pats: &[Pat], var: &Pat) -> bool { fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool { match (pat, var) { (Pat::WildcardPat(_), _) => true, + (Pat::SlicePat(spat), Pat::SlicePat(svar)) => { + spat.pats().zip(svar.pats()).all(|(p, v)| does_pat_match_variant(&p, &v)) + } (Pat::TuplePat(tpat), Pat::TuplePat(tvar)) => { tpat.fields().zip(tvar.fields()).all(|(p, v)| does_pat_match_variant(&p, &v)) } @@ -280,7 +308,7 @@ enum ExtendedEnum { Enum(hir::Enum), } -#[derive(Eq, PartialEq, Clone, Copy)] +#[derive(Eq, PartialEq, Clone, Copy, Debug)] enum ExtendedVariant { True, False, @@ -340,15 +368,30 @@ fn resolve_tuple_of_enum_def( .tuple_fields(sema.db) .iter() .map(|ty| { - ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() { - Some(Adt::Enum(e)) => Some(lift_enum(e)), - // For now we only handle expansion for a tuple of enums. Here - // we map non-enum items to None and rely on `collect` to - // convert Vec> into Option>. - _ => ty.is_bool().then_some(ExtendedEnum::Bool), + ty.autoderef(sema.db).find_map(|ty| { + match ty.as_adt() { + Some(Adt::Enum(e)) => Some(lift_enum(e)), + // For now we only handle expansion for a tuple of enums. Here + // we map non-enum items to None and rely on `collect` to + // convert Vec> into Option>. + _ => ty.is_bool().then_some(ExtendedEnum::Bool), + } }) }) - .collect() + .collect::>>() + .and_then(|list| if list.is_empty() { None } else { Some(list) }) +} + +fn resolve_array_of_enum_def( + sema: &Semantics<'_, RootDatabase>, + expr: &ast::Expr, +) -> Option<(ExtendedEnum, usize)> { + sema.type_of_expr(expr)?.adjusted().as_array(sema.db).and_then(|(ty, len)| { + ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() { + Some(Adt::Enum(e)) => Some((lift_enum(e), len)), + _ => ty.is_bool().then_some((ExtendedEnum::Bool, len)), + }) + }) } fn build_pat( @@ -377,7 +420,6 @@ fn build_pat( } ast::StructKind::Unit => make::path_pat(path), }; - Some(pat) } ExtendedVariant::True => Some(ast::Pat::from(make::literal_pat("true"))), @@ -573,6 +615,86 @@ fn foo(a: bool) { ) } + #[test] + fn fill_boolean_array() { + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match [a]$0 { + } +} +"#, + r#" +fn foo(a: bool) { + match [a] { + $0[true] => todo!(), + [false] => todo!(), + } +} +"#, + ); + + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match [a,]$0 { + } +} +"#, + r#" +fn foo(a: bool) { + match [a,] { + $0[true] => todo!(), + [false] => todo!(), + } +} +"#, + ); + + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match [a, a]$0 { + [true, true] => todo!(), + } +} +"#, + r#" +fn foo(a: bool) { + match [a, a] { + [true, true] => todo!(), + $0[true, false] => todo!(), + [false, true] => todo!(), + [false, false] => todo!(), + } +} +"#, + ); + + check_assist( + add_missing_match_arms, + r#" +fn foo(a: bool) { + match [a, a]$0 { + } +} +"#, + r#" +fn foo(a: bool) { + match [a, a] { + $0[true, true] => todo!(), + [true, false] => todo!(), + [false, true] => todo!(), + [false, false] => todo!(), + } +} +"#, + ) + } + #[test] fn partial_fill_boolean_tuple() { check_assist( diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index d5b3296980c..a35983435c7 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -520,6 +520,15 @@ fn from_text(text: &str) -> ast::LiteralPat { } } +pub fn slice_pat(pats: impl IntoIterator) -> ast::SlicePat { + let pats_str = pats.into_iter().join(", "); + return from_text(&format!("[{pats_str}]")); + + fn from_text(text: &str) -> ast::SlicePat { + ast_from_text(&format!("fn f() {{ match () {{{text} => ()}} }}")) + } +} + /// Creates a tuple of patterns from an iterator of patterns. /// /// Invariant: `pats` must be length > 0