Auto merge of #92048 - Urgau:num-midpoint, r=scottmcm

Add midpoint function for all integers and floating numbers

This pull-request adds the `midpoint` function to `{u,i}{8,16,32,64,128,size}`, `NonZeroU{8,16,32,64,size}` and `f{32,64}`.

This new function is analog to the [C++ midpoint](https://en.cppreference.com/w/cpp/numeric/midpoint) function, and basically compute `(a + b) / 2` with a rounding towards ~~`a`~~ negative infinity in the case of integers. Or simply said: `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a sufficiently-large signed integral type.

Note that unlike the C++ function this pull-request does not implement this function on pointers (`*const T` or `*mut T`). This could be implemented in a future pull-request if desire.

### Implementation

For `f32` and `f64` the implementation in based on the `libcxx` [one](18ab892ff7/libcxx/include/__numeric/midpoint.h (L65-L77)). I originally tried many different approach but all of them failed or lead me with a poor version of the `libcxx`. Note that `libstdc++` has a very similar one; Microsoft STL implementation is also basically the same as `libcxx`. It unfortunately doesn't seems like a better way exist.

For unsigned integers I created the macro `midpoint_impl!`, this macro has two branches:
 - The first one take `$SelfT` and is used when there is no unsigned integer with at least the double of bits. The code simply use this formula `a + (b - a) / 2` with the arguments in the correct order and signs to have the good rounding.
 - The second branch is used when a `$WideT` (at least double of bits as `$SelfT`) is provided, using a wider number means that no overflow can occur, this greatly improve the codegen (no branch and less instructions).

For signed integers the code basically forwards the signed numbers to the unsigned version of midpoint by mapping the signed numbers to their unsigned numbers (`ex: i8 [-128; 127] to [0; 255]`) and vice versa.
I originally created a version that worked directly on the signed numbers but the code was "ugly" and not understandable. Despite this mapping "overhead" the codegen is better than my most optimized version on signed integers.

~~Note that in the case of unsigned numbers I tried to be smart and used `#[cfg(target_pointer_width = "64")]` to determine if using the wide version was better or not by looking at the assembly on godbolt. This was applied to `u32`, `u64` and `usize` and doesn't change the behavior only the assembly code generated.~~
This commit is contained in:
bors 2023-05-14 19:33:02 +00:00
commit 18bfe5d8a9
10 changed files with 313 additions and 3 deletions

View file

@ -133,6 +133,7 @@
#![feature(const_maybe_uninit_assume_init)]
#![feature(const_maybe_uninit_uninit_array)]
#![feature(const_nonnull_new)]
#![feature(const_num_midpoint)]
#![feature(const_option)]
#![feature(const_option_ext)]
#![feature(const_pin)]

View file

@ -940,6 +940,42 @@ impl f32 {
}
}
/// Calculates the middle point of `self` and `rhs`.
///
/// This returns NaN when *either* argument is NaN or if a combination of
/// +inf and -inf is provided as arguments.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
/// assert_eq!(1f32.midpoint(4.0), 2.5);
/// assert_eq!((-5.5f32).midpoint(8.0), 1.25);
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
pub fn midpoint(self, other: f32) -> f32 {
const LO: f32 = f32::MIN_POSITIVE * 2.;
const HI: f32 = f32::MAX / 2.;
let (a, b) = (self, other);
let abs_a = a.abs_private();
let abs_b = b.abs_private();
if abs_a <= HI && abs_b <= HI {
// Overflow is impossible
(a + b) / 2.
} else if abs_a < LO {
// Not safe to halve a
a + (b / 2.)
} else if abs_b < LO {
// Not safe to halve b
(a / 2.) + b
} else {
// Not safe to halve a and b
(a / 2.) + (b / 2.)
}
}
/// Rounds toward zero and converts to any primitive integer type,
/// assuming that the value is finite and fits in that type.
///

View file

