From 8a77b3248f6e147d782ab29915dc1aa093ef7a7c Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Thu, 12 Oct 2023 23:29:16 +0200 Subject: [PATCH] Abstract over `PatRange` boundary value --- compiler/rustc_middle/src/thir.rs | 223 +++++++++++++++++- compiler/rustc_middle/src/ty/util.rs | 82 ++++--- .../rustc_mir_build/src/build/matches/mod.rs | 2 +- .../src/build/matches/simplify.rs | 44 +--- .../rustc_mir_build/src/build/matches/test.rs | 47 +--- .../src/thir/pattern/deconstruct_pat.rs | 24 +- .../rustc_mir_build/src/thir/pattern/mod.rs | 92 ++------ 7 files changed, 299 insertions(+), 215 deletions(-) diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs index f1747356139..fefdce8e77e 100644 --- a/compiler/rustc_middle/src/thir.rs +++ b/compiler/rustc_middle/src/thir.rs @@ -16,17 +16,19 @@ use rustc_hir::RangeEnd; use rustc_index::newtype_index; use rustc_index::IndexVec; use rustc_middle::middle::region; -use rustc_middle::mir::interpret::AllocId; +use rustc_middle::mir::interpret::{AllocId, Scalar}; use rustc_middle::mir::{self, BinOp, BorrowKind, FakeReadCause, Mutability, UnOp}; use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::layout::IntegerExt; use rustc_middle::ty::{ self, AdtDef, CanonicalUserType, CanonicalUserTypeAnnotation, FnSig, GenericArgsRef, List, Ty, - UpvarArgs, + TyCtxt, UpvarArgs, }; use rustc_span::def_id::LocalDefId; use rustc_span::{sym, ErrorGuaranteed, Span, Symbol, DUMMY_SP}; -use rustc_target::abi::{FieldIdx, VariantIdx}; +use rustc_target::abi::{FieldIdx, Integer, Size, VariantIdx}; use rustc_target::asm::InlineAsmRegOrRegClass; +use std::cmp::Ordering; use std::fmt; use std::ops::Index; @@ -810,12 +812,217 @@ pub enum PatKind<'tcx> { Error(ErrorGuaranteed), } +/// A range pattern. +/// The boundaries must be of the same type and that type must be numeric. #[derive(Clone, Debug, PartialEq, HashStable, TypeVisitable)] pub struct PatRange<'tcx> { - pub lo: mir::Const<'tcx>, - pub hi: mir::Const<'tcx>, + pub lo: PatRangeBoundary<'tcx>, + pub hi: PatRangeBoundary<'tcx>, #[type_visitable(ignore)] pub end: RangeEnd, + pub ty: Ty<'tcx>, +} + +impl<'tcx> PatRange<'tcx> { + /// Whether this range covers the full extent of possible values (best-effort, we ignore floats). + #[inline] + pub fn is_full_range(&self, tcx: TyCtxt<'tcx>) -> Option { + let (min, max, size, bias) = match *self.ty.kind() { + ty::Char => (0, std::char::MAX as u128, Size::from_bits(32), 0), + ty::Int(ity) => { + let size = Integer::from_int_ty(&tcx, ity).size(); + let max = size.truncate(u128::MAX); + let bias = 1u128 << (size.bits() - 1); + (0, max, size, bias) + } + ty::Uint(uty) => { + let size = Integer::from_uint_ty(&tcx, uty).size(); + let max = size.unsigned_int_max(); + (0, max, size, 0) + } + _ => return None, + }; + + // We want to compare ranges numerically, but the order of the bitwise representation of + // signed integers does not match their numeric order. Thus, to correct the ordering, we + // need to shift the range of signed integers to correct the comparison. This is achieved by + // XORing with a bias (see pattern/deconstruct_pat.rs for another pertinent example of this + // pattern). + // + // Also, for performance, it's important to only do the second `try_to_bits` if necessary. + let lo_is_min = match self.lo { + PatRangeBoundary::Finite(value) => { + let lo = value.try_to_bits(size).unwrap() ^ bias; + lo <= min + } + }; + if lo_is_min { + let hi_is_max = match self.hi { + PatRangeBoundary::Finite(value) => { + let hi = value.try_to_bits(size).unwrap() ^ bias; + hi > max || hi == max && self.end == RangeEnd::Included + } + }; + if hi_is_max { + return Some(true); + } + } + Some(false) + } + + #[inline] + pub fn contains( + &self, + value: mir::Const<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ) -> Option { + use Ordering::*; + debug_assert_eq!(self.ty, value.ty()); + let ty = self.ty; + let value = PatRangeBoundary::Finite(value); + // For performance, it's important to only do the second comparison if necessary. + Some( + match self.lo.compare_with(value, ty, tcx, param_env)? { + Less | Equal => true, + Greater => false, + } && match value.compare_with(self.hi, ty, tcx, param_env)? { + Less => true, + Equal => self.end == RangeEnd::Included, + Greater => false, + }, + ) + } + + #[inline] + pub fn overlaps( + &self, + other: &Self, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ) -> Option { + use Ordering::*; + debug_assert_eq!(self.ty, other.ty); + // For performance, it's important to only do the second comparison if necessary. + Some( + match other.lo.compare_with(self.hi, self.ty, tcx, param_env)? { + Less => true, + Equal => self.end == RangeEnd::Included, + Greater => false, + } && match self.lo.compare_with(other.hi, self.ty, tcx, param_env)? { + Less => true, + Equal => other.end == RangeEnd::Included, + Greater => false, + }, + ) + } +} + +impl<'tcx> fmt::Display for PatRange<'tcx> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let PatRangeBoundary::Finite(value) = &self.lo; + write!(f, "{value}")?; + write!(f, "{}", self.end)?; + let PatRangeBoundary::Finite(value) = &self.hi; + write!(f, "{value}")?; + Ok(()) + } +} + +/// A (possibly open) boundary of a range pattern. +/// If present, the const must be of a numeric type. +#[derive(Copy, Clone, Debug, PartialEq, HashStable, TypeVisitable)] +pub enum PatRangeBoundary<'tcx> { + Finite(mir::Const<'tcx>), +} + +impl<'tcx> PatRangeBoundary<'tcx> { + #[inline] + pub fn lower_bound(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Self { + // Unwrap is ok because the type is known to be numeric. + let c = ty.numeric_min_val(tcx).unwrap(); + let value = mir::Const::from_ty_const(c, tcx); + Self::Finite(value) + } + #[inline] + pub fn upper_bound(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Self { + // Unwrap is ok because the type is known to be numeric. + let c = ty.numeric_max_val(tcx).unwrap(); + let value = mir::Const::from_ty_const(c, tcx); + Self::Finite(value) + } + + #[inline] + pub fn to_const(self, _ty: Ty<'tcx>, _tcx: TyCtxt<'tcx>) -> mir::Const<'tcx> { + match self { + Self::Finite(value) => value, + } + } + pub fn eval_bits( + self, + _ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ) -> u128 { + match self { + Self::Finite(value) => value.eval_bits(tcx, param_env), + } + } + + #[instrument(skip(tcx, param_env), level = "debug", ret)] + pub fn compare_with( + self, + other: Self, + ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ) -> Option { + use PatRangeBoundary::*; + match (self, other) { + // This code is hot when compiling matches with many ranges. So we + // special-case extraction of evaluated scalars for speed, for types where + // raw data comparisons are appropriate. E.g. `unicode-normalization` has + // many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared + // in this way. + (Finite(mir::Const::Ty(a)), Finite(mir::Const::Ty(b))) + if matches!(ty.kind(), ty::Uint(_) | ty::Char) => + { + return Some(a.kind().cmp(&b.kind())); + } + ( + Finite(mir::Const::Val(mir::ConstValue::Scalar(Scalar::Int(a)), _)), + Finite(mir::Const::Val(mir::ConstValue::Scalar(Scalar::Int(b)), _)), + ) if matches!(ty.kind(), ty::Uint(_) | ty::Char) => return Some(a.cmp(&b)), + _ => {} + } + + let a = self.eval_bits(ty, tcx, param_env); + let b = other.eval_bits(ty, tcx, param_env); + + match ty.kind() { + ty::Float(ty::FloatTy::F32) => { + use rustc_apfloat::Float; + let a = rustc_apfloat::ieee::Single::from_bits(a); + let b = rustc_apfloat::ieee::Single::from_bits(b); + a.partial_cmp(&b) + } + ty::Float(ty::FloatTy::F64) => { + use rustc_apfloat::Float; + let a = rustc_apfloat::ieee::Double::from_bits(a); + let b = rustc_apfloat::ieee::Double::from_bits(b); + a.partial_cmp(&b) + } + ty::Int(ity) => { + use rustc_middle::ty::layout::IntegerExt; + let size = rustc_target::abi::Integer::from_int_ty(&tcx, *ity).size(); + let a = size.sign_extend(a) as i128; + let b = size.sign_extend(b) as i128; + Some(a.cmp(&b)) + } + ty::Uint(_) | ty::Char => Some(a.cmp(&b)), + _ => bug!(), + } + } } impl<'tcx> fmt::Display for Pat<'tcx> { @@ -944,11 +1151,7 @@ impl<'tcx> fmt::Display for Pat<'tcx> { PatKind::InlineConstant { def: _, ref subpattern } => { write!(f, "{} (from inline const)", subpattern) } - PatKind::Range(box PatRange { lo, hi, end }) => { - write!(f, "{lo}")?; - write!(f, "{end}")?; - write!(f, "{hi}") - } + PatKind::Range(ref range) => write!(f, "{range}"), PatKind::Slice { ref prefix, ref slice, ref suffix } | PatKind::Array { ref prefix, ref slice, ref suffix } => { write!(f, "[")?; diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs index be48c0e8926..a2a74dafc5f 100644 --- a/compiler/rustc_middle/src/ty/util.rs +++ b/compiler/rustc_middle/src/ty/util.rs @@ -19,7 +19,7 @@ use rustc_index::bit_set::GrowableBitSet; use rustc_macros::HashStable; use rustc_session::Limit; use rustc_span::sym; -use rustc_target::abi::{Integer, IntegerType, Size}; +use rustc_target::abi::{Integer, IntegerType, Primitive, Size}; use rustc_target::spec::abi::Abi; use smallvec::SmallVec; use std::{fmt, iter}; @@ -917,54 +917,62 @@ impl<'tcx> TypeFolder> for OpaqueTypeExpander<'tcx> { } impl<'tcx> Ty<'tcx> { + /// Returns the `Size` for primitive types (bool, uint, int, char, float). + pub fn primitive_size(self, tcx: TyCtxt<'tcx>) -> Size { + match *self.kind() { + ty::Bool => Size::from_bytes(1), + ty::Char => Size::from_bytes(4), + ty::Int(ity) => Integer::from_int_ty(&tcx, ity).size(), + ty::Uint(uty) => Integer::from_uint_ty(&tcx, uty).size(), + ty::Float(ty::FloatTy::F32) => Primitive::F32.size(&tcx), + ty::Float(ty::FloatTy::F64) => Primitive::F64.size(&tcx), + _ => bug!("non primitive type"), + } + } + pub fn int_size_and_signed(self, tcx: TyCtxt<'tcx>) -> (Size, bool) { - let (int, signed) = match *self.kind() { - ty::Int(ity) => (Integer::from_int_ty(&tcx, ity), true), - ty::Uint(uty) => (Integer::from_uint_ty(&tcx, uty), false), + match *self.kind() { + ty::Int(ity) => (Integer::from_int_ty(&tcx, ity).size(), true), + ty::Uint(uty) => (Integer::from_uint_ty(&tcx, uty).size(), false), _ => bug!("non integer discriminant"), - }; - (int.size(), signed) + } + } + + /// Returns the minimum and maximum values for the given numeric type (including `char`s) or + /// returns `None` if the type is not numeric. + pub fn numeric_min_and_max_as_bits(self, tcx: TyCtxt<'tcx>) -> Option<(u128, u128)> { + use rustc_apfloat::ieee::{Double, Single}; + Some(match self.kind() { + ty::Int(_) | ty::Uint(_) => { + let (size, signed) = self.int_size_and_signed(tcx); + let min = if signed { size.truncate(size.signed_int_min() as u128) } else { 0 }; + let max = + if signed { size.signed_int_max() as u128 } else { size.unsigned_int_max() }; + (min, max) + } + ty::Char => (0, std::char::MAX as u128), + ty::Float(ty::FloatTy::F32) => { + ((-Single::INFINITY).to_bits(), Single::INFINITY.to_bits()) + } + ty::Float(ty::FloatTy::F64) => { + ((-Double::INFINITY).to_bits(), Double::INFINITY.to_bits()) + } + _ => return None, + }) } /// Returns the maximum value for the given numeric type (including `char`s) /// or returns `None` if the type is not numeric. pub fn numeric_max_val(self, tcx: TyCtxt<'tcx>) -> Option> { - let val = match self.kind() { - ty::Int(_) | ty::Uint(_) => { - let (size, signed) = self.int_size_and_signed(tcx); - let val = - if signed { size.signed_int_max() as u128 } else { size.unsigned_int_max() }; - Some(val) - } - ty::Char => Some(std::char::MAX as u128), - ty::Float(fty) => Some(match fty { - ty::FloatTy::F32 => rustc_apfloat::ieee::Single::INFINITY.to_bits(), - ty::FloatTy::F64 => rustc_apfloat::ieee::Double::INFINITY.to_bits(), - }), - _ => None, - }; - - val.map(|v| ty::Const::from_bits(tcx, v, ty::ParamEnv::empty().and(self))) + self.numeric_min_and_max_as_bits(tcx) + .map(|(_, max)| ty::Const::from_bits(tcx, max, ty::ParamEnv::empty().and(self))) } /// Returns the minimum value for the given numeric type (including `char`s) /// or returns `None` if the type is not numeric. pub fn numeric_min_val(self, tcx: TyCtxt<'tcx>) -> Option> { - let val = match self.kind() { - ty::Int(_) | ty::Uint(_) => { - let (size, signed) = self.int_size_and_signed(tcx); - let val = if signed { size.truncate(size.signed_int_min() as u128) } else { 0 }; - Some(val) - } - ty::Char => Some(0), - ty::Float(fty) => Some(match fty { - ty::FloatTy::F32 => (-::rustc_apfloat::ieee::Single::INFINITY).to_bits(), - ty::FloatTy::F64 => (-::rustc_apfloat::ieee::Double::INFINITY).to_bits(), - }), - _ => None, - }; - - val.map(|v| ty::Const::from_bits(tcx, v, ty::ParamEnv::empty().and(self))) + self.numeric_min_and_max_as_bits(tcx) + .map(|(min, _)| ty::Const::from_bits(tcx, min, ty::ParamEnv::empty().and(self))) } /// Checks whether values of this type `T` are *moved* or *copied* diff --git a/compiler/rustc_mir_build/src/build/matches/mod.rs b/compiler/rustc_mir_build/src/build/matches/mod.rs index 1cf8c202ea4..2e30cdd2c39 100644 --- a/compiler/rustc_mir_build/src/build/matches/mod.rs +++ b/compiler/rustc_mir_build/src/build/matches/mod.rs @@ -1027,7 +1027,7 @@ enum TestKind<'tcx> { ty: Ty<'tcx>, }, - /// Test whether the value falls within an inclusive or exclusive range + /// Test whether the value falls within an inclusive or exclusive range. Range(Box>), /// Test that the length of the slice is equal to `len`. diff --git a/compiler/rustc_mir_build/src/build/matches/simplify.rs b/compiler/rustc_mir_build/src/build/matches/simplify.rs index 32573b4d53a..6a40c8d840b 100644 --- a/compiler/rustc_mir_build/src/build/matches/simplify.rs +++ b/compiler/rustc_mir_build/src/build/matches/simplify.rs @@ -15,11 +15,7 @@ use crate::build::expr::as_place::PlaceBuilder; use crate::build::matches::{Ascription, Binding, Candidate, MatchPair}; use crate::build::Builder; -use rustc_hir::RangeEnd; use rustc_middle::thir::{self, *}; -use rustc_middle::ty; -use rustc_middle::ty::layout::IntegerExt; -use rustc_target::abi::{Integer, Size}; use std::mem; @@ -148,7 +144,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { match_pair: MatchPair<'pat, 'tcx>, candidate: &mut Candidate<'pat, 'tcx>, ) -> Result<(), MatchPair<'pat, 'tcx>> { - let tcx = self.tcx; match match_pair.pattern.kind { PatKind::AscribeUserType { ref subpattern, @@ -210,41 +205,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { Ok(()) } - PatKind::Range(box PatRange { lo, hi, end }) => { - let (range, bias) = match *lo.ty().kind() { - ty::Char => { - (Some(('\u{0000}' as u128, '\u{10FFFF}' as u128, Size::from_bits(32))), 0) - } - ty::Int(ity) => { - let size = Integer::from_int_ty(&tcx, ity).size(); - let max = size.truncate(u128::MAX); - let bias = 1u128 << (size.bits() - 1); - (Some((0, max, size)), bias) - } - ty::Uint(uty) => { - let size = Integer::from_uint_ty(&tcx, uty).size(); - let max = size.truncate(u128::MAX); - (Some((0, max, size)), 0) - } - _ => (None, 0), - }; - if let Some((min, max, sz)) = range { - // We want to compare ranges numerically, but the order of the bitwise - // representation of signed integers does not match their numeric order. Thus, - // to correct the ordering, we need to shift the range of signed integers to - // correct the comparison. This is achieved by XORing with a bias (see - // pattern/_match.rs for another pertinent example of this pattern). - // - // Also, for performance, it's important to only do the second - // `try_to_bits` if necessary. - let lo = lo.try_to_bits(sz).unwrap() ^ bias; - if lo <= min { - let hi = hi.try_to_bits(sz).unwrap() ^ bias; - if hi > max || hi == max && end == RangeEnd::Included { - // Irrefutable pattern match. - return Ok(()); - } - } + PatKind::Range(ref range) => { + if let Some(true) = range.is_full_range(self.tcx) { + // Irrefutable pattern match. + return Ok(()); } Err(match_pair) } diff --git a/compiler/rustc_mir_build/src/build/matches/test.rs b/compiler/rustc_mir_build/src/build/matches/test.rs index 5e7db7413df..bdd4f2011eb 100644 --- a/compiler/rustc_mir_build/src/build/matches/test.rs +++ b/compiler/rustc_mir_build/src/build/matches/test.rs @@ -8,7 +8,6 @@ use crate::build::expr::as_place::PlaceBuilder; use crate::build::matches::{Candidate, MatchPair, Test, TestKind}; use crate::build::Builder; -use crate::thir::pattern::compare_const_vals; use rustc_data_structures::fx::FxIndexMap; use rustc_hir::{LangItem, RangeEnd}; use rustc_index::bit_set::BitSet; @@ -59,8 +58,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }, PatKind::Range(ref range) => { - assert_eq!(range.lo.ty(), match_pair.pattern.ty); - assert_eq!(range.hi.ty(), match_pair.pattern.ty); + assert_eq!(range.ty, match_pair.pattern.ty); Test { span: match_pair.pattern.span, kind: TestKind::Range(range.clone()) } } @@ -309,11 +307,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } - TestKind::Range(box PatRange { lo, hi, ref end }) => { + TestKind::Range(ref range) => { let lower_bound_success = self.cfg.start_new_block(); let target_blocks = make_target_blocks(self); // Test `val` by computing `lo <= val && val <= hi`, using primitive comparisons. + // FIXME: skip useless comparison when the range is half-open. + let lo = range.lo.to_const(range.ty, self.tcx); + let hi = range.hi.to_const(range.ty, self.tcx); let lo = self.literal_operand(test.span, lo); let hi = self.literal_operand(test.span, hi); let val = Operand::Copy(place); @@ -330,7 +331,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { lo, val.clone(), ); - let op = match *end { + let op = match range.end { RangeEnd::Included => BinOp::Le, RangeEnd::Excluded => BinOp::Lt, }; @@ -698,34 +699,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } (TestKind::Range(test), PatKind::Range(pat)) => { - use std::cmp::Ordering::*; - if test == pat { self.candidate_without_match_pair(match_pair_index, candidate); return Some(0); } - // For performance, it's important to only do the second - // `compare_const_vals` if necessary. - let no_overlap = if matches!( - (compare_const_vals(self.tcx, test.hi, pat.lo, self.param_env)?, test.end), - (Less, _) | (Equal, RangeEnd::Excluded) // test < pat - ) || matches!( - (compare_const_vals(self.tcx, test.lo, pat.hi, self.param_env)?, pat.end), - (Greater, _) | (Equal, RangeEnd::Excluded) // test > pat - ) { - Some(1) - } else { - None - }; - // If the testing range does not overlap with pattern range, // the pattern can be matched only if this test fails. - no_overlap + if !test.overlaps(pat, self.tcx, self.param_env)? { Some(1) } else { None } } (TestKind::Range(range), &PatKind::Constant { value }) => { - if let Some(false) = self.const_range_contains(&*range, value) { + if !range.contains(value, self.tcx, self.param_env)? { // `value` is not contained in the testing range, // so `value` can be matched only if this test fails. Some(1) @@ -817,27 +802,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { span_bug!(match_pair.pattern.span, "simplifiable pattern found: {:?}", match_pair.pattern) } - fn const_range_contains(&self, range: &PatRange<'tcx>, value: Const<'tcx>) -> Option { - use std::cmp::Ordering::*; - - // For performance, it's important to only do the second - // `compare_const_vals` if necessary. - Some( - matches!(compare_const_vals(self.tcx, range.lo, value, self.param_env)?, Less | Equal) - && matches!( - (compare_const_vals(self.tcx, value, range.hi, self.param_env)?, range.end), - (Less, _) | (Equal, RangeEnd::Included) - ), - ) - } - fn values_not_contained_in_range( &self, range: &PatRange<'tcx>, options: &FxIndexMap, u128>, ) -> Option { for &val in options.keys() { - if self.const_range_contains(range, val)? { + if range.contains(val, self.tcx, self.param_env)? { return Some(false); } } diff --git a/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs b/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs index 186c77795e4..f93daa29afb 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs @@ -57,7 +57,7 @@ use rustc_hir::RangeEnd; use rustc_index::Idx; use rustc_middle::middle::stability::EvalResult; use rustc_middle::mir; -use rustc_middle::thir::{FieldPat, Pat, PatKind, PatRange}; +use rustc_middle::thir::{FieldPat, Pat, PatKind, PatRange, PatRangeBoundary}; use rustc_middle::ty::layout::IntegerExt; use rustc_middle::ty::{self, Ty, TyCtxt, VariantDef}; use rustc_span::{Span, DUMMY_SP}; @@ -278,19 +278,20 @@ impl IntRange { let (lo, hi) = self.boundaries(); let bias = IntRange::signed_bias(tcx, ty); - let (lo, hi) = (lo ^ bias, hi ^ bias); + let (lo_bits, hi_bits) = (lo ^ bias, hi ^ bias); let env = ty::ParamEnv::empty().and(ty); - let lo_const = mir::Const::from_bits(tcx, lo, env); - let hi_const = mir::Const::from_bits(tcx, hi, env); + let lo_const = mir::Const::from_bits(tcx, lo_bits, env); + let hi_const = mir::Const::from_bits(tcx, hi_bits, env); - let kind = if lo == hi { + let kind = if lo_bits == hi_bits { PatKind::Constant { value: lo_const } } else { PatKind::Range(Box::new(PatRange { - lo: lo_const, - hi: hi_const, + lo: PatRangeBoundary::Finite(lo_const), + hi: PatRangeBoundary::Finite(hi_const), end: RangeEnd::Included, + ty, })) }; @@ -1387,11 +1388,12 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> { } } } - PatKind::Range(box PatRange { lo, hi, end }) => { + PatKind::Range(box PatRange { lo, hi, end, .. }) => { use rustc_apfloat::Float; - let ty = lo.ty(); - let lo = lo.try_eval_bits(cx.tcx, cx.param_env).unwrap(); - let hi = hi.try_eval_bits(cx.tcx, cx.param_env).unwrap(); + let ty = pat.ty; + // FIXME: handle half-open ranges + let lo = lo.eval_bits(ty, cx.tcx, cx.param_env); + let hi = hi.eval_bits(ty, cx.tcx, cx.param_env); ctor = match ty.kind() { ty::Char | ty::Int(_) | ty::Uint(_) => { IntRange(IntRange::from_range(cx.tcx, lo, hi, ty, *end)) diff --git a/compiler/rustc_mir_build/src/thir/pattern/mod.rs b/compiler/rustc_mir_build/src/thir/pattern/mod.rs index dd71ab1f8e5..c530c0a772c 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/mod.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/mod.rs @@ -17,11 +17,11 @@ use rustc_hir::def::{CtorOf, DefKind, Res}; use rustc_hir::pat_util::EnumerateAndAdjustIterator; use rustc_hir::RangeEnd; use rustc_index::Idx; -use rustc_middle::mir::interpret::{ - ErrorHandled, GlobalId, LitToConstError, LitToConstInput, Scalar, -}; +use rustc_middle::mir::interpret::{ErrorHandled, GlobalId, LitToConstError, LitToConstInput}; use rustc_middle::mir::{self, BorrowKind, Const, Mutability, UserTypeProjection}; -use rustc_middle::thir::{Ascription, BindingMode, FieldPat, LocalVarId, Pat, PatKind, PatRange}; +use rustc_middle::thir::{ + Ascription, BindingMode, FieldPat, LocalVarId, Pat, PatKind, PatRange, PatRangeBoundary, +}; use rustc_middle::ty::layout::IntegerExt; use rustc_middle::ty::{ self, AdtDef, CanonicalUserTypeAnnotation, GenericArg, GenericArgsRef, Region, Ty, TyCtxt, @@ -90,7 +90,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> { &mut self, expr: Option<&'tcx hir::Expr<'tcx>>, ) -> Result< - (Option>, Option>, Option), + (Option>, Option>, Option), ErrorGuaranteed, > { match expr { @@ -113,7 +113,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> { ); return Err(self.tcx.sess.delay_span_bug(expr.span, msg)); }; - Ok((Some(value), ascr, inline_const)) + Ok((Some(PatRangeBoundary::Finite(value)), ascr, inline_const)) } } } @@ -187,31 +187,23 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> { let (lo, lo_ascr, lo_inline) = self.lower_pattern_range_endpoint(lo_expr)?; let (hi, hi_ascr, hi_inline) = self.lower_pattern_range_endpoint(hi_expr)?; - let lo = lo.unwrap_or_else(|| { - // Unwrap is ok because the type is known to be numeric. - let lo = ty.numeric_min_val(self.tcx).unwrap(); - mir::Const::from_ty_const(lo, self.tcx) - }); - let hi = hi.unwrap_or_else(|| { - // Unwrap is ok because the type is known to be numeric. - let hi = ty.numeric_max_val(self.tcx).unwrap(); - mir::Const::from_ty_const(hi, self.tcx) - }); - assert_eq!(lo.ty(), ty); - assert_eq!(hi.ty(), ty); + let lo = lo.unwrap_or_else(|| PatRangeBoundary::lower_bound(ty, self.tcx)); + let hi = hi.unwrap_or_else(|| PatRangeBoundary::upper_bound(ty, self.tcx)); - let cmp = compare_const_vals(self.tcx, lo, hi, self.param_env); + let cmp = lo.compare_with(hi, ty, self.tcx, self.param_env); let mut kind = match (end, cmp) { // `x..y` where `x < y`. // Non-empty because the range includes at least `x`. (RangeEnd::Excluded, Some(Ordering::Less)) => { - PatKind::Range(Box::new(PatRange { lo, hi, end })) + PatKind::Range(Box::new(PatRange { lo, hi, end, ty })) } // `x..=y` where `x == y`. - (RangeEnd::Included, Some(Ordering::Equal)) => PatKind::Constant { value: lo }, + (RangeEnd::Included, Some(Ordering::Equal)) => { + PatKind::Constant { value: lo.to_const(ty, self.tcx) } + } // `x..=y` where `x < y`. (RangeEnd::Included, Some(Ordering::Less)) => { - PatKind::Range(Box::new(PatRange { lo, hi, end })) + PatKind::Range(Box::new(PatRange { lo, hi, end, ty })) } // `x..y` where `x >= y`, or `x..=y` where `x > y`. The range is empty => error. _ => { @@ -851,59 +843,3 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> { } } } - -#[instrument(skip(tcx), level = "debug")] -pub(crate) fn compare_const_vals<'tcx>( - tcx: TyCtxt<'tcx>, - a: mir::Const<'tcx>, - b: mir::Const<'tcx>, - param_env: ty::ParamEnv<'tcx>, -) -> Option { - assert_eq!(a.ty(), b.ty()); - - let ty = a.ty(); - - // This code is hot when compiling matches with many ranges. So we - // special-case extraction of evaluated scalars for speed, for types where - // raw data comparisons are appropriate. E.g. `unicode-normalization` has - // many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared - // in this way. - match ty.kind() { - ty::Float(_) | ty::Int(_) => {} // require special handling, see below - _ => match (a, b) { - ( - mir::Const::Val(mir::ConstValue::Scalar(Scalar::Int(a)), _a_ty), - mir::Const::Val(mir::ConstValue::Scalar(Scalar::Int(b)), _b_ty), - ) => return Some(a.cmp(&b)), - (mir::Const::Ty(a), mir::Const::Ty(b)) => { - return Some(a.kind().cmp(&b.kind())); - } - _ => {} - }, - } - - let a = a.eval_bits(tcx, param_env); - let b = b.eval_bits(tcx, param_env); - - use rustc_apfloat::Float; - match *ty.kind() { - ty::Float(ty::FloatTy::F32) => { - let a = rustc_apfloat::ieee::Single::from_bits(a); - let b = rustc_apfloat::ieee::Single::from_bits(b); - a.partial_cmp(&b) - } - ty::Float(ty::FloatTy::F64) => { - let a = rustc_apfloat::ieee::Double::from_bits(a); - let b = rustc_apfloat::ieee::Double::from_bits(b); - a.partial_cmp(&b) - } - ty::Int(ity) => { - use rustc_middle::ty::layout::IntegerExt; - let size = rustc_target::abi::Integer::from_int_ty(&tcx, ity).size(); - let a = size.sign_extend(a); - let b = size.sign_extend(b); - Some((a as i128).cmp(&(b as i128))) - } - _ => Some(a.cmp(&b)), - } -}