From a40e390860987a23f9b899abc5947f1525d3709c Mon Sep 17 00:00:00 2001
From: Phil Ellison <phil.j.ellison@gmail.com>
Date: Sun, 11 Aug 2019 15:00:37 +0100
Subject: [PATCH] Check type rather than just name in ok-wrapping diagnostic.
 Add test for handling generic functions (which currently fails)

---
 crates/ra_hir/src/expr/validation.rs | 48 ++++++++++++++++++++++------
 crates/ra_hir/src/name.rs            |  2 ++
 crates/ra_ide_api/src/diagnostics.rs | 37 +++++++++++++++++++++
 3 files changed, 78 insertions(+), 9 deletions(-)

diff --git a/crates/ra_hir/src/expr/validation.rs b/crates/ra_hir/src/expr/validation.rs
index ca7db61bc43..339a7b84862 100644
--- a/crates/ra_hir/src/expr/validation.rs
+++ b/crates/ra_hir/src/expr/validation.rs
@@ -6,10 +6,13 @@ use ra_syntax::ast::{AstNode, RecordLit};
 use super::{Expr, ExprId, RecordLitField};
 use crate::{
     adt::AdtDef,
+    code_model::Enum,
     diagnostics::{DiagnosticSink, MissingFields, MissingOkInTailExpr},
     expr::AstPtr,
+    name,
+    path::{PathKind, PathSegment},
     ty::{InferenceResult, Ty, TypeCtor},
-    Function, HasSource, HirDatabase, Name, Path,
+    Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution
 };
 use ra_syntax::ast;
 
@@ -106,18 +109,45 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
             Some(m) => m,
             None => return,
         };
+
+        let std_result_path = Path {
+            kind: PathKind::Abs,
+            segments: vec![
+                PathSegment { name: name::STD, args_and_bindings: None },
+                PathSegment { name: name::RESULT_MOD, args_and_bindings: None },
+                PathSegment { name: name::RESULT_TYPE, args_and_bindings: None },
+            ]
+        };
+
+        let resolver = self.func.resolver(db);
+        let std_result_enum = match resolver.resolve_path_segments(db, &std_result_path).into_fully_resolved() {
+            PerNs { types: Some(Resolution::Def(ModuleDef::Enum(e))), .. } => e,
+            _ => return,
+        };
+
+        let std_result_type = std_result_enum.ty(db);
+
+        fn enum_from_type(ty: &Ty) -> Option<Enum> {
+            match ty {
+                Ty::Apply(t) => {
+                    match t.ctor {
+                        TypeCtor::Adt(AdtDef::Enum(e)) => Some(e),
+                        _ => None,
+                    }
+                }
+                _ => None
+            }
+        }
+
+        if enum_from_type(&mismatch.expected) != enum_from_type(&std_result_type) {
+            return;
+        }
+
         let ret = match &mismatch.expected {
             Ty::Apply(t) => t,
             _ => return,
         };
-        let ret_enum = match ret.ctor {
-            TypeCtor::Adt(AdtDef::Enum(e)) => e,
-            _ => return,
-        };
-        let enum_name = ret_enum.name(db);
-        if enum_name.is_none() || enum_name.unwrap().to_string() != "Result" {
-            return;
-        }
+
         let params = &ret.parameters;
         if params.len() == 2 && &params[0] == &mismatch.actual {
             let source_map = self.func.body_source_map(db);
diff --git a/crates/ra_hir/src/name.rs b/crates/ra_hir/src/name.rs
index 6d14eea8ecf..9c4822d917f 100644
--- a/crates/ra_hir/src/name.rs
+++ b/crates/ra_hir/src/name.rs
@@ -120,6 +120,8 @@ pub(crate) const TRY: Name = Name::new(SmolStr::new_inline_from_ascii(3, b"Try")
 pub(crate) const OK: Name = Name::new(SmolStr::new_inline_from_ascii(2, b"Ok"));
 pub(crate) const FUTURE_MOD: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"future"));
 pub(crate) const FUTURE_TYPE: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Future"));
+pub(crate) const RESULT_MOD: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"result"));
+pub(crate) const RESULT_TYPE: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Result"));
 pub(crate) const OUTPUT: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Output"));
 
 fn resolve_name(text: &SmolStr) -> SmolStr {
diff --git a/crates/ra_ide_api/src/diagnostics.rs b/crates/ra_ide_api/src/diagnostics.rs
index 5e25991c624..57454719c2e 100644
--- a/crates/ra_ide_api/src/diagnostics.rs
+++ b/crates/ra_ide_api/src/diagnostics.rs
@@ -281,6 +281,43 @@ fn div(x: i32, y: i32) -> Result<i32, String> {
         check_apply_diagnostic_fix_for_target_file("/main.rs", before, after);
     }
 
+    #[test]
+    fn test_wrap_return_type_handles_generic_functions() {
+        let before = r#"
+            //- /main.rs
+            use std::{default::Default, result::Result::{self, Ok, Err}};
+
+            fn div<T: Default, i32>(x: i32) -> Result<T, i32> {
+                if x == 0 {
+                    return Err(7);
+                }
+                T::default()
+            }
+
+            //- /std/lib.rs
+            pub mod result {
+                pub enum Result<T, E> { Ok(T), Err(E) }
+            }
+            pub mod default {
+                pub trait Default {
+                    fn default() -> Self;
+                }
+            }
+        "#;
+// The formatting here is a bit odd due to how the parse_fixture function works in test_utils -
+// it strips empty lines and leading whitespace. The important part of this test is that the final
+// `x / y` expr is now wrapped in `Ok(..)`
+        let after = r#"use std::{default::Default, result::Result::{self, Ok, Err}};
+fn div<T: Default>(x: i32) -> Result<T, i32> {
+    if x == 0 {
+        return Err(7);
+    }
+    Ok(T::default())
+}
+"#;
+        check_apply_diagnostic_fix_for_target_file("/main.rs", before, after);
+    }
+
     #[test]
     fn test_wrap_return_type_handles_type_aliases() {
         let before = r#"