Support non-scalar constants.

This commit is contained in:
Camille GILLOT 2023-05-13 12:30:40 +00:00
parent 68c2f5ba0f
commit 6ad6b4381c
12 changed files with 259 additions and 22 deletions

View file

@ -445,9 +445,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
#[inline(always)]
pub fn cur_span(&self) -> Span {
// This deliberately does *not* honor `requires_caller_location` since it is used for much
// more than just panics.
self.stack().last().map_or(self.tcx.span, |f| f.current_span())
M::cur_span(self)
}
#[inline(always)]

View file

@ -11,6 +11,7 @@ use rustc_middle::mir;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::def_id::DefId;
use rustc_span::Span;
use rustc_target::abi::{Align, Size};
use rustc_target::spec::abi::Abi as CallAbi;
@ -440,6 +441,15 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
frame: Frame<'mir, 'tcx, Self::Provenance>,
) -> InterpResult<'tcx, Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>>;
fn cur_span(ecx: &InterpCx<'mir, 'tcx, Self>) -> Span
where
'tcx: 'mir,
{
// This deliberately does *not* honor `requires_caller_location` since it is used for much
// more than just panics.
Self::stack(ecx).last().map_or(ecx.tcx.span, |f| f.current_span())
}
/// Borrow the current thread's stack.
fn stack<'a>(
ecx: &'a InterpCx<'mir, 'tcx, Self>,

View file

@ -532,7 +532,7 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
/// places that are non-overlapping or identical.
///
/// The target place must have been flooded before calling this method.
fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
pub fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
let StateData::Reachable(values) = &mut self.0 else { return };
// If both places are tracked, we copy the value to the target.
@ -928,6 +928,31 @@ impl Map {
f(v)
}
}
/// Invoke a function on each value in the given place and all descendants.
pub fn for_each_projection_value<O>(
&self,
root: PlaceIndex,
value: O,
project: &mut impl FnMut(TrackElem, &O) -> Option<O>,
f: &mut impl FnMut(PlaceIndex, &O),
) {
// Fast path is there is nothing to do.
if self.inner_values[root].is_empty() {
return;
}
if self.places[root].value_index.is_some() {
f(root, &value)
}
for child in self.children(root) {
let elem = self.places[child].proj_elem.unwrap();
if let Some(value) = project(elem, &value) {
self.for_each_projection_value(child, value, project, f);
}
}
}
}
/// This is the information tracked for every [`PlaceIndex`] and is stored by [`Map`].

View file

