Auto merge of #115827 - eduardosm:miri-sse-reduce-code-dup, r=RalfJung

miri: reduce code duplication in some SSE/SSE2 intrinsics

Reduces code duplication in the Miri implementation of some SSE and SSE2 using generics and rustc_const_eval helper functions.

There are also some other minor changes.

r? `@RalfJung`
This commit is contained in:
bors 2023-09-20 12:07:26 +00:00
commit 78e74d959b
6 changed files with 337 additions and 526 deletions

View file

@ -173,6 +173,16 @@ impl<Prov> Scalar<Prov> {
.unwrap_or_else(|| bug!("Signed value {:#x} does not fit in {} bits", i, size.bits()))
}
#[inline]
pub fn from_i8(i: i8) -> Self {
Self::from_int(i, Size::from_bits(8))
}
#[inline]
pub fn from_i16(i: i16) -> Self {
Self::from_int(i, Size::from_bits(16))
}
#[inline]
pub fn from_i32(i: i32) -> Self {
Self::from_int(i, Size::from_bits(32))
@ -400,15 +410,19 @@ impl<'tcx, Prov: Provenance> Scalar<Prov> {
Ok(i64::try_from(b).unwrap())
}
#[inline]
pub fn to_float<F: Float>(self) -> InterpResult<'tcx, F> {
// Going through `to_uint` to check size and truncation.
Ok(F::from_bits(self.to_uint(Size::from_bits(F::BITS))?))
}
#[inline]
pub fn to_f32(self) -> InterpResult<'tcx, Single> {
// Going through `u32` to check size and truncation.
Ok(Single::from_bits(self.to_u32()?.into()))
self.to_float()
}
#[inline]
pub fn to_f64(self) -> InterpResult<'tcx, Double> {
// Going through `u64` to check size and truncation.
Ok(Double::from_bits(self.to_u64()?.into()))
self.to_float()
}
}

View file

@ -14,7 +14,7 @@ use rustc_middle::mir;
use rustc_middle::ty::{
self,
layout::{IntegerExt as _, LayoutOf, TyAndLayout},
Ty, TyCtxt,
IntTy, Ty, TyCtxt, UintTy,
};
use rustc_span::{def_id::CrateNum, sym, Span, Symbol};
use rustc_target::abi::{Align, FieldIdx, FieldsShape, Integer, Size, Variants};
@ -1066,6 +1066,24 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
),
}
}
/// Returns an integer type that is twice wide as `ty`
fn get_twice_wide_int_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
let this = self.eval_context_ref();
match ty.kind() {
// Unsigned
ty::Uint(UintTy::U8) => this.tcx.types.u16,
ty::Uint(UintTy::U16) => this.tcx.types.u32,
ty::Uint(UintTy::U32) => this.tcx.types.u64,
ty::Uint(UintTy::U64) => this.tcx.types.u128,
// Signed
ty::Int(IntTy::I8) => this.tcx.types.i16,
ty::Int(IntTy::I16) => this.tcx.types.i32,
ty::Int(IntTy::I32) => this.tcx.types.i64,
ty::Int(IntTy::I64) => this.tcx.types.i128,
_ => span_bug!(this.cur_span(), "unexpected type: {ty:?}"),
}
}
}
impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
@ -1151,3 +1169,20 @@ pub fn get_local_crates(tcx: TyCtxt<'_>) -> Vec<CrateNum> {
pub fn target_os_is_unix(target_os: &str) -> bool {
matches!(target_os, "linux" | "macos" | "freebsd" | "android")
}
pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
// SIMD uses all-1 as pattern for "true". In two's complement,
// -1 has all its bits set to one and `from_int` will truncate or
// sign-extend it to `size` as required.
let val = if b { -1 } else { 0 };
Scalar::from_int(val, size)
}
pub(crate) fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
let val = elem.to_scalar().to_int(elem.layout.size)?;
Ok(match val {
0 => false,
-1 => true,
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
})
}

View file

