6220: implement binary operator overloading type inference r=flodiebold a=ruabmbua

Extend type inference of *binary operator expression*, by adding support for operator overloads.

Before this merge request, the type inference of binary expressions could only resolve operations done on built-in primitive types. This merge requests adds a code path, which is executed in case the built-in inference could not get any results. It resolves the proper operator overload trait in *core::ops* via lang items, and then resolves the associated *Output* type.

```rust
struct V2([f32; 2]);

#[lang = "add"]
pub trait Add<Rhs = Self> {
    /// The resulting type after applying the `+` operator.
    type Output;

    /// Performs the `+` operation.
    #[must_use]
    fn add(self, rhs: Rhs) -> Self::Output;
}

impl Add<V2> for V2 {
    type Output = V2;

    fn add(self, rhs: V2) -> V2 {
        let x = self.0[0] + rhs.0[0];
        let y = self.0[1] + rhs.0[1];
        V2([x, y])
    }
}

fn test() {
    let va = V2([0.0, 1.0]);
    let vb = V2([0.0, 1.0]);

    let r = va + vb; // This infers to V2 now
}
```

There is a problem with operator overloads, which do not explicitly set the *Rhs* type parameter in the respective impl block. 

**Example:**

```rust
impl Add for V2 {
    type Output = V2;

    fn add(self, rhs: V2) -> V2 {
        let x = self.0[0] + rhs.0[0];
        let y = self.0[1] + rhs.0[1];
        V2([x, y])
    }
}
```

In this case, the trait solver does not realize, that the *Rhs* type parameter is actually self in the context of the impl block. This stops type inference in its tracks, and it can not resolve the associated *Output* type.

I guess we can still merge this back, because it increases the amount of resolved types, and does not regress anything (in the tests).

Somewhat blocked by https://github.com/rust-analyzer/rust-analyzer/issues/5685
Resolves  https://github.com/rust-analyzer/rust-analyzer/issues/5544

Co-authored-by: Roland Ruckerbauer <roland.rucky@gmail.com>
This commit is contained in:
bors[bot] 2020-10-15 18:02:27 +00:00 committed by GitHub
commit 0d45802d67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 5 deletions

View File

@ -22,7 +22,7 @@ use arena::map::ArenaMap;
use hir_def::{
body::Body,
data::{ConstData, FunctionData, StaticData},
expr::{BindingAnnotation, ExprId, PatId},
expr::{ArithOp, BinaryOp, BindingAnnotation, ExprId, PatId},
lang_item::LangItemTarget,
path::{path, Path},
resolver::{HasResolver, Resolver, TypeNs},
@ -586,6 +586,28 @@ impl<'a> InferenceContext<'a> {
self.db.trait_data(trait_).associated_type_by_name(&name![Output])
}
fn resolve_binary_op_output(&self, bop: &BinaryOp) -> Option<TypeAliasId> {
let lang_item = match bop {
BinaryOp::ArithOp(aop) => match aop {
ArithOp::Add => "add",
ArithOp::Sub => "sub",
ArithOp::Mul => "mul",
ArithOp::Div => "div",
ArithOp::Shl => "shl",
ArithOp::Shr => "shr",
ArithOp::Rem => "rem",
ArithOp::BitXor => "bitxor",
ArithOp::BitOr => "bitor",
ArithOp::BitAnd => "bitand",
},
_ => return None,
};
let trait_ = self.resolve_lang_item(lang_item)?.as_trait();
self.db.trait_data(trait_?).associated_type_by_name(&name![Output])
}
fn resolve_boxed_box(&self) -> Option<AdtId> {
let struct_ = self.resolve_lang_item("owned_box")?.as_struct()?;
Some(struct_.into())

View File

@ -12,6 +12,7 @@ use hir_def::{
};
use hir_expand::name::{name, Name};
use syntax::ast::RangeOp;
use test_utils::mark;
use crate::{
autoderef, method_resolution, op,
@ -531,13 +532,22 @@ impl<'a> InferenceContext<'a> {
_ => Expectation::none(),
};
let lhs_ty = self.infer_expr(*lhs, &lhs_expectation);
// FIXME: find implementation of trait corresponding to operation
// symbol and resolve associated `Output` type
let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty.clone());
let rhs_ty = self.infer_expr(*rhs, &Expectation::has_type(rhs_expectation));
// FIXME: similar as above, return ty is often associated trait type
op::binary_op_return_ty(*op, lhs_ty, rhs_ty)
let ret = op::binary_op_return_ty(*op, lhs_ty.clone(), rhs_ty.clone());
if ret == Ty::Unknown {
mark::hit!(infer_expr_inner_binary_operator_overload);
self.resolve_associated_type_with_params(
lhs_ty,
self.resolve_binary_op_output(op),
&[rhs_ty],
)
} else {
ret
}
}
_ => Ty::Unknown,
},

View File

@ -1,4 +1,5 @@
use expect_test::expect;
use test_utils::mark;
use super::{check_infer, check_types};
@ -2225,3 +2226,91 @@ fn generic_default_depending_on_other_type_arg_forward() {
"#]],
);
}
#[test]
fn infer_operator_overload() {
mark::check!(infer_expr_inner_binary_operator_overload);
check_infer(
r#"
struct V2([f32; 2]);
#[lang = "add"]
pub trait Add<Rhs = Self> {
/// The resulting type after applying the `+` operator.
type Output;
/// Performs the `+` operation.
#[must_use]
fn add(self, rhs: Rhs) -> Self::Output;
}
impl Add<V2> for V2 {
type Output = V2;
fn add(self, rhs: V2) -> V2 {
let x = self.0[0] + rhs.0[0];
let y = self.0[1] + rhs.0[1];
V2([x, y])
}
}
fn test() {
let va = V2([0.0, 1.0]);
let vb = V2([0.0, 1.0]);
let r = va + vb;
}
"#,
expect![[r#"
207..211 'self': Self
213..216 'rhs': Rhs
299..303 'self': V2
305..308 'rhs': V2
320..422 '{ ... }': V2
334..335 'x': f32
338..342 'self': V2
338..344 'self.0': [f32; _]
338..347 'self.0[0]': {unknown}
338..358 'self.0...s.0[0]': f32
345..346 '0': i32
350..353 'rhs': V2
350..355 'rhs.0': [f32; _]
350..358 'rhs.0[0]': {unknown}
356..357 '0': i32
372..373 'y': f32
376..380 'self': V2
376..382 'self.0': [f32; _]
376..385 'self.0[1]': {unknown}
376..396 'self.0...s.0[1]': f32
383..384 '1': i32
388..391 'rhs': V2
388..393 'rhs.0': [f32; _]
388..396 'rhs.0[1]': {unknown}
394..395 '1': i32
406..408 'V2': V2([f32; _]) -> V2
406..416 'V2([x, y])': V2
409..415 '[x, y]': [f32; _]
410..411 'x': f32
413..414 'y': f32
436..519 '{ ... vb; }': ()
446..448 'va': V2
451..453 'V2': V2([f32; _]) -> V2
451..465 'V2([0.0, 1.0])': V2
454..464 '[0.0, 1.0]': [f32; _]
455..458 '0.0': f32
460..463 '1.0': f32
475..477 'vb': V2
480..482 'V2': V2([f32; _]) -> V2
480..494 'V2([0.0, 1.0])': V2
483..493 '[0.0, 1.0]': [f32; _]
484..487 '0.0': f32
489..492 '1.0': f32
505..506 'r': V2
509..511 'va': V2
509..516 'va + vb': V2
514..516 'vb': V2
"#]],
);
}