From 80bb15ed91d728e39acc2dcb5a3ed5db59862687 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 2 Mar 2024 21:40:18 -0500 Subject: [PATCH] Add compiler support for parsing `f16` and `f128` --- compiler/rustc_ast/src/util/literal.rs | 2 + .../rustc_middle/src/mir/interpret/value.rs | 22 +++++++++- compiler/rustc_middle/src/ty/consts/int.rs | 44 ++++++++++++++++++- compiler/rustc_mir_build/src/build/mod.rs | 8 ++-- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/compiler/rustc_ast/src/util/literal.rs b/compiler/rustc_ast/src/util/literal.rs index 5ed2762b726..a17c7708e4a 100644 --- a/compiler/rustc_ast/src/util/literal.rs +++ b/compiler/rustc_ast/src/util/literal.rs @@ -276,8 +276,10 @@ fn filtered_float_lit( Some(suffix) => LitKind::Float( symbol, ast::LitFloatType::Suffixed(match suffix { + sym::f16 => ast::FloatTy::F16, sym::f32 => ast::FloatTy::F32, sym::f64 => ast::FloatTy::F64, + sym::f128 => ast::FloatTy::F128, _ => return Err(LitError::InvalidFloatSuffix(suffix)), }), ), diff --git a/compiler/rustc_middle/src/mir/interpret/value.rs b/compiler/rustc_middle/src/mir/interpret/value.rs index e937c17c8ac..24d4a79c7d7 100644 --- a/compiler/rustc_middle/src/mir/interpret/value.rs +++ b/compiler/rustc_middle/src/mir/interpret/value.rs @@ -3,7 +3,7 @@ use std::fmt; use either::{Either, Left, Right}; use rustc_apfloat::{ - ieee::{Double, Single}, + ieee::{Double, Half, Quad, Single}, Float, }; use rustc_macros::HashStable; @@ -201,6 +201,11 @@ impl Scalar { Self::from_int(i, cx.data_layout().pointer_size) } + #[inline] + pub fn from_f16(f: Half) -> Self { + Scalar::Int(f.into()) + } + #[inline] pub fn from_f32(f: Single) -> Self { Scalar::Int(f.into()) @@ -211,6 +216,11 @@ impl Scalar { Scalar::Int(f.into()) } + #[inline] + pub fn from_f128(f: Quad) -> Self { + Scalar::Int(f.into()) + } + /// This is almost certainly not the method you want! You should dispatch on the type /// and use `to_{u8,u16,...}`/`scalar_to_ptr` to perform ptr-to-int / int-to-ptr casts as needed. /// @@ -422,6 +432,11 @@ impl<'tcx, Prov: Provenance> Scalar { Ok(F::from_bits(self.to_bits(Size::from_bits(F::BITS))?)) } + #[inline] + pub fn to_f16(self) -> InterpResult<'tcx, Half> { + self.to_float() + } + #[inline] pub fn to_f32(self) -> InterpResult<'tcx, Single> { self.to_float() @@ -431,4 +446,9 @@ impl<'tcx, Prov: Provenance> Scalar { pub fn to_f64(self) -> InterpResult<'tcx, Double> { self.to_float() } + + #[inline] + pub fn to_f128(self) -> InterpResult<'tcx, Quad> { + self.to_float() + } } diff --git a/compiler/rustc_middle/src/ty/consts/int.rs b/compiler/rustc_middle/src/ty/consts/int.rs index a70e01645f4..71d7dfd8b01 100644 --- a/compiler/rustc_middle/src/ty/consts/int.rs +++ b/compiler/rustc_middle/src/ty/consts/int.rs @@ -1,4 +1,4 @@ -use rustc_apfloat::ieee::{Double, Single}; +use rustc_apfloat::ieee::{Double, Half, Quad, Single}; use rustc_apfloat::Float; use rustc_errors::{DiagArgValue, IntoDiagArg}; use rustc_serialize::{Decodable, Decoder, Encodable, Encoder}; @@ -369,6 +369,11 @@ impl ScalarInt { Ok(F::from_bits(self.to_bits(Size::from_bits(F::BITS))?)) } + #[inline] + pub fn try_to_f16(self) -> Result { + self.try_to_float() + } + #[inline] pub fn try_to_f32(self) -> Result { self.try_to_float() @@ -378,6 +383,11 @@ impl ScalarInt { pub fn try_to_f64(self) -> Result { self.try_to_float() } + + #[inline] + pub fn try_to_f128(self) -> Result { + self.try_to_float() + } } macro_rules! from { @@ -450,6 +460,22 @@ impl TryFrom for char { } } +impl From for ScalarInt { + #[inline] + fn from(f: Half) -> Self { + // We trust apfloat to give us properly truncated data. + Self { data: f.to_bits(), size: NonZero::new((Half::BITS / 8) as u8).unwrap() } + } +} + +impl TryFrom for Half { + type Error = Size; + #[inline] + fn try_from(int: ScalarInt) -> Result { + int.to_bits(Size::from_bytes(2)).map(Self::from_bits) + } +} + impl From for ScalarInt { #[inline] fn from(f: Single) -> Self { @@ -482,6 +508,22 @@ impl TryFrom for Double { } } +impl From for ScalarInt { + #[inline] + fn from(f: Quad) -> Self { + // We trust apfloat to give us properly truncated data. + Self { data: f.to_bits(), size: NonZero::new((Quad::BITS / 8) as u8).unwrap() } + } +} + +impl TryFrom for Quad { + type Error = Size; + #[inline] + fn try_from(int: ScalarInt) -> Result { + int.to_bits(Size::from_bytes(16)).map(Self::from_bits) + } +} + impl fmt::Debug for ScalarInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Dispatch to LowerHex below. diff --git a/compiler/rustc_mir_build/src/build/mod.rs b/compiler/rustc_mir_build/src/build/mod.rs index 45954bdb114..615a5e69319 100644 --- a/compiler/rustc_mir_build/src/build/mod.rs +++ b/compiler/rustc_mir_build/src/build/mod.rs @@ -1,7 +1,7 @@ use crate::build::expr::as_place::PlaceBuilder; use crate::build::scope::DropKind; use itertools::Itertools; -use rustc_apfloat::ieee::{Double, Single}; +use rustc_apfloat::ieee::{Double, Half, Quad, Single}; use rustc_apfloat::Float; use rustc_ast::attr; use rustc_data_structures::fx::FxHashMap; @@ -1053,7 +1053,8 @@ pub(crate) fn parse_float_into_scalar( ) -> Option { let num = num.as_str(); match float_ty { - ty::FloatTy::F16 => unimplemented!("f16_f128"), + // FIXME(f16_f128): When available, compare to the library parser as with `f32` and `f64` + ty::FloatTy::F16 => num.parse::().ok().map(Scalar::from_f16), ty::FloatTy::F32 => { let Ok(rust_f) = num.parse::() else { return None }; let mut f = num @@ -1100,7 +1101,8 @@ pub(crate) fn parse_float_into_scalar( Some(Scalar::from_f64(f)) } - ty::FloatTy::F128 => unimplemented!("f16_f128"), + // FIXME(f16_f128): When available, compare to the library parser as with `f32` and `f64` + ty::FloatTy::F128 => num.parse::().ok().map(Scalar::from_f128), } }