From 842ac87747c4a6f8002ada6bab04d97320d206fc Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Thu, 13 Jan 2022 21:20:17 -0500 Subject: [PATCH] Use bitmask trait --- crates/core_simd/src/masks.rs | 22 ++----- crates/core_simd/src/masks/bitmask.rs | 12 +--- crates/core_simd/src/masks/full_masks.rs | 35 ++--------- crates/core_simd/src/masks/to_bitmask.rs | 78 ++++++++++++++++++++++++ crates/core_simd/tests/masks.rs | 6 +- 5 files changed, 93 insertions(+), 60 deletions(-) create mode 100644 crates/core_simd/src/masks/to_bitmask.rs diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index ae1fef53da8..22514728ffa 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -12,8 +12,10 @@ )] mod mask_impl; -use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount}; +mod to_bitmask; +pub use to_bitmask::ToBitMask; + +use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; @@ -216,22 +218,6 @@ where } } - /// Convert this mask to a bitmask, with one bit set per lane. - #[cfg(feature = "generic_const_exprs")] - #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask(self) -> [u8; LaneCount::::BITMASK_LEN] { - self.0.to_bitmask() - } - - /// Convert a bitmask to a mask. - #[cfg(feature = "generic_const_exprs")] - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask(bitmask: [u8; LaneCount::::BITMASK_LEN]) -> Self { - Self(mask_impl::Mask::from_bitmask(bitmask)) - } - /// Returns true if any lane is set, or false otherwise. #[inline] #[must_use = "method returns a new bool and does not mutate the original value"] diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index b4217dc87ba..f20f83ecb38 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -115,20 +115,14 @@ where unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask(self) -> [u8; LaneCount::::BITMASK_LEN] { - // Safety: these are the same type and we are laundering the generic + pub unsafe fn to_bitmask_intrinsic(self) -> U { unsafe { core::mem::transmute_copy(&self.0) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask(bitmask: [u8; LaneCount::::BITMASK_LEN]) -> Self { - // Safety: these are the same type and we are laundering the generic - Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData) + pub unsafe fn from_bitmask_intrinsic(bitmask: U) -> Self { + unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) } } #[inline] diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index e5bb784bb91..b20b0a4b708 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -109,41 +109,16 @@ where unsafe { Mask(intrinsics::simd_cast(self.0)) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask(self) -> [u8; LaneCount::::BITMASK_LEN] { - unsafe { - let mut bitmask: [u8; LaneCount::::BITMASK_LEN] = - intrinsics::simd_bitmask(self.0); - - // There is a bug where LLVM appears to implement this operation with the wrong - // bit order. - // TODO fix this in a better way - if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { - *x = x.reverse_bits(); - } - } - - bitmask - } + pub unsafe fn to_bitmask_intrinsic(self) -> U { + // Safety: caller must only return bitmask types + unsafe { intrinsics::simd_bitmask(self.0) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask(mut bitmask: [u8; LaneCount::::BITMASK_LEN]) -> Self { + pub unsafe fn from_bitmask_intrinsic(bitmask: U) -> Self { + // Safety: caller must only pass bitmask types unsafe { - // There is a bug where LLVM appears to implement this operation with the wrong - // bit order. - // TODO fix this in a better way - if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { - *x = x.reverse_bits(); - } - } - Self::from_int_unchecked(intrinsics::simd_select_bitmask( bitmask, Self::splat(true).to_int(), diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs new file mode 100644 index 00000000000..3a9f89f19eb --- /dev/null +++ b/crates/core_simd/src/masks/to_bitmask.rs @@ -0,0 +1,78 @@ +use super::{mask_impl, Mask, MaskElement}; + +/// Converts masks to and from bitmasks. +/// +/// In a bitmask, each bit represents if the corresponding lane in the mask is set. +pub trait ToBitMask { + /// Converts a mask to a bitmask. + fn to_bitmask(self) -> BitMask; + + /// Converts a bitmask to a mask. + fn from_bitmask(bitmask: BitMask) -> Self; +} + +macro_rules! impl_integer_intrinsic { + { $(unsafe impl ToBitMask<$int:ty> for Mask<_, $lanes:literal>)* } => { + $( + impl ToBitMask<$int> for Mask { + fn to_bitmask(self) -> $int { + unsafe { self.0.to_bitmask_intrinsic() } + } + + fn from_bitmask(bitmask: $int) -> Self { + unsafe { Self(mask_impl::Mask::from_bitmask_intrinsic(bitmask)) } + } + } + )* + } +} + +impl_integer_intrinsic! { + unsafe impl ToBitMask for Mask<_, 8> + unsafe impl ToBitMask for Mask<_, 16> + unsafe impl ToBitMask for Mask<_, 32> + unsafe impl ToBitMask for Mask<_, 64> +} + +macro_rules! impl_integer_via { + { $(impl ToBitMask<$int:ty, via $via:ty> for Mask<_, $lanes:literal>)* } => { + $( + impl ToBitMask<$int> for Mask { + fn to_bitmask(self) -> $int { + let bitmask: $via = self.to_bitmask(); + bitmask as _ + } + + fn from_bitmask(bitmask: $int) -> Self { + Self::from_bitmask(bitmask as $via) + } + } + )* + } +} + +impl_integer_via! { + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 8> + + impl ToBitMask for Mask<_, 16> + impl ToBitMask for Mask<_, 16> + + impl ToBitMask for Mask<_, 32> +} + +#[cfg(target_pointer_width = "32")] +impl_integer_via! { + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 16> + impl ToBitMask for Mask<_, 32> +} + +#[cfg(target_pointer_width = "64")] +impl_integer_via! { + impl ToBitMask for Mask<_, 8> + impl ToBitMask for Mask<_, 16> + impl ToBitMask for Mask<_, 32> + impl ToBitMask for Mask<_, 64> +} diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 6a8ecd33a73..965c0fa2635 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -68,16 +68,16 @@ macro_rules! test_mask_api { assert_eq!(core_simd::Mask::<$type, 8>::from_int(int), mask); } - #[cfg(feature = "generic_const_exprs")] #[test] fn roundtrip_bitmask_conversion() { + use core_simd::ToBitMask; let values = [ true, false, false, true, false, false, true, false, true, true, false, false, false, false, false, true, ]; let mask = core_simd::Mask::<$type, 16>::from_array(values); - let bitmask = mask.to_bitmask(); - assert_eq!(bitmask, [0b01001001, 0b10000011]); + let bitmask: u16 = mask.to_bitmask(); + assert_eq!(bitmask, 0b1000001101001001); assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask); } }