Make SearchGraph fully generic

This commit is contained in:
Michael Goulet 2024-06-14 18:25:31 -04:00
parent af3d1004c7
commit dba4147633
8 changed files with 149 additions and 95 deletions

View file

@ -10,7 +10,7 @@ use crate::ty::{
mod cache;
pub use cache::{CacheData, EvaluationCache};
pub use cache::EvaluationCache;
pub type Goal<'tcx, P> = ir::solve::Goal<TyCtxt<'tcx>, P>;
pub type QueryInput<'tcx, P> = ir::solve::QueryInput<TyCtxt<'tcx>, P>;

View file

@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock;
use rustc_query_system::cache::WithDepNode;
use rustc_query_system::dep_graph::DepNodeIndex;
use rustc_session::Limit;
use rustc_type_ir::solve::CacheData;
/// The trait solver cache used by `-Znext-solver`.
///
/// FIXME(@lcnr): link to some official documentation of how
@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> {
map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct CacheData<'tcx> {
pub result: QueryResult<'tcx>,
pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>>,
pub additional_depth: usize,
pub encountered_overflow: bool,
}
impl<'tcx> EvaluationCache<'tcx> {
impl<'tcx> rustc_type_ir::inherent::EvaluationCache<TyCtxt<'tcx>> for &'tcx EvaluationCache<'tcx> {
/// Insert a final result into the global cache.
pub fn insert(
fn insert(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> {
if cfg!(debug_assertions) {
drop(map);
let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow };
let actual = self.get(tcx, key, [], Limit(additional_depth));
let actual = self.get(tcx, key, [], additional_depth);
if !actual.as_ref().is_some_and(|actual| expected == *actual) {
bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}");
}
@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> {
/// and handling root goals of coinductive cycles.
///
/// If this returns `Some` the cache result can be used.
pub fn get(
fn get(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>,
available_depth: Limit,
) -> Option<CacheData<'tcx>> {
available_depth: usize,
) -> Option<CacheData<TyCtxt<'tcx>>> {
let map = self.map.borrow();
let entry = map.get(&key)?;
@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> {
}
if let Some(ref success) = entry.success {
if available_depth.value_within_limit(success.additional_depth) {
if Limit(available_depth).value_within_limit(success.additional_depth) {
let QueryData { result, proof_tree } = success.data.get(tcx);
return Some(CacheData {
result,
@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> {
}
}
entry.with_overflow.get(&available_depth.0).map(|e| {
entry.with_overflow.get(&available_depth).map(|e| {
let QueryData { result, proof_tree } = e.get(tcx);
CacheData {
result,
proof_tree,
additional_depth: available_depth.0,
additional_depth: available_depth,
encountered_overflow: true,
}
})

View file

@ -71,6 +71,7 @@ use rustc_target::abi::{FieldIdx, Layout, LayoutS, TargetDataLayout, VariantIdx}
use rustc_target::spec::abi;
use rustc_type_ir::fold::TypeFoldable;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::solve::SolverMode;
use rustc_type_ir::TyKind::*;
use rustc_type_ir::{CollectAndApply, Interner, TypeFlags, WithCachedTypeInfo};
use tracing::{debug, instrument};
@ -139,10 +140,30 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
type Clause = Clause<'tcx>;
type Clauses = ty::Clauses<'tcx>;
type DepNodeIndex = DepNodeIndex;
fn with_cached_task<T>(self, task: impl FnOnce() -> T) -> (T, DepNodeIndex) {
self.dep_graph.with_anon_task(self, crate::dep_graph::dep_kinds::TraitSelect, task)
}
type EvaluationCache = &'tcx solve::EvaluationCache<'tcx>;
fn evaluation_cache(self, mode: SolverMode) -> &'tcx solve::EvaluationCache<'tcx> {
match mode {
SolverMode::Normal => &self.new_solver_evaluation_cache,
SolverMode::Coherence => &self.new_solver_coherence_evaluation_cache,
}
}
fn expand_abstract_consts<T: TypeFoldable<TyCtxt<'tcx>>>(self, t: T) -> T {
self.expand_abstract_consts(t)
}
fn mk_external_constraints(
self,
data: ExternalConstraintsData<Self>,
) -> ExternalConstraints<'tcx> {
self.mk_external_constraints(data)
}
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
self.mk_canonical_var_infos(infos)
}

View file

@ -14,12 +14,11 @@
//! FIXME(@lcnr): Write that section. If you read this before then ask me
//! about it on zulip.
use rustc_hir::def_id::DefId;
use rustc_infer::infer::canonical::{Canonical, CanonicalVarValues};
use rustc_infer::infer::canonical::Canonical;
use rustc_infer::infer::InferCtxt;
use rustc_infer::traits::query::NoSolution;
use rustc_macros::extension;
use rustc_middle::bug;
use rustc_middle::infer::canonical::CanonicalVarInfos;
use rustc_middle::traits::solve::{
CanonicalResponse, Certainty, ExternalConstraintsData, Goal, GoalSource, QueryResult, Response,
};
@ -27,6 +26,8 @@ use rustc_middle::ty::{
self, AliasRelationDirection, CoercePredicate, RegionOutlivesPredicate, SubtypePredicate, Ty,
TyCtxt, TypeOutlivesPredicate, UniverseIndex,
};
use rustc_type_ir::solve::SolverMode;
use rustc_type_ir::{self as ir, Interner};
mod alias_relate;
mod assembly;
@ -57,19 +58,6 @@ pub use select::InferCtxtSelectExt;
/// recursion limit again. However, this feels very unlikely.
const FIXPOINT_STEP_LIMIT: usize = 8;
#[derive(Debug, Clone, Copy)]
enum SolverMode {
/// Ordinary trait solving, using everywhere except for coherence.
Normal,
/// Trait solving during coherence. There are a few notable differences
/// between coherence and ordinary trait solving.
///
/// Most importantly, trait solving during coherence must not be incomplete,
/// i.e. return `Err(NoSolution)` for goals for which a solution exists.
/// This means that we must not make any guesses or arbitrary choices.
Coherence,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum GoalEvaluationKind {
Root,
@ -314,17 +302,17 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
}
}
fn response_no_constraints_raw<'tcx>(
tcx: TyCtxt<'tcx>,
fn response_no_constraints_raw<I: Interner>(
tcx: I,
max_universe: UniverseIndex,
variables: CanonicalVarInfos<'tcx>,
variables: I::CanonicalVars,
certainty: Certainty,
) -> CanonicalResponse<'tcx> {
Canonical {
) -> ir::solve::CanonicalResponse<I> {
ir::Canonical {
max_universe,
variables,
value: Response {
var_values: CanonicalVarValues::make_identity(tcx, variables),
var_values: ir::CanonicalVarValues::make_identity(tcx, variables),
// FIXME: maybe we should store the "no response" version in tcx, like
// we do for tcx.types and stuff.
external_constraints: tcx.mk_external_constraints(ExternalConstraintsData::default()),

View file

@ -3,14 +3,11 @@ use std::mem;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_index::Idx;
use rustc_index::IndexVec;
use rustc_infer::infer::InferCtxt;
use rustc_middle::dep_graph::dep_kinds;
use rustc_middle::traits::solve::CacheData;
use rustc_middle::traits::solve::EvaluationCache;
use rustc_middle::ty::TyCtxt;
use rustc_next_trait_solver::solve::CacheData;
use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult};
use rustc_session::Limit;
use rustc_type_ir::inherent::*;
use rustc_type_ir::InferCtxtLike;
use rustc_type_ir::Interner;
use super::inspect;
@ -240,34 +237,26 @@ impl<I: Interner> SearchGraph<I> {
!entry.is_empty()
});
}
}
impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
/// The trait solver behavior is different for coherence
/// so we use a separate cache. Alternatively we could use
/// a single cache and share it between coherence and ordinary
/// trait solving.
pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
match self.mode {
SolverMode::Normal => &tcx.new_solver_evaluation_cache,
SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
}
pub(super) fn global_cache(&self, tcx: I) -> I::EvaluationCache {
tcx.evaluation_cache(self.mode)
}
/// Probably the most involved method of the whole solver.
///
/// Given some goal which is proven via the `prove_goal` closure, this
/// handles caching, overflow, and coinductive cycles.
pub(super) fn with_new_goal(
pub(super) fn with_new_goal<Infcx: InferCtxtLike<Interner = I>>(
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<TyCtxt<'tcx>>,
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
mut prove_goal: impl FnMut(
&mut Self,
&mut ProofTreeBuilder<InferCtxt<'tcx>>,
) -> QueryResult<TyCtxt<'tcx>>,
) -> QueryResult<TyCtxt<'tcx>> {
tcx: I,
input: CanonicalInput<I>,
inspect: &mut ProofTreeBuilder<Infcx>,
mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>,
) -> QueryResult<I> {
self.check_invariants();
// Check for overflow.
let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
@ -361,21 +350,20 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
// not tracked by the cache key and from outside of this anon task, it
// must not be added to the global cache. Notably, this is the case for
// trait solver cycles participants.
let ((final_entry, result), dep_node) =
tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || {
for _ in 0..FIXPOINT_STEP_LIMIT {
match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) {
StepResult::Done(final_entry, result) => return (final_entry, result),
StepResult::HasChanged => debug!("fixpoint changed provisional results"),
}
let ((final_entry, result), dep_node) = tcx.with_cached_task(|| {
for _ in 0..FIXPOINT_STEP_LIMIT {
match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) {
StepResult::Done(final_entry, result) => return (final_entry, result),
StepResult::HasChanged => debug!("fixpoint changed provisional results"),
}
}
debug!("canonical cycle overflow");
let current_entry = self.pop_stack();
debug_assert!(current_entry.has_been_used.is_empty());
let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false));
(current_entry, result)
});
debug!("canonical cycle overflow");
let current_entry = self.pop_stack();
debug_assert!(current_entry.has_been_used.is_empty());
let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false));
(current_entry, result)
});
let proof_tree = inspect.finalize_canonical_goal_evaluation(tcx);
@ -423,16 +411,17 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
/// Try to fetch a previously computed result from the global cache,
/// making sure to only do so if it would match the result of reevaluating
/// this goal.
fn lookup_global_cache(
fn lookup_global_cache<Infcx: InferCtxtLike<Interner = I>>(
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<TyCtxt<'tcx>>,
tcx: I,
input: CanonicalInput<I>,
available_depth: Limit,
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
) -> Option<QueryResult<TyCtxt<'tcx>>> {
inspect: &mut ProofTreeBuilder<Infcx>,
) -> Option<QueryResult<I>> {
let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self
.global_cache(tcx)
.get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?;
// TODO: Awkward `Limit -> usize -> Limit`.
.get(tcx, input, self.stack.iter().map(|e| e.input), available_depth.0)?;
// If we're building a proof tree and the current cache entry does not
// contain a proof tree, we do not use the entry but instead recompute
@ -465,21 +454,22 @@ enum StepResult<I: Interner> {
HasChanged,
}
impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
impl<I: Interner> SearchGraph<I> {
/// When we encounter a coinductive cycle, we have to fetch the
/// result of that cycle while we are still computing it. Because
/// of this we continuously recompute the cycle until the result
/// of the previous iteration is equal to the final result, at which
/// point we are done.
fn fixpoint_step_in_task<F>(
fn fixpoint_step_in_task<Infcx, F>(
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<TyCtxt<'tcx>>,
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
tcx: I,
input: CanonicalInput<I>,
inspect: &mut ProofTreeBuilder<Infcx>,
prove_goal: &mut F,
) -> StepResult<TyCtxt<'tcx>>
) -> StepResult<I>
where
F: FnMut(&mut Self, &mut ProofTreeBuilder<InferCtxt<'tcx>>) -> QueryResult<TyCtxt<'tcx>>,
Infcx: InferCtxtLike<Interner = I>,
F: FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>,
{
let result = prove_goal(self, inspect);
let stack_entry = self.pop_stack();
@ -533,15 +523,13 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
}
fn response_no_constraints(
tcx: TyCtxt<'tcx>,
goal: CanonicalInput<TyCtxt<'tcx>>,
tcx: I,
goal: CanonicalInput<I>,
certainty: Certainty,
) -> QueryResult<TyCtxt<'tcx>> {
) -> QueryResult<I> {
Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty))
}
}
impl<I: Interner> SearchGraph<I> {
#[allow(rustc::potential_query_instability)]
fn check_invariants(&self) {
if !cfg!(debug_assertions) {

View file

@ -8,9 +8,11 @@ use std::hash::Hash;
use std::ops::Deref;
use rustc_ast_ir::Mutability;
use rustc_data_structures::fx::FxHashSet;
use crate::fold::{TypeFoldable, TypeSuperFoldable};
use crate::relate::Relate;
use crate::solve::{CacheData, CanonicalInput, QueryResult};
use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable};
use crate::{self as ty, CollectAndApply, Interner, UpcastFrom};
@ -363,3 +365,30 @@ pub trait Features<I: Interner>: Copy {
fn coroutine_clone(self) -> bool;
}
pub trait EvaluationCache<I: Interner> {
/// Insert a final result into the global cache.
fn insert(
&self,
tcx: I,
key: CanonicalInput<I>,
proof_tree: Option<I::CanonicalGoalEvaluationStepRef>,
additional_depth: usize,
encountered_overflow: bool,
cycle_participants: FxHashSet<CanonicalInput<I>>,
dep_node: I::DepNodeIndex,
result: QueryResult<I>,
);
/// Try to fetch a cached result, checking the recursion limit
/// and handling root goals of coinductive cycles.
///
/// If this returns `Some` the cache result can be used.
fn get(
&self,
tcx: I,
key: CanonicalInput<I>,
stack_entries: impl IntoIterator<Item = CanonicalInput<I>>,
available_depth: usize,
) -> Option<CacheData<I>>;
}

View file

@ -10,6 +10,7 @@ use crate::ir_print::IrPrint;
use crate::lang_items::TraitSolverLangItem;
use crate::relate::Relate;
use crate::solve::inspect::CanonicalGoalEvaluationStep;
use crate::solve::{ExternalConstraintsData, SolverMode};
use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable};
use crate::{self as ty};
@ -45,16 +46,26 @@ pub trait Interner:
+ Default;
type BoundVarKind: Copy + Debug + Hash + Eq;
type CanonicalVars: Copy + Debug + Hash + Eq + IntoIterator<Item = ty::CanonicalVarInfo<Self>>;
type PredefinedOpaques: Copy + Debug + Hash + Eq;
type DefiningOpaqueTypes: Copy + Debug + Hash + Default + Eq + TypeVisitable<Self>;
type ExternalConstraints: Copy + Debug + Hash + Eq;
type CanonicalGoalEvaluationStepRef: Copy
+ Debug
+ Hash
+ Eq
+ Deref<Target = CanonicalGoalEvaluationStep<Self>>;
type CanonicalVars: Copy + Debug + Hash + Eq + IntoIterator<Item = ty::CanonicalVarInfo<Self>>;
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars;
type ExternalConstraints: Copy + Debug + Hash + Eq;
fn mk_external_constraints(
self,
data: ExternalConstraintsData<Self>,
) -> Self::ExternalConstraints;
type DepNodeIndex;
fn with_cached_task<T>(self, task: impl FnOnce() -> T) -> (T, Self::DepNodeIndex);
// Kinds of tys
type Ty: Ty<Self>;
type Tys: Tys<Self>;
@ -97,9 +108,10 @@ pub trait Interner:
type Clause: Clause<Self>;
type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable<Self> + Flags;
fn expand_abstract_consts<T: TypeFoldable<Self>>(self, t: T) -> T;
type EvaluationCache: EvaluationCache<Self>;
fn evaluation_cache(self, mode: SolverMode) -> Self::EvaluationCache;
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars;
fn expand_abstract_consts<T: TypeFoldable<Self>>(self, t: T) -> T;
type GenericsOf: GenericsOf<Self>;
fn generics_of(self, def_id: Self::DefId) -> Self::GenericsOf;

View file

@ -57,6 +57,19 @@ pub enum Reveal {
All,
}
#[derive(Debug, Clone, Copy)]
pub enum SolverMode {
/// Ordinary trait solving, using everywhere except for coherence.
Normal,
/// Trait solving during coherence. There are a few notable differences
/// between coherence and ordinary trait solving.
///
/// Most importantly, trait solving during coherence must not be incomplete,
/// i.e. return `Err(NoSolution)` for goals for which a solution exists.
/// This means that we must not make any guesses or arbitrary choices.
Coherence,
}
pub type CanonicalInput<I, T = <I as Interner>::Predicate> = Canonical<I, QueryInput<I, T>>;
pub type CanonicalResponse<I> = Canonical<I, Response<I>>;
/// The result of evaluating a canonical query.
@ -356,3 +369,12 @@ impl MaybeCause {
}
}
}
#[derive(derivative::Derivative)]
#[derivative(PartialEq(bound = ""), Eq(bound = ""), Debug(bound = ""))]
pub struct CacheData<I: Interner> {
pub result: QueryResult<I>,
pub proof_tree: Option<I::CanonicalGoalEvaluationStepRef>,
pub additional_depth: usize,
pub encountered_overflow: bool,
}