Enforce builtin binop expectations on single references

Also don't enforce them on non-builtin types
This commit is contained in:
Ryo Yoshida 2023-01-17 19:48:25 +09:00
parent fa874627f0
commit 461435adab
No known key found for this signature in database
GPG Key ID: E25698A930586171
3 changed files with 243 additions and 33 deletions

View File

@ -1,6 +1,6 @@
//! Various extensions traits for Chalk types.
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, UintTy};
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, TyVariableKind, UintTy};
use hir_def::{
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinType, BuiltinUint},
generics::TypeOrConstParamData,
@ -18,6 +18,8 @@
pub trait TyExt {
fn is_unit(&self) -> bool;
fn is_integral(&self) -> bool;
fn is_floating_point(&self) -> bool;
fn is_never(&self) -> bool;
fn is_unknown(&self) -> bool;
fn is_ty_var(&self) -> bool;
@ -51,6 +53,21 @@ fn is_unit(&self) -> bool {
matches!(self.kind(Interner), TyKind::Tuple(0, _))
}
fn is_integral(&self) -> bool {
matches!(
self.kind(Interner),
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
| TyKind::InferenceVar(_, TyVariableKind::Integer)
)
}
fn is_floating_point(&self) -> bool {
matches!(
self.kind(Interner),
TyKind::Scalar(Scalar::Float(_)) | TyKind::InferenceVar(_, TyVariableKind::Float)
)
}
fn is_never(&self) -> bool {
matches!(self.kind(Interner), TyKind::Never)
}

View File

@ -1071,11 +1071,9 @@ fn infer_overloadable_binop(
let ret_ty = self.normalize_associated_types_in(ret_ty);
// use knowledge of built-in binary ops, which can sometimes help inference
if let Some(builtin_rhs) = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()) {
self.unify(&builtin_rhs, &rhs_ty);
}
if let Some(builtin_ret) = self.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty) {
if self.is_builtin_binop(&lhs_ty, &rhs_ty, op) {
// use knowledge of built-in binary ops, which can sometimes help inference
let builtin_ret = self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op);
self.unify(&builtin_ret, &ret_ty);
}
@ -1545,7 +1543,10 @@ fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty)
fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Option<Ty> {
Some(match op {
BinaryOp::LogicOp(..) => TyKind::Scalar(Scalar::Bool).intern(Interner),
BinaryOp::Assignment { op: None } => lhs_ty,
BinaryOp::Assignment { op: None } => {
stdx::never!("Simple assignment operator is not binary op.");
return None;
}
BinaryOp::CmpOp(CmpOp::Eq { .. }) => match self
.resolve_ty_shallow(&lhs_ty)
.kind(Interner)
@ -1565,6 +1566,126 @@ fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Opt
})
}
/// Dereferences a single level of immutable referencing.
fn deref_ty_if_possible(&mut self, ty: &Ty) -> Ty {
let ty = self.resolve_ty_shallow(ty);
match ty.kind(Interner) {
TyKind::Ref(Mutability::Not, _, inner) => self.resolve_ty_shallow(inner),
_ => ty,
}
}
/// Enforces expectations on lhs type and rhs type depending on the operator and returns the
/// output type of the binary op.
fn enforce_builtin_binop_types(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> Ty {
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
let lhs = self.deref_ty_if_possible(lhs);
let rhs = self.deref_ty_if_possible(rhs);
let (op, is_assign) = match op {
BinaryOp::Assignment { op: Some(inner) } => (BinaryOp::ArithOp(inner), true),
_ => (op, false),
};
let output_ty = match op {
BinaryOp::LogicOp(_) => {
let bool_ = self.result.standard_types.bool_.clone();
self.unify(&lhs, &bool_);
self.unify(&rhs, &bool_);
bool_
}
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
// result type is same as LHS always
lhs
}
BinaryOp::ArithOp(_) => {
// LHS, RHS, and result will have the same type
self.unify(&lhs, &rhs);
lhs
}
BinaryOp::CmpOp(_) => {
// LHS and RHS will have the same type
self.unify(&lhs, &rhs);
self.result.standard_types.bool_.clone()
}
BinaryOp::Assignment { op: None } => {
stdx::never!("Simple assignment operator is not binary op.");
lhs
}
BinaryOp::Assignment { .. } => unreachable!("handled above"),
};
if is_assign {
self.result.standard_types.unit.clone()
} else {
output_ty
}
}
fn is_builtin_binop(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> bool {
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
let lhs = self.deref_ty_if_possible(lhs);
let rhs = self.deref_ty_if_possible(rhs);
let op = match op {
BinaryOp::Assignment { op: Some(inner) } => BinaryOp::ArithOp(inner),
_ => op,
};
match op {
BinaryOp::LogicOp(_) => true,
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
lhs.is_integral() && rhs.is_integral()
}
BinaryOp::ArithOp(
ArithOp::Add | ArithOp::Sub | ArithOp::Mul | ArithOp::Div | ArithOp::Rem,
) => {
lhs.is_integral() && rhs.is_integral()
|| lhs.is_floating_point() && rhs.is_floating_point()
}
BinaryOp::ArithOp(ArithOp::BitAnd | ArithOp::BitOr | ArithOp::BitXor) => {
lhs.is_integral() && rhs.is_integral()
|| lhs.is_floating_point() && rhs.is_floating_point()
|| matches!(
(lhs.kind(Interner), rhs.kind(Interner)),
(TyKind::Scalar(Scalar::Bool), TyKind::Scalar(Scalar::Bool))
)
}
BinaryOp::CmpOp(_) => {
let is_scalar = |kind| {
matches!(
kind,
&TyKind::Scalar(_)
| TyKind::FnDef(..)
| TyKind::Function(_)
| TyKind::Raw(..)
| TyKind::InferenceVar(
_,
TyVariableKind::Integer | TyVariableKind::Float
)
)
};
is_scalar(lhs.kind(Interner)) && is_scalar(rhs.kind(Interner))
}
BinaryOp::Assignment { op: None } => {
stdx::never!("Simple assignment operator is not binary op.");
false
}
BinaryOp::Assignment { .. } => unreachable!("handled above"),
}
}
fn with_breakable_ctx<T>(
&mut self,
kind: BreakableKind,

View File

@ -3507,14 +3507,9 @@ trait Request {
fn bin_op_adt_with_rhs_primitive() {
check_infer_with_mismatches(
r#"
#[lang = "add"]
pub trait Add<Rhs = Self> {
type Output;
fn add(self, rhs: Rhs) -> Self::Output;
}
//- minicore: add
struct Wrapper(u32);
impl Add<u32> for Wrapper {
impl core::ops::Add<u32> for Wrapper {
type Output = Self;
fn add(self, rhs: u32) -> Wrapper {
Wrapper(rhs)
@ -3527,29 +3522,106 @@ fn main(){
}"#,
expect![[r#"
72..76 'self': Self
78..81 'rhs': Rhs
192..196 'self': Wrapper
198..201 'rhs': u32
219..247 '{ ... }': Wrapper
229..236 'Wrapper': Wrapper(u32) -> Wrapper
229..241 'Wrapper(rhs)': Wrapper
237..240 'rhs': u32
259..345 '{ ...um; }': ()
269..276 'wrapped': Wrapper
279..286 'Wrapper': Wrapper(u32) -> Wrapper
279..290 'Wrapper(10)': Wrapper
287..289 '10': u32
300..303 'num': u32
311..312 '2': u32
322..325 'res': Wrapper
328..335 'wrapped': Wrapper
328..341 'wrapped + num': Wrapper
338..341 'num': u32
95..99 'self': Wrapper
101..104 'rhs': u32
122..150 '{ ... }': Wrapper
132..139 'Wrapper': Wrapper(u32) -> Wrapper
132..144 'Wrapper(rhs)': Wrapper
140..143 'rhs': u32
162..248 '{ ...um; }': ()
172..179 'wrapped': Wrapper
182..189 'Wrapper': Wrapper(u32) -> Wrapper
182..193 'Wrapper(10)': Wrapper
190..192 '10': u32
203..206 'num': u32
214..215 '2': u32
225..228 'res': Wrapper
231..238 'wrapped': Wrapper
231..244 'wrapped + num': Wrapper
241..244 'num': u32
"#]],
)
}
#[test]
fn builtin_binop_expectation_works_on_single_reference() {
check_types(
r#"
//- minicore: add
use core::ops::Add;
impl Add<i32> for i32 { type Output = i32 }
impl Add<&i32> for i32 { type Output = i32 }
impl Add<u32> for u32 { type Output = u32 }
impl Add<&u32> for u32 { type Output = u32 }
struct V<T>;
impl<T> V<T> {
fn default() -> Self { loop {} }
fn get(&self, _: &T) -> &T { loop {} }
}
fn take_u32(_: u32) {}
fn minimized() {
let v = V::default();
let p = v.get(&0);
//^ &u32
take_u32(42 + p);
}
"#,
);
}
#[test]
fn no_builtin_binop_expectation_for_general_ty_var() {
// FIXME: Ideally type mismatch should be reported on `take_u32(42 - p)`.
check_types(
r#"
//- minicore: add
use core::ops::Add;
impl Add<i32> for i32 { type Output = i32; }
impl Add<&i32> for i32 { type Output = i32; }
// This is needed to prevent chalk from giving unique solution to `i32: Add<&?0>` after applying
// fallback to integer type variable for `42`.
impl Add<&()> for i32 { type Output = (); }
struct V<T>;
impl<T> V<T> {
fn default() -> Self { loop {} }
fn get(&self) -> &T { loop {} }
}
fn take_u32(_: u32) {}
fn minimized() {
let v = V::default();
let p = v.get();
//^ &{unknown}
take_u32(42 + p);
}
"#,
);
}
#[test]
fn no_builtin_binop_expectation_for_non_builtin_types() {
check_no_mismatches(
r#"
//- minicore: default, eq
struct S;
impl Default for S { fn default() -> Self { S } }
impl Default for i32 { fn default() -> Self { 0 } }
impl PartialEq<S> for i32 { fn eq(&self, _: &S) -> bool { true } }
impl PartialEq<i32> for i32 { fn eq(&self, _: &S) -> bool { true } }
fn take_s(_: S) {}
fn test() {
let s = Default::default();
let _eq = 0 == s;
take_s(s);
}
"#,
)
}
#[test]
fn array_length() {
check_infer(