@ -3,18 +3,19 @@
//! Currently, this pass only propagates scalar values.
use rustc_const_eval::const_eval::CheckAlignment;
use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{ConstValue, Scalar};
use rustc_middle::mir::visit::{MutVisitor, NonMutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_mir_dataflow::value_analysis::{
Map, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
};
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
use rustc_span::DUMMY_SP;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{Align, FieldIdx, VariantIdx};
use crate::MirPass;
@ -111,6 +112,12 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
state: &mut State<Self::Value>,
) {
match rvalue {
Rvalue::Use(operand) => {
state.flood(target.as_ref(), self.map());
if let Some(target) = self.map.find(target.as_ref()) {
self.assign_operand(state, target, operand);
}
}
Rvalue::Aggregate(kind, operands) => {
// If we assign `target = Enum::Variant#0(operand)`,
// we must make sure that all `target as Variant#i` are `Top`.
@ -138,8 +145,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
variant_target_idx,
TrackElem::Field(FieldIdx::from_usize(field_index)),
) {
let result = self.handle_operand(operand, state);
state.insert_idx(field, result, self.map());
self.assign_operand(state, field, operand);
}
}
}
@ -330,6 +336,86 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
}
}
/// The caller must have flooded `place`.
fn assign_operand(
&self,
state: &mut State<FlatSet<ScalarInt>>,
place: PlaceIndex,
operand: &Operand<'tcx>,
) {
match operand {
Operand::Copy(rhs) | Operand::Move(rhs) => {
if let Some(rhs) = self.map.find(rhs.as_ref()) {
state.insert_place_idx(place, rhs, &self.map)
}
}
Operand::Constant(box constant) => {
if let Ok(constant) = self.ecx.eval_mir_constant(&constant.literal, None, None) {
self.assign_constant(state, place, constant, &[]);
}
}
}
}
/// The caller must have flooded `place`.
///
/// Perform: `place = operand.projection`.
#[instrument(level = "trace", skip(self, state))]
fn assign_constant(
&self,
state: &mut State<FlatSet<ScalarInt>>,
place: PlaceIndex,
mut operand: OpTy<'tcx>,
projection: &[PlaceElem<'tcx>],
) -> Option<!> {
for &(mut proj_elem) in projection {
if let PlaceElem::Index(index) = proj_elem {
if let FlatSet::Elem(index) = state.get(index.into(), &self.map)
&& let Ok(offset) = index.try_to_target_usize(self.tcx)
&& let Some(min_length) = offset.checked_add(1)
{
proj_elem = PlaceElem::ConstantIndex { offset, min_length, from_end: false };
} else {
return None;
}
}
operand = self.ecx.project(&operand, proj_elem).ok()?;
}
self.map.for_each_projection_value(
place,
operand,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let scalar = self.ecx.discriminant_for_variant(op.layout, variant).ok()?;
let discr_ty = op.layout.ty.discriminant_ty(self.tcx);
let layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
Some(ImmTy::from_scalar(scalar, layout).into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout =
self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(scalar)) = *imm
{
state.insert_value_idx(place, FlatSet::Elem(scalar), &self.map);
}
},
);
None
}
fn binary_op(
&self,
state: &mut State<FlatSet<ScalarInt>>,
@ -604,8 +690,16 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
type MemoryKind = !;
const PANIC_ON_ALLOC_FAIL: bool = true;
#[inline(always)]
fn cur_span(_ecx: &InterpCx<'mir, 'tcx, Self>) -> Span {
DUMMY_SP
}
#[inline(always)]
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment {
unimplemented!()
// We do not check for alignment to avoid having to carry an `Align`
// in `ConstValue::ByRef`.
CheckAlignment::No
}
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {

View file

@ -0,0 +1,63 @@
- // MIR for `constant` before DataflowConstProp
+ // MIR for `constant` after DataflowConstProp
fn constant() -> () {
let mut _0: ();
let _1: E;
let mut _3: isize;
scope 1 {
debug e => _1;
let _2: i32;
let _4: i32;
let _5: i32;
scope 2 {
debug x => _2;
}
scope 3 {
debug x => _4;
}
scope 4 {
debug x => _5;
}
}
bb0: {
StorageLive(_1);
_1 = const _;
StorageLive(_2);
- _3 = discriminant(_1);
- switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2];
+ _3 = const 0_isize;
+ switchInt(const 0_isize) -> [0: bb3, 1: bb1, otherwise: bb2];
}
bb1: {
StorageLive(_5);
_5 = ((_1 as V2).0: i32);
_2 = _5;
StorageDead(_5);
goto -> bb4;
}
bb2: {
unreachable;
}
bb3: {
StorageLive(_4);
- _4 = ((_1 as V1).0: i32);
- _2 = _4;
+ _4 = const 0_i32;
+ _2 = const 0_i32;
StorageDead(_4);
goto -> bb4;
}
bb4: {
_0 = const ();
StorageDead(_2);
StorageDead(_1);
return;
}
}

View file

@ -15,6 +15,13 @@ fn simple() {
let x = match e { E::V1(x) => x, E::V2(x) => x };
}
// EMIT_MIR enum.constant.DataflowConstProp.diff
fn constant() {
const C: E = E::V1(0);
let e = C;
let x = match e { E::V1(x) => x, E::V2(x) => x };
}
#[rustc_layout_scalar_valid_range_start(1)]
#[rustc_nonnull_optimization_guaranteed]
struct NonZeroUsize(usize);
@ -63,6 +70,7 @@ fn multiple(x: bool, i: u8) {
fn main() {
simple();
constant();
mutate_discriminant();
multiple(false, 5);
}

View file

@ -55,11 +55,12 @@
_10 = const _;
StorageLive(_11);
_11 = const 1_usize;
_12 = Len((*_10));
- _12 = Len((*_10));
- _13 = Lt(_11, _12);
- assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, _11) -> [success: bb2, unwind unreachable];
+ _13 = Lt(const 1_usize, _12);
+ assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, const 1_usize) -> [success: bb2, unwind unreachable];
+ _12 = const 3_usize;
+ _13 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 3_usize, const 1_usize) -> [success: bb2, unwind unreachable];
}
bb2: {

View file

@ -55,11 +55,12 @@
_10 = const _;
StorageLive(_11);
_11 = const 1_usize;
_12 = Len((*_10));
- _12 = Len((*_10));
- _13 = Lt(_11, _12);
- assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, _11) -> [success: bb2, unwind continue];
+ _13 = Lt(const 1_usize, _12);
+ assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, const 1_usize) -> [success: bb2, unwind continue];
+ _12 = const 3_usize;
+ _13 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 3_usize, const 1_usize) -> [success: bb2, unwind continue];
}
bb2: {

View file

@ -55,11 +55,12 @@
_10 = const _;
StorageLive(_11);
_11 = const 1_usize;
_12 = Len((*_10));
- _12 = Len((*_10));
- _13 = Lt(_11, _12);
- assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, _11) -> [success: bb2, unwind unreachable];
+ _13 = Lt(const 1_usize, _12);
+ assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, const 1_usize) -> [success: bb2, unwind unreachable];
+ _12 = const 3_usize;
+ _13 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 3_usize, const 1_usize) -> [success: bb2, unwind unreachable];
}
bb2: {

View file

@ -55,11 +55,12 @@
_10 = const _;
StorageLive(_11);
_11 = const 1_usize;
_12 = Len((*_10));
- _12 = Len((*_10));
- _13 = Lt(_11, _12);
- assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, _11) -> [success: bb2, unwind continue];
+ _13 = Lt(const 1_usize, _12);
+ assert(move _13, "index out of bounds: the length is {} but the index is {}", move _12, const 1_usize) -> [success: bb2, unwind continue];
+ _12 = const 3_usize;
+ _13 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 3_usize, const 1_usize) -> [success: bb2, unwind continue];
}
bb2: {

View file

@ -7,6 +7,7 @@
let mut _3: i32;
let mut _5: i32;
let mut _6: i32;
let mut _11: BigStruct;
scope 1 {
debug s => _1;
let _2: i32;
@ -15,6 +16,16 @@
let _4: i32;
scope 3 {
debug b => _4;
let _7: S;
let _8: u8;
let _9: f32;
let _10: S;
scope 4 {
debug a => _7;
debug b => _8;
debug c => _9;
debug d => _10;
}
}
}
}
@ -41,7 +52,26 @@
+ _4 = const 6_i32;
StorageDead(_6);
StorageDead(_5);
StorageLive(_11);
_11 = const _;
StorageLive(_7);
- _7 = move (_11.0: S);
+ _7 = const S(1_i32);
StorageLive(_8);
- _8 = (_11.1: u8);
+ _8 = const 5_u8;
StorageLive(_9);
- _9 = (_11.2: f32);
+ _9 = const 7f32;
StorageLive(_10);
- _10 = move (_11.3: S);
+ _10 = const S(13_i32);
StorageDead(_11);
_0 = const ();
StorageDead(_10);
StorageDead(_9);
StorageDead(_8);
StorageDead(_7);
StorageDead(_4);
StorageDead(_2);
StorageDead(_1);

View file

@ -2,10 +2,15 @@
struct S(i32);
struct BigStruct(S, u8, f32, S);
// EMIT_MIR struct.main.DataflowConstProp.diff
fn main() {
let mut s = S(1);
let a = s.0 + 2;
s.0 = 3;
let b = a + s.0;
const VAL: BigStruct = BigStruct(S(1), 5, 7., S(13));
let BigStruct(a, b, c, d) = VAL;
}