@ -1,10 +1,10 @@
use rustc_apfloat::{Float, Round};
use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
use rustc_middle::{mir, ty, ty::FloatTy};
use rustc_target::abi::{Endian, HasDataLayout, Size};
use rustc_target::abi::{Endian, HasDataLayout};
use crate::*;
use helpers::check_arg_count;
use helpers::{bool_to_simd_element, check_arg_count, simd_element_to_bool};
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
@ -612,21 +612,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
}
fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
// SIMD uses all-1 as pattern for "true"
let val = if b { -1 } else { 0 };
Scalar::from_int(val, size)
}
fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
let val = elem.to_scalar().to_int(elem.layout.size)?;
Ok(match val {
0 => false,
-1 => true,
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
})
}
fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
assert!(idx < vec_len);
match endianness {

View file

@ -1,4 +1,8 @@
use crate::InterpResult;
use rustc_middle::mir;
use rustc_target::abi::Size;
use crate::*;
use helpers::bool_to_simd_element;
pub(super) mod sse;
pub(super) mod sse2;
@ -43,3 +47,155 @@ impl FloatCmpOp {
}
}
}
#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
/// <https://www.felixcloutier.com/x86/minps>
/// <https://www.felixcloutier.com/x86/minsd>
/// <https://www.felixcloutier.com/x86/minpd>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxss>
/// <https://www.felixcloutier.com/x86/maxps>
/// <https://www.felixcloutier.com/x86/maxsd>
/// <https://www.felixcloutier.com/x86/maxpd>
Max,
}
/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
this: &crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Arith(which) => {
let (res, _overflow, _ty) = this.overflowing_binary_op(which, left, right)?;
Ok(res)
}
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_float::<F>()?;
let right = right.to_scalar().to_float::<F>()?;
// FIXME: Make sure that these operations match the semantics
// of cmpps/cmpss/cmppd/cmpsd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
}
FloatBinOp::Min => {
let left_scalar = left.to_scalar();
let left = left_scalar.to_float::<F>()?;
let right_scalar = right.to_scalar();
let right = right_scalar.to_float::<F>()?;
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
// is true when `x` is either +0 or -0.
if (left == F::ZERO && right == F::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(right_scalar)
} else {
Ok(left_scalar)
}
}
FloatBinOp::Max => {
let left_scalar = left.to_scalar();
let left = left_scalar.to_float::<F>()?;
let right_scalar = right.to_scalar();
let right = right_scalar.to_float::<F>()?;
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
// is true when `x` is either +0 or -0.
if (left == F::ZERO && right == F::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(right_scalar)
} else {
Ok(left_scalar)
}
}
}
}
/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let res0 = bin_op_float::<F>(
this,
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())
}
/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = bin_op_float::<F>(this, which, &left, &right)?;
this.write_scalar(res, &dest)?;
}
Ok(())
}

View file