@ -951,6 +951,42 @@ impl f64 {
}
}
/// Calculates the middle point of `self` and `rhs`.
///
/// This returns NaN when *either* argument is NaN or if a combination of
/// +inf and -inf is provided as arguments.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
/// assert_eq!(1f64.midpoint(4.0), 2.5);
/// assert_eq!((-5.5f64).midpoint(8.0), 1.25);
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
pub fn midpoint(self, other: f64) -> f64 {
const LO: f64 = f64::MIN_POSITIVE * 2.;
const HI: f64 = f64::MAX / 2.;
let (a, b) = (self, other);
let abs_a = a.abs_private();
let abs_b = b.abs_private();
if abs_a <= HI && abs_b <= HI {
// Overflow is impossible
(a + b) / 2.
} else if abs_a < LO {
// Not safe to halve a
a + (b / 2.)
} else if abs_b < LO {
// Not safe to halve b
(a / 2.) + b
} else {
// Not safe to halve a and b
(a / 2.) + (b / 2.)
}
}
/// Rounds toward zero and converts to any primitive integer type,
/// assuming that the value is finite and fits in that type.
///

View file

@ -2332,6 +2332,44 @@ macro_rules! int_impl {
}
}
/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-1), -1);")]
#[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(0), -1);")]
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[rustc_allow_const_fn_unstable(const_num_midpoint)]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: Self) -> Self {
const U: $UnsignedT = <$SelfT>::MIN.unsigned_abs();
// Map an $SelfT to an $UnsignedT
// ex: i8 [-128; 127] to [0; 255]
const fn map(a: $SelfT) -> $UnsignedT {
(a as $UnsignedT) ^ U
}
// Map an $UnsignedT to an $SelfT
// ex: u8 [0; 255] to [-128; 127]
const fn demap(a: $UnsignedT) -> $SelfT {
(a ^ U) as $SelfT
}
demap(<$UnsignedT>::midpoint(map(self), map(rhs)))
}
/// Returns the logarithm of the number with respect to an arbitrary base,
/// rounded down.
///

View file

@ -95,6 +95,57 @@ depending on the target pointer size.
};
}
macro_rules! midpoint_impl {
($SelfT:ty, unsigned) => {
/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
#[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
// Use the well known branchless algorthim from Hacker's Delight to compute
// `(a + b) / 2` without overflowing: `((a ^ b) >> 1) + (a & b)`.
((self ^ rhs) >> 1) + (self & rhs)
}
};
($SelfT:ty, $WideT:ty, unsigned) => {
/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
#[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
((self as $WideT + rhs as $WideT) / 2) as $SelfT
}
};
}
macro_rules! widening_impl {
($SelfT:ty, $WideT:ty, $BITS:literal, unsigned) => {
/// Calculates the complete product `self * rhs` without the possibility to overflow.
@ -455,6 +506,7 @@ impl u8 {
bound_condition = "",
}
widening_impl! { u8, u16, 8, unsigned }
midpoint_impl! { u8, u16, unsigned }
/// Checks if the value is within the ASCII range.
///
@ -1066,6 +1118,7 @@ impl u16 {
bound_condition = "",
}
widening_impl! { u16, u32, 16, unsigned }
midpoint_impl! { u16, u32, unsigned }
/// Checks if the value is a Unicode surrogate code point, which are disallowed values for [`char`].
///
@ -1114,6 +1167,7 @@ impl u32 {
bound_condition = "",
}
widening_impl! { u32, u64, 32, unsigned }
midpoint_impl! { u32, u64, unsigned }
}
impl u64 {
@ -1137,6 +1191,7 @@ impl u64 {
bound_condition = "",
}
widening_impl! { u64, u128, 64, unsigned }
midpoint_impl! { u64, u128, unsigned }
}
impl u128 {
@ -1161,6 +1216,7 @@ impl u128 {
from_xe_bytes_doc = "",
bound_condition = "",
}
midpoint_impl! { u128, unsigned }
}
#[cfg(target_pointer_width = "16")]
@ -1185,6 +1241,7 @@ impl usize {
bound_condition = " on 16-bit targets",
}
widening_impl! { usize, u32, 16, unsigned }
midpoint_impl! { usize, u32, unsigned }
}
#[cfg(target_pointer_width = "32")]
@ -1209,6 +1266,7 @@ impl usize {
bound_condition = " on 32-bit targets",
}
widening_impl! { usize, u64, 32, unsigned }
midpoint_impl! { usize, u64, unsigned }
}
#[cfg(target_pointer_width = "64")]
@ -1233,6 +1291,7 @@ impl usize {
bound_condition = " on 64-bit targets",
}
widening_impl! { usize, u128, 64, unsigned }
midpoint_impl! { usize, u128, unsigned }
}
impl usize {

View file

@ -493,6 +493,43 @@ macro_rules! nonzero_unsigned_operations {
pub const fn ilog10(self) -> u32 {
super::int_log10::$Int(self.0)
}
/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("# use std::num::", stringify!($Ty), ";")]
///
/// # fn main() { test().unwrap(); }
/// # fn test() -> Option<()> {
#[doc = concat!("let one = ", stringify!($Ty), "::new(1)?;")]
#[doc = concat!("let two = ", stringify!($Ty), "::new(2)?;")]
#[doc = concat!("let four = ", stringify!($Ty), "::new(4)?;")]
///
/// assert_eq!(one.midpoint(four), two);
/// assert_eq!(four.midpoint(one), two);
/// # Some(())
/// # }
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[rustc_allow_const_fn_unstable(const_num_midpoint)]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: Self) -> Self {
// SAFETY: The only way to get `0` with midpoint is to have two opposite or
// near opposite numbers: (-5, 5), (0, 1), (0, 0) which is impossible because
// of the unsignedness of this number and also because $Ty is guaranteed to
// never being 0.
unsafe { $Ty::new_unchecked(self.get().midpoint(rhs.get())) }
}
}
)+
}

