Auto merge of #127036 - cjgillot:sparse-state, r=oli-obk
Make jump threading state sparse Continuation of https://github.com/rust-lang/rust/pull/127024 Both dataflow const-prop and jump threading involve cloning the state vector a lot. This PR replaces the data structure by a sparse vector, considering: - that jump threading state is typically very sparse (at most 1 or 2 set entries); - that dataflow const-prop is disabled by default; - that place/value map is very eager, and prone to creating an overly large state. The first commit is shared with the previous PR to avoid needless conflicts. r? `@oli-obk`
This commit is contained in:
commit
2b90614e94
3 changed files with 141 additions and 80 deletions
|
@ -76,6 +76,8 @@ pub trait MeetSemiLattice: Eq {
|
|||
/// A set that has a "bottom" element, which is less than or equal to any other element.
|
||||
pub trait HasBottom {
|
||||
const BOTTOM: Self;
|
||||
|
||||
fn is_bottom(&self) -> bool;
|
||||
}
|
||||
|
||||
/// A set that has a "top" element, which is greater than or equal to any other element.
|
||||
|
@ -114,6 +116,10 @@ impl MeetSemiLattice for bool {
|
|||
|
||||
impl HasBottom for bool {
|
||||
const BOTTOM: Self = false;
|
||||
|
||||
fn is_bottom(&self) -> bool {
|
||||
!self
|
||||
}
|
||||
}
|
||||
|
||||
impl HasTop for bool {
|
||||
|
@ -267,6 +273,10 @@ impl<T: Clone + Eq> MeetSemiLattice for FlatSet<T> {
|
|||
|
||||
impl<T> HasBottom for FlatSet<T> {
|
||||
const BOTTOM: Self = Self::Bottom;
|
||||
|
||||
fn is_bottom(&self) -> bool {
|
||||
matches!(self, Self::Bottom)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> HasTop for FlatSet<T> {
|
||||
|
@ -291,6 +301,10 @@ impl<T> MaybeReachable<T> {
|
|||
|
||||
impl<T> HasBottom for MaybeReachable<T> {
|
||||
const BOTTOM: Self = MaybeReachable::Unreachable;
|
||||
|
||||
fn is_bottom(&self) -> bool {
|
||||
matches!(self, Self::Unreachable)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: HasTop> HasTop for MaybeReachable<T> {
|
||||
|
|
|
@ -36,10 +36,10 @@ use std::collections::VecDeque;
|
|||
use std::fmt::{Debug, Formatter};
|
||||
use std::ops::Range;
|
||||
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_data_structures::fx::{FxHashMap, StdEntry};
|
||||
use rustc_data_structures::stack::ensure_sufficient_stack;
|
||||
use rustc_index::bit_set::BitSet;
|
||||
use rustc_index::{IndexSlice, IndexVec};
|
||||
use rustc_index::IndexVec;
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
|
||||
use rustc_middle::mir::*;
|
||||
|
@ -336,14 +336,13 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper
|
|||
const NAME: &'static str = T::NAME;
|
||||
|
||||
fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain {
|
||||
State(StateData::Unreachable)
|
||||
State::Unreachable
|
||||
}
|
||||
|
||||
fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) {
|
||||
// The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥.
|
||||
assert!(matches!(state.0, StateData::Unreachable));
|
||||
let values = IndexVec::from_elem_n(T::Value::BOTTOM, self.0.map().value_count);
|
||||
*state = State(StateData::Reachable(values));
|
||||
assert!(matches!(state, State::Unreachable));
|
||||
*state = State::new_reachable();
|
||||
for arg in body.args_iter() {
|
||||
state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map());
|
||||
}
|
||||
|
@ -415,27 +414,54 @@ rustc_index::newtype_index!(
|
|||
|
||||
/// See [`State`].
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
enum StateData<V> {
|
||||
Reachable(IndexVec<ValueIndex, V>),
|
||||
Unreachable,
|
||||
pub struct StateData<V> {
|
||||
bottom: V,
|
||||
/// This map only contains values that are not `⊥`.
|
||||
map: FxHashMap<ValueIndex, V>,
|
||||
}
|
||||
|
||||
impl<V: HasBottom> StateData<V> {
|
||||
fn new() -> StateData<V> {
|
||||
StateData { bottom: V::BOTTOM, map: FxHashMap::default() }
|
||||
}
|
||||
|
||||
fn get(&self, idx: ValueIndex) -> &V {
|
||||
self.map.get(&idx).unwrap_or(&self.bottom)
|
||||
}
|
||||
|
||||
fn insert(&mut self, idx: ValueIndex, elem: V) {
|
||||
if elem.is_bottom() {
|
||||
self.map.remove(&idx);
|
||||
} else {
|
||||
self.map.insert(idx, elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Clone> Clone for StateData<V> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Reachable(x) => Self::Reachable(x.clone()),
|
||||
Self::Unreachable => Self::Unreachable,
|
||||
}
|
||||
StateData { bottom: self.bottom.clone(), map: self.map.clone() }
|
||||
}
|
||||
|
||||
fn clone_from(&mut self, source: &Self) {
|
||||
match (&mut *self, source) {
|
||||
(Self::Reachable(x), Self::Reachable(y)) => {
|
||||
// We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
|
||||
x.raw.clone_from(&y.raw);
|
||||
self.map.clone_from(&source.map)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for StateData<V> {
|
||||
fn join(&mut self, other: &Self) -> bool {
|
||||
let mut changed = false;
|
||||
#[allow(rustc::potential_query_instability)]
|
||||
for (i, v) in other.map.iter() {
|
||||
match self.map.entry(*i) {
|
||||
StdEntry::Vacant(e) => {
|
||||
e.insert(v.clone());
|
||||
changed = true
|
||||
}
|
||||
StdEntry::Occupied(e) => changed |= e.into_mut().join(v),
|
||||
}
|
||||
_ => *self = source.clone(),
|
||||
}
|
||||
changed
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -450,33 +476,47 @@ impl<V: Clone> Clone for StateData<V> {
|
|||
///
|
||||
/// Flooding means assigning a value (by default `⊤`) to all tracked projections of a given place.
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub struct State<V>(StateData<V>);
|
||||
pub enum State<V> {
|
||||
Unreachable,
|
||||
Reachable(StateData<V>),
|
||||
}
|
||||
|
||||
impl<V: Clone> Clone for State<V> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
match self {
|
||||
Self::Reachable(x) => Self::Reachable(x.clone()),
|
||||
Self::Unreachable => Self::Unreachable,
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_from(&mut self, source: &Self) {
|
||||
self.0.clone_from(&source.0);
|
||||
match (&mut *self, source) {
|
||||
(Self::Reachable(x), Self::Reachable(y)) => {
|
||||
x.clone_from(&y);
|
||||
}
|
||||
_ => *self = source.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Clone> State<V> {
|
||||
pub fn new(init: V, map: &Map) -> State<V> {
|
||||
let values = IndexVec::from_elem_n(init, map.value_count);
|
||||
State(StateData::Reachable(values))
|
||||
impl<V: Clone + HasBottom> State<V> {
|
||||
pub fn new_reachable() -> State<V> {
|
||||
State::Reachable(StateData::new())
|
||||
}
|
||||
|
||||
pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
|
||||
match self.0 {
|
||||
StateData::Unreachable => true,
|
||||
StateData::Reachable(ref values) => values.iter().all(f),
|
||||
pub fn all_bottom(&self) -> bool {
|
||||
match self {
|
||||
State::Unreachable => false,
|
||||
State::Reachable(ref values) =>
|
||||
{
|
||||
#[allow(rustc::potential_query_instability)]
|
||||
values.map.values().all(V::is_bottom)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_reachable(&self) -> bool {
|
||||
matches!(&self.0, StateData::Reachable(_))
|
||||
matches!(self, State::Reachable(_))
|
||||
}
|
||||
|
||||
/// Assign `value` to all places that are contained in `place` or may alias one.
|
||||
|
@ -519,10 +559,8 @@ impl<V: Clone> State<V> {
|
|||
map: &Map,
|
||||
value: V,
|
||||
) {
|
||||
let StateData::Reachable(values) = &mut self.0 else { return };
|
||||
map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
|
||||
values[vi] = value.clone();
|
||||
});
|
||||
let State::Reachable(values) = self else { return };
|
||||
map.for_each_aliasing_place(place, tail_elem, &mut |vi| values.insert(vi, value.clone()));
|
||||
}
|
||||
|
||||
/// Low-level method that assigns to a place.
|
||||
|
@ -541,9 +579,9 @@ impl<V: Clone> State<V> {
|
|||
///
|
||||
/// The target place must have been flooded before calling this method.
|
||||
pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
|
||||
let StateData::Reachable(values) = &mut self.0 else { return };
|
||||
let State::Reachable(values) = self else { return };
|
||||
if let Some(value_index) = map.places[target].value_index {
|
||||
values[value_index] = value;
|
||||
values.insert(value_index, value)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -555,14 +593,14 @@ impl<V: Clone> State<V> {
|
|||
///
|
||||
/// The target place must have been flooded before calling this method.
|
||||
pub fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
|
||||
let StateData::Reachable(values) = &mut self.0 else { return };
|
||||
let State::Reachable(values) = self else { return };
|
||||
|
||||
// If both places are tracked, we copy the value to the target.
|
||||
// If the target is tracked, but the source is not, we do nothing, as invalidation has
|
||||
// already been performed.
|
||||
if let Some(target_value) = map.places[target].value_index {
|
||||
if let Some(source_value) = map.places[source].value_index {
|
||||
values[target_value] = values[source_value].clone();
|
||||
values.insert(target_value, values.get(source_value).clone());
|
||||
}
|
||||
}
|
||||
for target_child in map.children(target) {
|
||||
|
@ -616,11 +654,11 @@ impl<V: Clone> State<V> {
|
|||
|
||||
/// Retrieve the value stored for a place index, or `None` if it is not tracked.
|
||||
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
|
||||
match &self.0 {
|
||||
StateData::Reachable(values) => {
|
||||
map.places[place].value_index.map(|v| values[v].clone())
|
||||
match self {
|
||||
State::Reachable(values) => {
|
||||
map.places[place].value_index.map(|v| values.get(v).clone())
|
||||
}
|
||||
StateData::Unreachable => None,
|
||||
State::Unreachable => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -631,10 +669,10 @@ impl<V: Clone> State<V> {
|
|||
where
|
||||
V: HasBottom + HasTop,
|
||||
{
|
||||
match &self.0 {
|
||||
StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
|
||||
match self {
|
||||
State::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
|
||||
// Because this is unreachable, we can return any value we want.
|
||||
StateData::Unreachable => V::BOTTOM,
|
||||
State::Unreachable => V::BOTTOM,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -645,10 +683,10 @@ impl<V: Clone> State<V> {
|
|||
where
|
||||
V: HasBottom + HasTop,
|
||||
{
|
||||
match &self.0 {
|
||||
StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
|
||||
match self {
|
||||
State::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
|
||||
// Because this is unreachable, we can return any value we want.
|
||||
StateData::Unreachable => V::BOTTOM,
|
||||
State::Unreachable => V::BOTTOM,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -659,10 +697,10 @@ impl<V: Clone> State<V> {
|
|||
where
|
||||
V: HasBottom + HasTop,
|
||||
{
|
||||
match &self.0 {
|
||||
StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
|
||||
match self {
|
||||
State::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
|
||||
// Because this is unreachable, we can return any value we want.
|
||||
StateData::Unreachable => V::BOTTOM,
|
||||
State::Unreachable => V::BOTTOM,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -673,11 +711,11 @@ impl<V: Clone> State<V> {
|
|||
where
|
||||
V: HasBottom + HasTop,
|
||||
{
|
||||
match &self.0 {
|
||||
StateData::Reachable(values) => {
|
||||
map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP)
|
||||
match self {
|
||||
State::Reachable(values) => {
|
||||
map.places[place].value_index.map(|v| values.get(v).clone()).unwrap_or(V::TOP)
|
||||
}
|
||||
StateData::Unreachable => {
|
||||
State::Unreachable => {
|
||||
// Because this is unreachable, we can return any value we want.
|
||||
V::BOTTOM
|
||||
}
|
||||
|
@ -685,15 +723,15 @@ impl<V: Clone> State<V> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> {
|
||||
impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for State<V> {
|
||||
fn join(&mut self, other: &Self) -> bool {
|
||||
match (&mut self.0, &other.0) {
|
||||
(_, StateData::Unreachable) => false,
|
||||
(StateData::Unreachable, _) => {
|
||||
match (&mut *self, other) {
|
||||
(_, State::Unreachable) => false,
|
||||
(State::Unreachable, _) => {
|
||||
*self = other.clone();
|
||||
true
|
||||
}
|
||||
(StateData::Reachable(this), StateData::Reachable(other)) => this.join(other),
|
||||
(State::Reachable(this), State::Reachable(ref other)) => this.join(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1194,9 +1232,9 @@ where
|
|||
T::Value: Debug,
|
||||
{
|
||||
fn fmt_with(&self, ctxt: &ValueAnalysisWrapper<T>, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.0 {
|
||||
StateData::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f),
|
||||
StateData::Unreachable => write!(f, "unreachable"),
|
||||
match self {
|
||||
State::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f),
|
||||
State::Unreachable => write!(f, "unreachable"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1206,8 +1244,8 @@ where
|
|||
ctxt: &ValueAnalysisWrapper<T>,
|
||||
f: &mut Formatter<'_>,
|
||||
) -> std::fmt::Result {
|
||||
match (&self.0, &old.0) {
|
||||
(StateData::Reachable(this), StateData::Reachable(old)) => {
|
||||
match (self, old) {
|
||||
(State::Reachable(this), State::Reachable(old)) => {
|
||||
debug_with_context(this, Some(old), ctxt.0.map(), f)
|
||||
}
|
||||
_ => Ok(()), // Consider printing something here.
|
||||
|
@ -1215,21 +1253,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn debug_with_context_rec<V: Debug + Eq>(
|
||||
fn debug_with_context_rec<V: Debug + Eq + HasBottom>(
|
||||
place: PlaceIndex,
|
||||
place_str: &str,
|
||||
new: &IndexSlice<ValueIndex, V>,
|
||||
old: Option<&IndexSlice<ValueIndex, V>>,
|
||||
new: &StateData<V>,
|
||||
old: Option<&StateData<V>>,
|
||||
map: &Map,
|
||||
f: &mut Formatter<'_>,
|
||||
) -> std::fmt::Result {
|
||||
if let Some(value) = map.places[place].value_index {
|
||||
match old {
|
||||
None => writeln!(f, "{}: {:?}", place_str, new[value])?,
|
||||
None => writeln!(f, "{}: {:?}", place_str, new.get(value))?,
|
||||
Some(old) => {
|
||||
if new[value] != old[value] {
|
||||
writeln!(f, "\u{001f}-{}: {:?}", place_str, old[value])?;
|
||||
writeln!(f, "\u{001f}+{}: {:?}", place_str, new[value])?;
|
||||
if new.get(value) != old.get(value) {
|
||||
writeln!(f, "\u{001f}-{}: {:?}", place_str, old.get(value))?;
|
||||
writeln!(f, "\u{001f}+{}: {:?}", place_str, new.get(value))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1261,9 +1299,9 @@ fn debug_with_context_rec<V: Debug + Eq>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn debug_with_context<V: Debug + Eq>(
|
||||
new: &IndexSlice<ValueIndex, V>,
|
||||
old: Option<&IndexSlice<ValueIndex, V>>,
|
||||
fn debug_with_context<V: Debug + Eq + HasBottom>(
|
||||
new: &StateData<V>,
|
||||
old: Option<&StateData<V>>,
|
||||
map: &Map,
|
||||
f: &mut Formatter<'_>,
|
||||
) -> std::fmt::Result {
|
||||
|
|
|
@ -47,6 +47,7 @@ use rustc_middle::mir::visit::Visitor;
|
|||
use rustc_middle::mir::*;
|
||||
use rustc_middle::ty::layout::LayoutOf;
|
||||
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
|
||||
use rustc_mir_dataflow::lattice::HasBottom;
|
||||
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
|
||||
use rustc_span::DUMMY_SP;
|
||||
use rustc_target::abi::{TagEncoding, Variants};
|
||||
|
@ -158,9 +159,17 @@ impl Condition {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
struct ConditionSet<'a>(&'a [Condition]);
|
||||
|
||||
impl HasBottom for ConditionSet<'_> {
|
||||
const BOTTOM: Self = ConditionSet(&[]);
|
||||
|
||||
fn is_bottom(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ConditionSet<'a> {
|
||||
fn iter(self) -> impl Iterator<Item = Condition> + 'a {
|
||||
self.0.iter().copied()
|
||||
|
@ -177,7 +186,7 @@ impl<'a> ConditionSet<'a> {
|
|||
|
||||
impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
|
||||
state.all(|cs| cs.0.is_empty())
|
||||
state.all_bottom()
|
||||
}
|
||||
|
||||
/// Recursion entry point to find threading opportunities.
|
||||
|
@ -198,7 +207,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
|||
debug!(?discr);
|
||||
|
||||
let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
|
||||
let mut state = State::new(ConditionSet::default(), self.map);
|
||||
let mut state = State::new_reachable();
|
||||
|
||||
let conds = if let Some((value, then, else_)) = targets.as_static_if() {
|
||||
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
|
||||
|
@ -255,7 +264,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
|||
// _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
|
||||
// _1 = 6
|
||||
if let Some((lhs, tail)) = self.mutated_statement(stmt) {
|
||||
state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default());
|
||||
state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::BOTTOM);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -609,7 +618,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
|||
// We can recurse through this terminator.
|
||||
let mut state = state();
|
||||
if let Some(place_to_flood) = place_to_flood {
|
||||
state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default());
|
||||
state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::BOTTOM);
|
||||
}
|
||||
self.find_opportunity(bb, state, cost.clone(), depth + 1);
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue