diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index e95e1755363..b700864f7dd 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -101,6 +101,18 @@ fn bit_op() { check_number(r#"const GOAL: i8 = 1 << 8"#, 0); } +#[test] +fn floating_point() { + check_number( + r#"const GOAL: f64 = 2.0 + 3.0 * 5.5 - 8.;"#, + i128::from_le_bytes(pad16(&f64::to_le_bytes(10.5), true)), + ); + check_number( + r#"const GOAL: f32 = 2.0 + 3.0 * 5.5 - 8.;"#, + i128::from_le_bytes(pad16(&f32::to_le_bytes(10.5), true)), + ); +} + #[test] fn casts() { check_number(r#"const GOAL: usize = 12 as *const i32 as usize"#, 12); diff --git a/crates/hir-ty/src/mir.rs b/crates/hir-ty/src/mir.rs index 4846bbfe5fd..3ac208666a7 100644 --- a/crates/hir-ty/src/mir.rs +++ b/crates/hir-ty/src/mir.rs @@ -649,6 +649,20 @@ pub enum BinOp { Offset, } +impl BinOp { + fn run_compare(&self, l: T, r: T) -> bool { + match self { + BinOp::Ge => l >= r, + BinOp::Gt => l > r, + BinOp::Le => l <= r, + BinOp::Lt => l < r, + BinOp::Eq => l == r, + BinOp::Ne => l != r, + x => panic!("`run_compare` called on operator {x:?}"), + } + } +} + impl Display for BinOp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(match self { diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 01d27f2672c..9811cd9192b 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -48,7 +48,7 @@ macro_rules! from_bytes { ($ty:tt, $value:expr) => { ($ty::from_le_bytes(match ($value).try_into() { Ok(x) => x, - Err(_) => return Err(MirEvalError::TypeError("mismatched size")), + Err(_) => return Err(MirEvalError::TypeError(stringify!(mismatched size in constructing $ty))), })) }; } @@ -797,70 +797,122 @@ impl Evaluator<'_> { lc = self.read_memory(Address::from_bytes(lc)?, size)?; rc = self.read_memory(Address::from_bytes(rc)?, size)?; } - let is_signed = matches!(ty.as_builtin(), Some(BuiltinType::Int(_))); - let l128 = i128::from_le_bytes(pad16(lc, is_signed)); - let r128 = i128::from_le_bytes(pad16(rc, is_signed)); - match op { - BinOp::Ge | BinOp::Gt | BinOp::Le | BinOp::Lt | BinOp::Eq | BinOp::Ne => { - let r = match op { - BinOp::Ge => l128 >= r128, - BinOp::Gt => l128 > r128, - BinOp::Le => l128 <= r128, - BinOp::Lt => l128 < r128, - BinOp::Eq => l128 == r128, - BinOp::Ne => l128 != r128, - _ => unreachable!(), - }; - let r = r as u8; - Owned(vec![r]) - } - BinOp::BitAnd - | BinOp::BitOr - | BinOp::BitXor - | BinOp::Add - | BinOp::Mul - | BinOp::Div - | BinOp::Rem - | BinOp::Sub => { - let r = match op { - BinOp::Add => l128.overflowing_add(r128).0, - BinOp::Mul => l128.overflowing_mul(r128).0, - BinOp::Div => l128.checked_div(r128).ok_or_else(|| { - MirEvalError::Panic(format!("Overflow in {op:?}")) - })?, - BinOp::Rem => l128.checked_rem(r128).ok_or_else(|| { - MirEvalError::Panic(format!("Overflow in {op:?}")) - })?, - BinOp::Sub => l128.overflowing_sub(r128).0, - BinOp::BitAnd => l128 & r128, - BinOp::BitOr => l128 | r128, - BinOp::BitXor => l128 ^ r128, - _ => unreachable!(), - }; - let r = r.to_le_bytes(); - for &k in &r[lc.len()..] { - if k != 0 && (k != 255 || !is_signed) { - return Err(MirEvalError::Panic(format!("Overflow in {op:?}"))); + if let TyKind::Scalar(chalk_ir::Scalar::Float(f)) = ty.kind(Interner) { + match f { + chalk_ir::FloatTy::F32 => { + let l = from_bytes!(f32, lc); + let r = from_bytes!(f32, rc); + match op { + BinOp::Ge + | BinOp::Gt + | BinOp::Le + | BinOp::Lt + | BinOp::Eq + | BinOp::Ne => { + let r = op.run_compare(l, r) as u8; + Owned(vec![r]) + } + BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => { + let r = match op { + BinOp::Add => l + r, + BinOp::Sub => l - r, + BinOp::Mul => l * r, + BinOp::Div => l / r, + _ => unreachable!(), + }; + Owned(r.to_le_bytes().into()) + } + x => not_supported!( + "invalid binop {x:?} on floating point operators" + ), + } + } + chalk_ir::FloatTy::F64 => { + let l = from_bytes!(f64, lc); + let r = from_bytes!(f64, rc); + match op { + BinOp::Ge + | BinOp::Gt + | BinOp::Le + | BinOp::Lt + | BinOp::Eq + | BinOp::Ne => { + let r = op.run_compare(l, r) as u8; + Owned(vec![r]) + } + BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => { + let r = match op { + BinOp::Add => l + r, + BinOp::Sub => l - r, + BinOp::Mul => l * r, + BinOp::Div => l / r, + _ => unreachable!(), + }; + Owned(r.to_le_bytes().into()) + } + x => not_supported!( + "invalid binop {x:?} on floating point operators" + ), } } - Owned(r[0..lc.len()].into()) } - BinOp::Shl | BinOp::Shr => { - let shift_amount = if r128 < 0 { - return Err(MirEvalError::Panic(format!("Overflow in {op:?}"))); - } else if r128 > 128 { - return Err(MirEvalError::Panic(format!("Overflow in {op:?}"))); - } else { - r128 as u8 - }; - let r = match op { - BinOp::Shl => l128 << shift_amount, - BinOp::Shr => l128 >> shift_amount, - _ => unreachable!(), - }; - Owned(r.to_le_bytes()[0..lc.len()].into()) + } else { + let is_signed = matches!(ty.as_builtin(), Some(BuiltinType::Int(_))); + let l128 = i128::from_le_bytes(pad16(lc, is_signed)); + let r128 = i128::from_le_bytes(pad16(rc, is_signed)); + match op { + BinOp::Ge | BinOp::Gt | BinOp::Le | BinOp::Lt | BinOp::Eq | BinOp::Ne => { + let r = op.run_compare(l128, r128) as u8; + Owned(vec![r]) + } + BinOp::BitAnd + | BinOp::BitOr + | BinOp::BitXor + | BinOp::Add + | BinOp::Mul + | BinOp::Div + | BinOp::Rem + | BinOp::Sub => { + let r = match op { + BinOp::Add => l128.overflowing_add(r128).0, + BinOp::Mul => l128.overflowing_mul(r128).0, + BinOp::Div => l128.checked_div(r128).ok_or_else(|| { + MirEvalError::Panic(format!("Overflow in {op:?}")) + })?, + BinOp::Rem => l128.checked_rem(r128).ok_or_else(|| { + MirEvalError::Panic(format!("Overflow in {op:?}")) + })?, + BinOp::Sub => l128.overflowing_sub(r128).0, + BinOp::BitAnd => l128 & r128, + BinOp::BitOr => l128 | r128, + BinOp::BitXor => l128 ^ r128, + _ => unreachable!(), + }; + let r = r.to_le_bytes(); + for &k in &r[lc.len()..] { + if k != 0 && (k != 255 || !is_signed) { + return Err(MirEvalError::Panic(format!("Overflow in {op:?}"))); + } + } + Owned(r[0..lc.len()].into()) + } + BinOp::Shl | BinOp::Shr => { + let shift_amount = if r128 < 0 { + return Err(MirEvalError::Panic(format!("Overflow in {op:?}"))); + } else if r128 > 128 { + return Err(MirEvalError::Panic(format!("Overflow in {op:?}"))); + } else { + r128 as u8 + }; + let r = match op { + BinOp::Shl => l128 << shift_amount, + BinOp::Shr => l128 >> shift_amount, + _ => unreachable!(), + }; + Owned(r.to_le_bytes()[0..lc.len()].into()) + } + BinOp::Offset => not_supported!("offset binop"), } - BinOp::Offset => not_supported!("offset binop"), } } Rvalue::Discriminant(p) => {