@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi;
use rand::Rng as _;
use super::FloatCmpOp;
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use crate::*;
use shims::foreign_items::EmulateByNameResult;
@ -45,7 +45,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
_ => unreachable!(),
};
bin_op_ss(this, which, left, right, dest)?;
bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm_min_ps and _mm_max_ps functions.
// Note that the semantics are a bit different from Rust simd_min
@ -62,7 +62,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
_ => unreachable!(),
};
bin_op_ps(this, which, left, right, dest)?;
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm_{sqrt,rcp,rsqrt}_ss functions.
// Performs the operations on the first component of `op` and
@ -106,7 +106,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"llvm.x86.sse.cmp.ss",
)?);
bin_op_ss(this, which, left, right, dest)?;
bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp_ps function.
// Performs a comparison operation on each component of `left`
@ -121,7 +121,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"llvm.x86.sse.cmp.ps",
)?);
bin_op_ps(this, which, left, right, dest)?;
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm_{,u}comi{eq,lt,le,gt,ge,neq}_ss functions.
// Compares the first component of `left` and `right` and returns
@ -154,9 +154,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
};
this.write_scalar(Scalar::from_i32(i32::from(res)), dest)?;
}
// Use to implement _mm_cvtss_si32 and _mm_cvttss_si32.
// Converts the first component of `op` from f32 to i32.
"cvtss2si" | "cvttss2si" => {
// Use to implement the _mm_cvtss_si32, _mm_cvttss_si32,
// _mm_cvtss_si64 and _mm_cvttss_si64 functions.
// Converts the first component of `op` from f32 to i32/i64.
"cvtss2si" | "cvttss2si" | "cvtss2si64" | "cvttss2si64" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (op, _) = this.operand_to_simd(op)?;
@ -165,51 +166,26 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let rnd = match unprefixed_name {
// "current SSE rounding mode", assume nearest
// https://www.felixcloutier.com/x86/cvtss2si
"cvtss2si" => rustc_apfloat::Round::NearestTiesToEven,
"cvtss2si" | "cvtss2si64" => rustc_apfloat::Round::NearestTiesToEven,
// always truncate
// https://www.felixcloutier.com/x86/cvttss2si
"cvttss2si" => rustc_apfloat::Round::TowardZero,
"cvttss2si" | "cvttss2si64" => rustc_apfloat::Round::TowardZero,
_ => unreachable!(),
};
let res = this.float_to_int_checked(op, dest.layout.ty, rnd).unwrap_or_else(|| {
// Fallback to minimum acording to SSE semantics.
Scalar::from_i32(i32::MIN)
Scalar::from_int(dest.layout.size.signed_int_min(), dest.layout.size)
});
this.write_scalar(res, dest)?;
}
// Use to implement _mm_cvtss_si64 and _mm_cvttss_si64.
// Converts the first component of `op` from f32 to i64.
"cvtss2si64" | "cvttss2si64" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (op, _) = this.operand_to_simd(op)?;
let op = this.read_scalar(&this.project_index(&op, 0)?)?.to_f32()?;
let rnd = match unprefixed_name {
// "current SSE rounding mode", assume nearest
// https://www.felixcloutier.com/x86/cvtss2si
"cvtss2si64" => rustc_apfloat::Round::NearestTiesToEven,
// always truncate
// https://www.felixcloutier.com/x86/cvttss2si
"cvttss2si64" => rustc_apfloat::Round::TowardZero,
_ => unreachable!(),
};
let res = this.float_to_int_checked(op, dest.layout.ty, rnd).unwrap_or_else(|| {
// Fallback to minimum acording to SSE semantics.
Scalar::from_i64(i64::MIN)
});
this.write_scalar(res, dest)?;
}
// Used to implement the _mm_cvtsi32_ss function.
// Converts `right` from i32 to f32. Returns a SIMD vector with
// Used to implement the _mm_cvtsi32_ss and _mm_cvtsi64_ss functions.
// Converts `right` from i32/i64 to f32. Returns a SIMD vector with
// the result in the first component and the remaining components
// are copied from `left`.
// https://www.felixcloutier.com/x86/cvtsi2ss
"cvtsi2ss" => {
"cvtsi2ss" | "cvtsi642ss" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
@ -218,42 +194,17 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert_eq!(dest_len, left_len);
let right = this.read_scalar(right)?.to_i32()?;
let res0 = Scalar::from_f32(Single::from_i128(right.into()).value);
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
let right = this.read_immediate(right)?;
let dest0 = this.project_index(&dest, 0)?;
let res0 = this.int_to_int_or_float(&right, dest0.layout.ty)?;
this.write_immediate(res0, &dest0)?;
for i in 1..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let dest = this.project_index(&dest, i)?;
this.write_immediate(*left, &dest)?;
}
}
// Used to implement the _mm_cvtsi64_ss function.
// Converts `right` from i64 to f32. Returns a SIMD vector with
// the result in the first component and the remaining components
// are copied from `left`.
// https://www.felixcloutier.com/x86/cvtsi2ss
"cvtsi642ss" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
let right = this.read_scalar(right)?.to_i64()?;
let res0 = Scalar::from_f32(Single::from_i128(right.into()).value);
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let dest = this.project_index(&dest, i)?;
this.write_immediate(*left, &dest)?;
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
}
// Used to implement the _mm_movemask_ps function.
@ -281,148 +232,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
}
#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
/// <https://www.felixcloutier.com/x86/minps>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxss>
/// <https://www.felixcloutier.com/x86/maxps>
Max,
}
/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_f32<'tcx>(
this: &crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Arith(which) => {
let (res, _, _) = this.overflowing_binary_op(which, left, right)?;
Ok(res)
}
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_f32()?;
let right = right.to_scalar().to_f32()?;
// FIXME: Make sure that these operations match the semantics of cmpps
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(Scalar::from_u32(if res { u32::MAX } else { 0 }))
}
FloatBinOp::Min => {
let left = left.to_scalar().to_f32()?;
let right = right.to_scalar().to_f32()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Single::ZERO && right == Single::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(Scalar::from_f32(right))
} else {
Ok(Scalar::from_f32(left))
}
}
FloatBinOp::Max => {
let left = left.to_scalar().to_f32()?;
let right = right.to_scalar().to_f32()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Single::ZERO && right == Single::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(Scalar::from_f32(right))
} else {
Ok(Scalar::from_f32(left))
}
}
}
}
/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_ss<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let res0 = bin_op_f32(
this,
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let dest = this.project_index(&dest, i)?;
this.write_immediate(*left, &dest)?;
}
Ok(())
}
/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_ps<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = bin_op_f32(this, which, &left, &right)?;
this.write_scalar(res, &dest)?;
}
Ok(())
}
#[derive(Copy, Clone)]
enum FloatUnaryOp {
/// sqrt(x)
@ -510,10 +319,11 @@ fn unary_op_ss<'tcx>(
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
let op = this.read_immediate(&this.project_index(&op, i)?)?;
let dest = this.project_index(&dest, i)?;
this.write_immediate(*op, &dest)?;
this.copy_op(
&this.project_index(&op, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())

View file

@ -1,14 +1,14 @@
use rustc_apfloat::{
ieee::{Double, Single},
Float as _, FloatConvert as _,
Float as _,
};
use rustc_middle::mir;
use rustc_middle::ty::layout::LayoutOf as _;
use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::abi::Size;
use rustc_target::spec::abi::Abi;
use super::FloatCmpOp;
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
use crate::*;
use shims::foreign_items::EmulateByNameResult;
@ -37,9 +37,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
// Intrinsincs sufixed with "epiX" or "epuX" operate with X-bit signed or unsigned
// vectors.
match unprefixed_name {
// Used to implement the _mm_avg_epu8 function.
// Averages packed unsigned 8-bit integers in `left` and `right`.
"pavg.b" => {
// Used to implement the _mm_avg_epu8 and _mm_avg_epu16 functions.
// Averages packed unsigned 8/16-bit integers in `left` and `right`.
"pavg.b" | "pavg.w" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
@ -51,23 +51,45 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
// Values are expanded from u8 to u16, so adds cannot overflow.
let res = u16::from(left)
.checked_add(u16::from(right))
.unwrap()
.checked_add(1)
.unwrap()
/ 2;
this.write_scalar(Scalar::from_u8(res.try_into().unwrap()), &dest)?;
// Widen the operands to avoid overflow
let twice_wide_ty = this.get_twice_wide_int_ty(left.layout.ty);
let twice_wide_layout = this.layout_of(twice_wide_ty)?;
let left = this.int_to_int_or_float(&left, twice_wide_ty)?;
let right = this.int_to_int_or_float(&right, twice_wide_ty)?;
// Calculate left + right + 1
let (added, _overflow, _ty) = this.overflowing_binary_op(
mir::BinOp::Add,
&ImmTy::from_immediate(left, twice_wide_layout),
&ImmTy::from_immediate(right, twice_wide_layout),
)?;
let (added, _overflow, _ty) = this.overflowing_binary_op(
mir::BinOp::Add,
&ImmTy::from_scalar(added, twice_wide_layout),
&ImmTy::from_uint(1u32, twice_wide_layout),
)?;
// Calculate (left + right + 1) / 2
let (divided, _overflow, _ty) = this.overflowing_binary_op(
mir::BinOp::Div,
&ImmTy::from_scalar(added, twice_wide_layout),
&ImmTy::from_uint(2u32, twice_wide_layout),
)?;
// Narrow back to the original type
let res = this.int_to_int_or_float(
&ImmTy::from_scalar(divided, twice_wide_layout),
dest.layout.ty,
)?;
this.write_immediate(res, &dest)?;
}
}
// Used to implement the _mm_avg_epu16 function.
// Averages packed unsigned 16-bit integers in `left` and `right`.
"pavg.w" => {
// Used to implement the _mm_mulhi_epi16 and _mm_mulhi_epu16 functions.
"pmulh.w" | "pmulhu.w" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
@ -79,62 +101,35 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?;
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u16()?;
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
// Values are expanded from u16 to u32, so adds cannot overflow.
let res = u32::from(left)
.checked_add(u32::from(right))
.unwrap()
.checked_add(1)
.unwrap()
/ 2;
this.write_scalar(Scalar::from_u16(res.try_into().unwrap()), &dest)?;
}
}
// Used to implement the _mm_mulhi_epi16 function.
"pmulh.w" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
// Widen the operands to avoid overflow
let twice_wide_ty = this.get_twice_wide_int_ty(left.layout.ty);
let twice_wide_layout = this.layout_of(twice_wide_ty)?;
let left = this.int_to_int_or_float(&left, twice_wide_ty)?;
let right = this.int_to_int_or_float(&right, twice_wide_ty)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
// Multiply
let (multiplied, _overflow, _ty) = this.overflowing_binary_op(
mir::BinOp::Mul,
&ImmTy::from_immediate(left, twice_wide_layout),
&ImmTy::from_immediate(right, twice_wide_layout),
)?;
// Keep the high half
let (high, _overflow, _ty) = this.overflowing_binary_op(
mir::BinOp::Shr,
&ImmTy::from_scalar(multiplied, twice_wide_layout),
&ImmTy::from_uint(dest.layout.size.bits(), twice_wide_layout),
)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_i16()?;
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_i16()?;
let dest = this.project_index(&dest, i)?;
// Values are expanded from i16 to i32, so multiplication cannot overflow.
let res = i32::from(left).checked_mul(i32::from(right)).unwrap() >> 16;
this.write_scalar(Scalar::from_int(res, Size::from_bits(16)), &dest)?;
}
}
// Used to implement the _mm_mulhi_epu16 function.
"pmulhu.w" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?;
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u16()?;
let dest = this.project_index(&dest, i)?;
// Values are expanded from u16 to u32, so multiplication cannot overflow.
let res = u32::from(left).checked_mul(u32::from(right)).unwrap() >> 16;
this.write_scalar(Scalar::from_u16(res.try_into().unwrap()), &dest)?;
// Narrow back to the original type
let res = this.int_to_int_or_float(
&ImmTy::from_scalar(high, twice_wide_layout),
dest.layout.ty,
)?;
this.write_immediate(res, &dest)?;
}
}
// Used to implement the _mm_mul_epu32 function.
@ -431,11 +426,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let right_res =
i8::try_from(right).unwrap_or(if right < 0 { i8::MIN } else { i8::MAX });
this.write_scalar(Scalar::from_int(left_res, Size::from_bits(8)), &left_dest)?;
this.write_scalar(
Scalar::from_int(right_res, Size::from_bits(8)),
&right_dest,
)?;
this.write_scalar(Scalar::from_i8(left_res.try_into().unwrap()), &left_dest)?;
this.write_scalar(Scalar::from_i8(right_res.try_into().unwrap()), &right_dest)?;
}
}
// Used to implement the _mm_packus_epi16 function.
@ -469,7 +461,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
}
// Used to implement the _mm_packs_epi32 function.
// Converts two 16-bit integer vectors to a single 8-bit integer
// Converts two 32-bit integer vectors to a single 16-bit integer
// vector with signed saturation.
"packssdw.128" => {
let [left, right] =
@ -495,9 +487,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let right_res =
i16::try_from(right).unwrap_or(if right < 0 { i16::MIN } else { i16::MAX });
this.write_scalar(Scalar::from_int(left_res, Size::from_bits(16)), &left_dest)?;
this.write_scalar(Scalar::from_i16(left_res.try_into().unwrap()), &left_dest)?;
this.write_scalar(
Scalar::from_int(right_res, Size::from_bits(16)),
Scalar::from_i16(right_res.try_into().unwrap()),
&right_dest,
)?;
}
@ -517,7 +509,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
_ => unreachable!(),
};
bin_op_sd(this, which, left, right, dest)?;
bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
}
// Used to implement _mm_min_pd and _mm_max_pd functions.
// Note that the semantics are a bit different from Rust simd_min
@ -534,7 +526,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
_ => unreachable!(),
};
bin_op_pd(this, which, left, right, dest)?;
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
// Used to implement _mm_sqrt_sd functions.
// Performs the operations on the first component of `op` and
@ -593,7 +585,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"llvm.x86.sse2.cmp.sd",
)?);
bin_op_sd(this, which, left, right, dest)?;
bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
}
// Used to implement the _mm_cmp*_pd functions.
// Performs a comparison operation on each component of `left`
@ -608,7 +600,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"llvm.x86.sse2.cmp.pd",
)?);
bin_op_pd(this, which, left, right, dest)?;
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
// Used to implement _mm_{,u}comi{eq,lt,le,gt,ge,neq}_sd functions.
// Compares the first component of `left` and `right` and returns
@ -641,52 +633,31 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
};
this.write_scalar(Scalar::from_i32(i32::from(res)), dest)?;
}
// Used to implement the _mm_cvtpd_ps function.
// Converts packed f32 to packed f64.
"cvtpd2ps" => {
// Used to implement the _mm_cvtpd_ps and _mm_cvtps_pd functions.
// Converts packed f32/f64 to packed f64/f32.
"cvtpd2ps" | "cvtps2pd" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
// op is f64x2, dest is f32x4
assert_eq!(op_len, 2);
assert_eq!(dest_len, 4);
for i in 0..op_len {
let op = this.read_scalar(&this.project_index(&op, i)?)?.to_f64()?;
// For cvtpd2ps: op is f64x2, dest is f32x4
// For cvtps2pd: op is f32x4, dest is f64x2
// In either case, the two first values are converted
for i in 0..op_len.min(dest_len) {
let op = this.read_immediate(&this.project_index(&op, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = op.convert(/*loses_info*/ &mut false).value;
this.write_scalar(Scalar::from_f32(res), &dest)?;
let res = this.float_to_float_or_int(&op, dest.layout.ty)?;
this.write_immediate(res, &dest)?;
}
// Fill the remaining with zeros
// For f32 -> f64, ignore the remaining
// For f64 -> f32, fill the remaining with zeros
for i in op_len..dest_len {
let dest = this.project_index(&dest, i)?;
this.write_scalar(Scalar::from_u32(0), &dest)?;
this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
}
}
// Used to implement the _mm_cvtps_pd function.
// Converts packed f64 to packed f32.
"cvtps2pd" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
// op is f32x4, dest is f64x2
assert_eq!(op_len, 4);
assert_eq!(dest_len, 2);
for i in 0..dest_len {
let op = this.read_scalar(&this.project_index(&op, i)?)?.to_f32()?;
let dest = this.project_index(&dest, i)?;
let res = op.convert(/*loses_info*/ &mut false).value;
this.write_scalar(Scalar::from_f64(res), &dest)?;
}
// the two remaining f32 are ignored
}
// Used to implement the _mm_cvtpd_epi32 and _mm_cvttpd_epi32 functions.
// Converts packed f64 to packed i32.
"cvtpd2dq" | "cvttpd2dq" => {
@ -726,9 +697,10 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.write_scalar(Scalar::from_i32(0), &dest)?;
}
}
// Use to implement the _mm_cvtsd_si32 and _mm_cvttsd_si32 functions.
// Converts the first component of `op` from f64 to i32.
"cvtsd2si" | "cvttsd2si" => {
// Use to implement the _mm_cvtsd_si32, _mm_cvttsd_si32,
// _mm_cvtsd_si64 and _mm_cvttsd_si64 functions.
// Converts the first component of `op` from f64 to i32/i64.
"cvtsd2si" | "cvttsd2si" | "cvtsd2si64" | "cvttsd2si64" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (op, _) = this.operand_to_simd(op)?;
@ -737,41 +709,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let rnd = match unprefixed_name {
// "current SSE rounding mode", assume nearest
// https://www.felixcloutier.com/x86/cvtsd2si
"cvtsd2si" => rustc_apfloat::Round::NearestTiesToEven,
"cvtsd2si" | "cvtsd2si64" => rustc_apfloat::Round::NearestTiesToEven,
// always truncate
// https://www.felixcloutier.com/x86/cvttsd2si
"cvttsd2si" => rustc_apfloat::Round::TowardZero,
"cvttsd2si" | "cvttsd2si64" => rustc_apfloat::Round::TowardZero,
_ => unreachable!(),
};
let res = this.float_to_int_checked(op, dest.layout.ty, rnd).unwrap_or_else(|| {
// Fallback to minimum acording to SSE semantics.
Scalar::from_i32(i32::MIN)
});
this.write_scalar(res, dest)?;
}
// Use to implement the _mm_cvtsd_si64 and _mm_cvttsd_si64 functions.
// Converts the first component of `op` from f64 to i64.
"cvtsd2si64" | "cvttsd2si64" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (op, _) = this.operand_to_simd(op)?;
let op = this.read_scalar(&this.project_index(&op, 0)?)?.to_f64()?;
let rnd = match unprefixed_name {
// "current SSE rounding mode", assume nearest
// https://www.felixcloutier.com/x86/cvtsd2si
"cvtsd2si64" => rustc_apfloat::Round::NearestTiesToEven,
// always truncate
// https://www.felixcloutier.com/x86/cvttsd2si
"cvttsd2si64" => rustc_apfloat::Round::TowardZero,
_ => unreachable!(),
};
let res = this.float_to_int_checked(op, dest.layout.ty, rnd).unwrap_or_else(|| {
// Fallback to minimum acording to SSE semantics.
Scalar::from_i64(i64::MIN)
Scalar::from_int(dest.layout.size.signed_int_min(), dest.layout.size)
});
this.write_scalar(res, dest)?;
@ -844,139 +791,3 @@ fn extract_first_u64<'tcx>(
// Get the first u64 from the array
this.read_scalar(&this.project_index(&op, 0)?)?.to_u64()
}
#[derive(Copy, Clone)]
enum FloatBinOp {
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minsd>
/// <https://www.felixcloutier.com/x86/minpd>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxsd>
/// <https://www.felixcloutier.com/x86/maxpd>
Max,
}
/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
// FIXME make this generic over apfloat type to reduce code duplicaton with bin_op_f32
fn bin_op_f64<'tcx>(
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_f64()?;
let right = right.to_scalar().to_f64()?;
// FIXME: Make sure that these operations match the semantics of cmppd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(Scalar::from_u64(if res { u64::MAX } else { 0 }))
}
FloatBinOp::Min => {
let left = left.to_scalar().to_f64()?;
let right = right.to_scalar().to_f64()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Double::ZERO && right == Double::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(Scalar::from_f64(right))
} else {
Ok(Scalar::from_f64(left))
}
}
FloatBinOp::Max => {
let left = left.to_scalar().to_f64()?;
let right = right.to_scalar().to_f64()?;
// SSE semantics to handle zero and NaN. Note that `x == Single::ZERO`
// is true when `x` is either +0 or -0.
if (left == Double::ZERO && right == Double::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(Scalar::from_f64(right))
} else {
Ok(Scalar::from_f64(left))
}
}
}
}
/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_sd<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let res0 = bin_op_f64(
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())
}
/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_pd<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;
let res = bin_op_f64(which, &left, &right)?;
this.write_scalar(res, &dest)?;
}
Ok(())
}