View file

@ -53,6 +53,7 @@
#![feature(maybe_uninit_uninit_array_transpose)]
#![feature(min_specialization)]
#![feature(numfmt)]
#![feature(num_midpoint)]
#![feature(step_trait)]
#![feature(str_internals)]
#![feature(std_internals)]

View file

@ -364,6 +364,32 @@ macro_rules! int_module {
assert_eq!((0 as $T).borrowing_sub($T::MIN, false), ($T::MIN, true));
assert_eq!((0 as $T).borrowing_sub($T::MIN, true), ($T::MAX, false));
}
#[test]
fn test_midpoint() {
assert_eq!(<$T>::midpoint(1, 3), 2);
assert_eq!(<$T>::midpoint(3, 1), 2);
assert_eq!(<$T>::midpoint(0, 0), 0);
assert_eq!(<$T>::midpoint(0, 2), 1);
assert_eq!(<$T>::midpoint(2, 0), 1);
assert_eq!(<$T>::midpoint(2, 2), 2);
assert_eq!(<$T>::midpoint(1, 4), 2);
assert_eq!(<$T>::midpoint(4, 1), 2);
assert_eq!(<$T>::midpoint(3, 4), 3);
assert_eq!(<$T>::midpoint(4, 3), 3);
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), -1);
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), -1);
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);
assert_eq!(<$T>::midpoint(<$T>::MIN, 6), <$T>::MIN / 2 + 3);
assert_eq!(<$T>::midpoint(6, <$T>::MIN), <$T>::MIN / 2 + 3);
assert_eq!(<$T>::midpoint(<$T>::MAX, 6), <$T>::MAX / 2 + 3);
assert_eq!(<$T>::midpoint(6, <$T>::MAX), <$T>::MAX / 2 + 3);
}
}
};
}

View file

