diff --git a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs index c0b7db332e2..b04bd6ba098 100644 --- a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs +++ b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs @@ -675,6 +675,194 @@ impl Clone for Foo { ) } + #[test] + fn add_custom_impl_partial_ord_record_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord +#[derive(Partial$0Ord)] +struct Foo { + bin: usize, +} +"#, + r#" +struct Foo { + bin: usize, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + self.bin.partial_cmp(other.bin) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_record_struct_multi_field() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord +#[derive(Partial$0Ord)] +struct Foo { + bin: usize, + bar: usize, + baz: usize, +} +"#, + r#" +struct Foo { + bin: usize, + bar: usize, + baz: usize, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + (self.bin, self.bar, self.baz).partial_cmp((other.bin, other.bar, other.baz)) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_tuple_struct() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord +#[derive(Partial$0Ord)] +struct Foo(usize, usize, usize); +"#, + r#" +struct Foo(usize, usize, usize); + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + (self.0, self.1, self.2).partial_cmp((other.0, other.1, other.2)) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord +#[derive(Partial$0Ord)] +enum Foo { + Bin, + Bar, + Baz, +} +"#, + r#" +enum Foo { + Bin, + Bar, + Baz, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)) + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_record_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord +#[derive(Partial$0Ord)] +enum Foo { + Bar { + bin: String, + }, + Baz { + qux: String, + fez: String, + }, + Qux {}, + Bin, +} +"#, + r#" +enum Foo { + Bar { + bin: String, + }, + Baz { + qux: String, + fez: String, + }, + Qux {}, + Bin, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin.partial_cmp(r_bin), + (Self::Baz { qux: l_qux, fez: l_fez }, Self::Baz { qux: r_qux, fez: r_fez }) => { + (l_qux, l_fez).partial_cmp((r_qux, r_fez)) + } + _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)), + } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_partial_ord_tuple_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: ord +#[derive(Partial$0Ord)] +enum Foo { + Bar(String), + Baz(String, String), + Qux(), + Bin, +} +"#, + r#" +enum Foo { + Bar(String), + Baz(String, String), + Qux(), + Bin, +} + +impl PartialOrd for Foo { + $0fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Bar(l0), Self::Bar(r0)) => l0.partial_cmp(r0), + (Self::Baz(l0, l1), Self::Baz(r0, r1)) => { + (l0, l1).partial_cmp((r0, r1)) + } + _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)), + } + } +} +"#, + ) + } + #[test] fn add_custom_impl_partial_eq_record_struct() { check_assist( diff --git a/crates/ide_assists/src/utils/gen_trait_fn_body.rs b/crates/ide_assists/src/utils/gen_trait_fn_body.rs index 6915460209b..c883e6fb11b 100644 --- a/crates/ide_assists/src/utils/gen_trait_fn_body.rs +++ b/crates/ide_assists/src/utils/gen_trait_fn_body.rs @@ -21,6 +21,7 @@ pub(crate) fn gen_trait_fn_body( "Default" => gen_default_impl(adt, func), "Hash" => gen_hash_impl(adt, func), "PartialEq" => gen_partial_eq(adt, func), + "PartialOrd" => gen_partial_ord(adt, func), _ => None, } } @@ -572,6 +573,200 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { Some(()) } +fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { + fn gen_partial_cmp_call(lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr { + let method = make::name_ref("partial_cmp"); + make::expr_method_call(lhs, method, make::arg_list(Some(rhs))) + } + fn gen_partial_cmp_call2(mut lhs: Vec, mut rhs: Vec) -> ast::Expr { + let (lhs, rhs) = match (lhs.len(), rhs.len()) { + (1, 1) => (lhs.pop().unwrap(), rhs.pop().unwrap()), + _ => (make::expr_tuple(lhs.into_iter()), make::expr_tuple(rhs.into_iter())), + }; + let method = make::name_ref("partial_cmp"); + make::expr_method_call(lhs, method, make::arg_list(Some(rhs))) + } + + fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField { + let pat = make::ext::simple_ident_pat(make::name(&pat_name)); + let name_ref = make::name_ref(field_name); + make::record_pat_field(name_ref, pat.into()) + } + + fn gen_record_pat(record_name: ast::Path, fields: Vec) -> ast::RecordPat { + let list = make::record_pat_field_list(fields); + make::record_pat_with_fields(record_name, list) + } + + fn gen_variant_path(variant: &ast::Variant) -> Option { + make::ext::path_from_idents(["Self", &variant.name()?.to_string()]) + } + + fn gen_tuple_field(field_name: &String) -> ast::Pat { + ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name))) + } + + // FIXME: return `None` if the trait carries a generic type; we can only + // generate this code `Self` for the time being. + + let body = match adt { + // `Hash` cannot be derived for unions, so no default impl can be provided. + ast::Adt::Union(_) => return None, + + ast::Adt::Enum(enum_) => { + // => std::mem::discriminant(self) == std::mem::discriminant(other) + let lhs_name = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone()))); + let rhs_name = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone()))); + let ord_check = gen_partial_cmp_call(lhs, rhs); + + let mut case_count = 0; + let mut arms = vec![]; + for variant in enum_.variant_list()?.variants() { + case_count += 1; + match variant.field_list() { + // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin, + Some(ast::FieldList::RecordFieldList(list)) => { + let mut l_pat_fields = vec![]; + let mut r_pat_fields = vec![]; + let mut l_fields = vec![]; + let mut r_fields = vec![]; + + for field in list.fields() { + let field_name = field.name()?.to_string(); + + let l_name = &format!("l_{}", field_name); + l_pat_fields.push(gen_record_pat_field(&field_name, &l_name)); + + let r_name = &format!("r_{}", field_name); + r_pat_fields.push(gen_record_pat_field(&field_name, &r_name)); + + let lhs = make::expr_path(make::ext::ident_path(l_name)); + let rhs = make::expr_path(make::ext::ident_path(r_name)); + l_fields.push(lhs); + r_fields.push(rhs); + } + + let left_pat = gen_record_pat(gen_variant_path(&variant)?, l_pat_fields); + let right_pat = gen_record_pat(gen_variant_path(&variant)?, r_pat_fields); + let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]); + + let len = l_fields.len(); + if len != 0 { + let mut expr = gen_partial_cmp_call2(l_fields, r_fields); + if len >= 2 { + expr = make::block_expr(None, Some(expr)) + .indent(ast::edit::IndentLevel(1)) + .into(); + } + arms.push(make::match_arm(Some(tuple_pat.into()), None, expr)); + } + } + + Some(ast::FieldList::TupleFieldList(list)) => { + let mut l_pat_fields = vec![]; + let mut r_pat_fields = vec![]; + let mut l_fields = vec![]; + let mut r_fields = vec![]; + + for (i, _) in list.fields().enumerate() { + let field_name = format!("{}", i); + + let l_name = format!("l{}", field_name); + l_pat_fields.push(gen_tuple_field(&l_name)); + + let r_name = format!("r{}", field_name); + r_pat_fields.push(gen_tuple_field(&r_name)); + + let lhs = make::expr_path(make::ext::ident_path(&l_name)); + let rhs = make::expr_path(make::ext::ident_path(&r_name)); + l_fields.push(lhs); + r_fields.push(rhs); + } + + let left_pat = + make::tuple_struct_pat(gen_variant_path(&variant)?, l_pat_fields); + let right_pat = + make::tuple_struct_pat(gen_variant_path(&variant)?, r_pat_fields); + let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]); + + let len = l_fields.len(); + if len != 0 { + let mut expr = gen_partial_cmp_call2(l_fields, r_fields); + if len >= 2 { + expr = make::block_expr(None, Some(expr)) + .indent(ast::edit::IndentLevel(1)) + .into(); + } + arms.push(make::match_arm(Some(tuple_pat.into()), None, expr)); + } + } + None => continue, + } + } + + let expr = match arms.len() { + 0 => ord_check, + _ => { + if case_count > arms.len() { + let lhs = make::wildcard_pat().into(); + arms.push(make::match_arm(Some(lhs), None, ord_check)); + } + + let match_target = make::expr_tuple(vec![lhs_name, rhs_name]); + let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1)); + make::expr_match(match_target, list) + } + }; + + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + ast::Adt::Struct(strukt) => match strukt.field_list() { + Some(ast::FieldList::RecordFieldList(field_list)) => { + let mut l_fields = vec![]; + let mut r_fields = vec![]; + for field in field_list.fields() { + let lhs = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_field(lhs, &field.name()?.to_string()); + let rhs = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_field(rhs, &field.name()?.to_string()); + l_fields.push(lhs); + r_fields.push(rhs); + } + + let expr = gen_partial_cmp_call2(l_fields, r_fields); + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + + Some(ast::FieldList::TupleFieldList(field_list)) => { + let mut l_fields = vec![]; + let mut r_fields = vec![]; + for (i, _) in field_list.fields().enumerate() { + let idx = format!("{}", i); + let lhs = make::expr_path(make::ext::ident_path("self")); + let lhs = make::expr_field(lhs, &idx); + let rhs = make::expr_path(make::ext::ident_path("other")); + let rhs = make::expr_field(rhs, &idx); + l_fields.push(lhs); + r_fields.push(rhs); + } + let expr = gen_partial_cmp_call2(l_fields, r_fields); + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + + // No fields in the body means there's nothing to hash. + None => { + let expr = make::expr_literal("true").into(); + make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1)) + } + }, + }; + + ted::replace(func.body()?.syntax(), body.clone_for_update().syntax()); + Some(()) +} + fn make_discriminant() -> Option { Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?)) }