@ -724,7 +724,7 @@ assume_usize_width! {
}
macro_rules! test_float {
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr) => {
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => {
mod $modname {
#[test]
fn min() {
@ -845,6 +845,38 @@ macro_rules! test_float {
assert!(($nan as $fty).maximum($nan).is_nan());
}
#[test]
fn midpoint() {
assert_eq!((0.5 as $fty).midpoint(0.5), 0.5);
assert_eq!((0.5 as $fty).midpoint(2.5), 1.5);
assert_eq!((3.0 as $fty).midpoint(4.0), 3.5);
assert_eq!((-3.0 as $fty).midpoint(4.0), 0.5);
assert_eq!((3.0 as $fty).midpoint(-4.0), -0.5);
assert_eq!((-3.0 as $fty).midpoint(-4.0), -3.5);
assert_eq!((0.0 as $fty).midpoint(0.0), 0.0);
assert_eq!((-0.0 as $fty).midpoint(-0.0), -0.0);
assert_eq!((-5.0 as $fty).midpoint(5.0), 0.0);
assert_eq!(($max as $fty).midpoint($min), 0.0);
assert_eq!(($min as $fty).midpoint($max), -0.0);
assert_eq!(($max as $fty).midpoint($min_pos), $max / 2.);
assert_eq!((-$max as $fty).midpoint($min_pos), -$max / 2.);
assert_eq!(($max as $fty).midpoint(-$min_pos), $max / 2.);
assert_eq!((-$max as $fty).midpoint(-$min_pos), -$max / 2.);
assert_eq!(($min_pos as $fty).midpoint($max), $max / 2.);
assert_eq!(($min_pos as $fty).midpoint(-$max), -$max / 2.);
assert_eq!((-$min_pos as $fty).midpoint($max), $max / 2.);
assert_eq!((-$min_pos as $fty).midpoint(-$max), -$max / 2.);
assert_eq!(($max as $fty).midpoint($max), $max);
assert_eq!(($min_pos as $fty).midpoint($min_pos), $min_pos);
assert_eq!((-$min_pos as $fty).midpoint(-$min_pos), -$min_pos);
assert_eq!(($max as $fty).midpoint(5.0), $max / 2.0 + 2.5);
assert_eq!(($max as $fty).midpoint(-5.0), $max / 2.0 - 2.5);
assert_eq!(($inf as $fty).midpoint($inf), $inf);
assert_eq!(($neginf as $fty).midpoint($neginf), $neginf);
assert!(($nan as $fty).midpoint(1.0).is_nan());
assert!((1.0 as $fty).midpoint($nan).is_nan());
assert!(($nan as $fty).midpoint($nan).is_nan());
}
#[test]
fn rem_euclid() {
let a: $fty = 42.0;
assert!($inf.rem_euclid(a).is_nan());
@ -867,5 +899,23 @@ macro_rules! test_float {
};
}
test_float!(f32, f32, f32::INFINITY, f32::NEG_INFINITY, f32::NAN);
test_float!(f64, f64, f64::INFINITY, f64::NEG_INFINITY, f64::NAN);
test_float!(
f32,
f32,
f32::INFINITY,
f32::NEG_INFINITY,
f32::NAN,
f32::MIN,
f32::MAX,
f32::MIN_POSITIVE
);
test_float!(
f64,
f64,
f64::INFINITY,
f64::NEG_INFINITY,
f64::NAN,
f64::MIN,
f64::MAX,
f64::MIN_POSITIVE
);

View file

@ -252,6 +252,32 @@ macro_rules! uint_module {
assert_eq!($T::MAX.borrowing_sub(0, true), ($T::MAX - 1, false));
assert_eq!($T::MAX.borrowing_sub($T::MAX, true), ($T::MAX, true));
}
#[test]
fn test_midpoint() {
assert_eq!(<$T>::midpoint(1, 3), 2);
assert_eq!(<$T>::midpoint(3, 1), 2);
assert_eq!(<$T>::midpoint(0, 0), 0);
assert_eq!(<$T>::midpoint(0, 2), 1);
assert_eq!(<$T>::midpoint(2, 0), 1);
assert_eq!(<$T>::midpoint(2, 2), 2);
assert_eq!(<$T>::midpoint(1, 4), 2);
assert_eq!(<$T>::midpoint(4, 1), 2);
assert_eq!(<$T>::midpoint(3, 4), 3);
assert_eq!(<$T>::midpoint(4, 3), 3);
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), (<$T>::MAX - <$T>::MIN) / 2);
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), (<$T>::MAX - <$T>::MIN) / 2);
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);
assert_eq!(<$T>::midpoint(<$T>::MIN, 6), <$T>::MIN / 2 + 3);
assert_eq!(<$T>::midpoint(6, <$T>::MIN), <$T>::MIN / 2 + 3);
assert_eq!(<$T>::midpoint(<$T>::MAX, 6), (<$T>::MAX - <$T>::MIN) / 2 + 3);
assert_eq!(<$T>::midpoint(6, <$T>::MAX), (<$T>::MAX - <$T>::MIN) / 2 + 3);
}
}
};
}