diff --git a/compiler/rustc_middle/src/traits/solve.rs b/compiler/rustc_middle/src/traits/solve.rs index 9e979620a44..888e3aec5ea 100644 --- a/compiler/rustc_middle/src/traits/solve.rs +++ b/compiler/rustc_middle/src/traits/solve.rs @@ -10,7 +10,7 @@ use crate::ty::{ mod cache; -pub use cache::{CacheData, EvaluationCache}; +pub use cache::EvaluationCache; pub type Goal<'tcx, P> = ir::solve::Goal, P>; pub type QueryInput<'tcx, P> = ir::solve::QueryInput, P>; diff --git a/compiler/rustc_middle/src/traits/solve/cache.rs b/compiler/rustc_middle/src/traits/solve/cache.rs index dc31114b2c4..72a8d4eb405 100644 --- a/compiler/rustc_middle/src/traits/solve/cache.rs +++ b/compiler/rustc_middle/src/traits/solve/cache.rs @@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock; use rustc_query_system::cache::WithDepNode; use rustc_query_system::dep_graph::DepNodeIndex; use rustc_session::Limit; +use rustc_type_ir::solve::CacheData; + /// The trait solver cache used by `-Znext-solver`. /// /// FIXME(@lcnr): link to some official documentation of how @@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> { map: Lock, CacheEntry<'tcx>>>, } -#[derive(Debug, PartialEq, Eq)] -pub struct CacheData<'tcx> { - pub result: QueryResult<'tcx>, - pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep>>, - pub additional_depth: usize, - pub encountered_overflow: bool, -} - -impl<'tcx> EvaluationCache<'tcx> { +impl<'tcx> rustc_type_ir::inherent::EvaluationCache> for &'tcx EvaluationCache<'tcx> { /// Insert a final result into the global cache. - pub fn insert( + fn insert( &self, tcx: TyCtxt<'tcx>, key: CanonicalInput<'tcx>, @@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> { if cfg!(debug_assertions) { drop(map); let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow }; - let actual = self.get(tcx, key, [], Limit(additional_depth)); + let actual = self.get(tcx, key, [], additional_depth); if !actual.as_ref().is_some_and(|actual| expected == *actual) { bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}"); } @@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> { /// and handling root goals of coinductive cycles. /// /// If this returns `Some` the cache result can be used. - pub fn get( + fn get( &self, tcx: TyCtxt<'tcx>, key: CanonicalInput<'tcx>, stack_entries: impl IntoIterator>, - available_depth: Limit, - ) -> Option> { + available_depth: usize, + ) -> Option>> { let map = self.map.borrow(); let entry = map.get(&key)?; @@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> { } if let Some(ref success) = entry.success { - if available_depth.value_within_limit(success.additional_depth) { + if Limit(available_depth).value_within_limit(success.additional_depth) { let QueryData { result, proof_tree } = success.data.get(tcx); return Some(CacheData { result, @@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> { } } - entry.with_overflow.get(&available_depth.0).map(|e| { + entry.with_overflow.get(&available_depth).map(|e| { let QueryData { result, proof_tree } = e.get(tcx); CacheData { result, proof_tree, - additional_depth: available_depth.0, + additional_depth: available_depth, encountered_overflow: true, } }) diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index e2f15dac019..eec7fa8db1d 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -71,6 +71,7 @@ use rustc_target::abi::{FieldIdx, Layout, LayoutS, TargetDataLayout, VariantIdx} use rustc_target::spec::abi; use rustc_type_ir::fold::TypeFoldable; use rustc_type_ir::lang_items::TraitSolverLangItem; +use rustc_type_ir::solve::SolverMode; use rustc_type_ir::TyKind::*; use rustc_type_ir::{CollectAndApply, Interner, TypeFlags, WithCachedTypeInfo}; use tracing::{debug, instrument}; @@ -139,10 +140,30 @@ impl<'tcx> Interner for TyCtxt<'tcx> { type Clause = Clause<'tcx>; type Clauses = ty::Clauses<'tcx>; + type DepNodeIndex = DepNodeIndex; + fn with_cached_task(self, task: impl FnOnce() -> T) -> (T, DepNodeIndex) { + self.dep_graph.with_anon_task(self, crate::dep_graph::dep_kinds::TraitSelect, task) + } + + type EvaluationCache = &'tcx solve::EvaluationCache<'tcx>; + fn evaluation_cache(self, mode: SolverMode) -> &'tcx solve::EvaluationCache<'tcx> { + match mode { + SolverMode::Normal => &self.new_solver_evaluation_cache, + SolverMode::Coherence => &self.new_solver_coherence_evaluation_cache, + } + } + fn expand_abstract_consts>>(self, t: T) -> T { self.expand_abstract_consts(t) } + fn mk_external_constraints( + self, + data: ExternalConstraintsData, + ) -> ExternalConstraints<'tcx> { + self.mk_external_constraints(data) + } + fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo]) -> Self::CanonicalVars { self.mk_canonical_var_infos(infos) } diff --git a/compiler/rustc_trait_selection/src/solve/mod.rs b/compiler/rustc_trait_selection/src/solve/mod.rs index 4f1be5cbc85..7b6e525370c 100644 --- a/compiler/rustc_trait_selection/src/solve/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/mod.rs @@ -14,12 +14,11 @@ //! FIXME(@lcnr): Write that section. If you read this before then ask me //! about it on zulip. use rustc_hir::def_id::DefId; -use rustc_infer::infer::canonical::{Canonical, CanonicalVarValues}; +use rustc_infer::infer::canonical::Canonical; use rustc_infer::infer::InferCtxt; use rustc_infer::traits::query::NoSolution; use rustc_macros::extension; use rustc_middle::bug; -use rustc_middle::infer::canonical::CanonicalVarInfos; use rustc_middle::traits::solve::{ CanonicalResponse, Certainty, ExternalConstraintsData, Goal, GoalSource, QueryResult, Response, }; @@ -27,6 +26,8 @@ use rustc_middle::ty::{ self, AliasRelationDirection, CoercePredicate, RegionOutlivesPredicate, SubtypePredicate, Ty, TyCtxt, TypeOutlivesPredicate, UniverseIndex, }; +use rustc_type_ir::solve::SolverMode; +use rustc_type_ir::{self as ir, Interner}; mod alias_relate; mod assembly; @@ -57,19 +58,6 @@ pub use select::InferCtxtSelectExt; /// recursion limit again. However, this feels very unlikely. const FIXPOINT_STEP_LIMIT: usize = 8; -#[derive(Debug, Clone, Copy)] -enum SolverMode { - /// Ordinary trait solving, using everywhere except for coherence. - Normal, - /// Trait solving during coherence. There are a few notable differences - /// between coherence and ordinary trait solving. - /// - /// Most importantly, trait solving during coherence must not be incomplete, - /// i.e. return `Err(NoSolution)` for goals for which a solution exists. - /// This means that we must not make any guesses or arbitrary choices. - Coherence, -} - #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum GoalEvaluationKind { Root, @@ -314,17 +302,17 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> { } } -fn response_no_constraints_raw<'tcx>( - tcx: TyCtxt<'tcx>, +fn response_no_constraints_raw( + tcx: I, max_universe: UniverseIndex, - variables: CanonicalVarInfos<'tcx>, + variables: I::CanonicalVars, certainty: Certainty, -) -> CanonicalResponse<'tcx> { - Canonical { +) -> ir::solve::CanonicalResponse { + ir::Canonical { max_universe, variables, value: Response { - var_values: CanonicalVarValues::make_identity(tcx, variables), + var_values: ir::CanonicalVarValues::make_identity(tcx, variables), // FIXME: maybe we should store the "no response" version in tcx, like // we do for tcx.types and stuff. external_constraints: tcx.mk_external_constraints(ExternalConstraintsData::default()), diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs index 84878fea101..681061c25aa 100644 --- a/compiler/rustc_trait_selection/src/solve/search_graph.rs +++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs @@ -3,14 +3,11 @@ use std::mem; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_index::Idx; use rustc_index::IndexVec; -use rustc_infer::infer::InferCtxt; -use rustc_middle::dep_graph::dep_kinds; -use rustc_middle::traits::solve::CacheData; -use rustc_middle::traits::solve::EvaluationCache; -use rustc_middle::ty::TyCtxt; +use rustc_next_trait_solver::solve::CacheData; use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult}; use rustc_session::Limit; use rustc_type_ir::inherent::*; +use rustc_type_ir::InferCtxtLike; use rustc_type_ir::Interner; use super::inspect; @@ -240,34 +237,26 @@ impl SearchGraph { !entry.is_empty() }); } -} -impl<'tcx> SearchGraph> { /// The trait solver behavior is different for coherence /// so we use a separate cache. Alternatively we could use /// a single cache and share it between coherence and ordinary /// trait solving. - pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> { - match self.mode { - SolverMode::Normal => &tcx.new_solver_evaluation_cache, - SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache, - } + pub(super) fn global_cache(&self, tcx: I) -> I::EvaluationCache { + tcx.evaluation_cache(self.mode) } /// Probably the most involved method of the whole solver. /// /// Given some goal which is proven via the `prove_goal` closure, this /// handles caching, overflow, and coinductive cycles. - pub(super) fn with_new_goal( + pub(super) fn with_new_goal>( &mut self, - tcx: TyCtxt<'tcx>, - input: CanonicalInput>, - inspect: &mut ProofTreeBuilder>, - mut prove_goal: impl FnMut( - &mut Self, - &mut ProofTreeBuilder>, - ) -> QueryResult>, - ) -> QueryResult> { + tcx: I, + input: CanonicalInput, + inspect: &mut ProofTreeBuilder, + mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder) -> QueryResult, + ) -> QueryResult { self.check_invariants(); // Check for overflow. let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else { @@ -361,21 +350,20 @@ impl<'tcx> SearchGraph> { // not tracked by the cache key and from outside of this anon task, it // must not be added to the global cache. Notably, this is the case for // trait solver cycles participants. - let ((final_entry, result), dep_node) = - tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || { - for _ in 0..FIXPOINT_STEP_LIMIT { - match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) { - StepResult::Done(final_entry, result) => return (final_entry, result), - StepResult::HasChanged => debug!("fixpoint changed provisional results"), - } + let ((final_entry, result), dep_node) = tcx.with_cached_task(|| { + for _ in 0..FIXPOINT_STEP_LIMIT { + match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) { + StepResult::Done(final_entry, result) => return (final_entry, result), + StepResult::HasChanged => debug!("fixpoint changed provisional results"), } + } - debug!("canonical cycle overflow"); - let current_entry = self.pop_stack(); - debug_assert!(current_entry.has_been_used.is_empty()); - let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false)); - (current_entry, result) - }); + debug!("canonical cycle overflow"); + let current_entry = self.pop_stack(); + debug_assert!(current_entry.has_been_used.is_empty()); + let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false)); + (current_entry, result) + }); let proof_tree = inspect.finalize_canonical_goal_evaluation(tcx); @@ -423,16 +411,17 @@ impl<'tcx> SearchGraph> { /// Try to fetch a previously computed result from the global cache, /// making sure to only do so if it would match the result of reevaluating /// this goal. - fn lookup_global_cache( + fn lookup_global_cache>( &mut self, - tcx: TyCtxt<'tcx>, - input: CanonicalInput>, + tcx: I, + input: CanonicalInput, available_depth: Limit, - inspect: &mut ProofTreeBuilder>, - ) -> Option>> { + inspect: &mut ProofTreeBuilder, + ) -> Option> { let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self .global_cache(tcx) - .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?; + // TODO: Awkward `Limit -> usize -> Limit`. + .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth.0)?; // If we're building a proof tree and the current cache entry does not // contain a proof tree, we do not use the entry but instead recompute @@ -465,21 +454,22 @@ enum StepResult { HasChanged, } -impl<'tcx> SearchGraph> { +impl SearchGraph { /// When we encounter a coinductive cycle, we have to fetch the /// result of that cycle while we are still computing it. Because /// of this we continuously recompute the cycle until the result /// of the previous iteration is equal to the final result, at which /// point we are done. - fn fixpoint_step_in_task( + fn fixpoint_step_in_task( &mut self, - tcx: TyCtxt<'tcx>, - input: CanonicalInput>, - inspect: &mut ProofTreeBuilder>, + tcx: I, + input: CanonicalInput, + inspect: &mut ProofTreeBuilder, prove_goal: &mut F, - ) -> StepResult> + ) -> StepResult where - F: FnMut(&mut Self, &mut ProofTreeBuilder>) -> QueryResult>, + Infcx: InferCtxtLike, + F: FnMut(&mut Self, &mut ProofTreeBuilder) -> QueryResult, { let result = prove_goal(self, inspect); let stack_entry = self.pop_stack(); @@ -533,15 +523,13 @@ impl<'tcx> SearchGraph> { } fn response_no_constraints( - tcx: TyCtxt<'tcx>, - goal: CanonicalInput>, + tcx: I, + goal: CanonicalInput, certainty: Certainty, - ) -> QueryResult> { + ) -> QueryResult { Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty)) } -} -impl SearchGraph { #[allow(rustc::potential_query_instability)] fn check_invariants(&self) { if !cfg!(debug_assertions) { diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index 6b84592978a..4afb9a2339b 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -8,9 +8,11 @@ use std::hash::Hash; use std::ops::Deref; use rustc_ast_ir::Mutability; +use rustc_data_structures::fx::FxHashSet; use crate::fold::{TypeFoldable, TypeSuperFoldable}; use crate::relate::Relate; +use crate::solve::{CacheData, CanonicalInput, QueryResult}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; use crate::{self as ty, CollectAndApply, Interner, UpcastFrom}; @@ -363,3 +365,30 @@ pub trait Features: Copy { fn coroutine_clone(self) -> bool; } + +pub trait EvaluationCache { + /// Insert a final result into the global cache. + fn insert( + &self, + tcx: I, + key: CanonicalInput, + proof_tree: Option, + additional_depth: usize, + encountered_overflow: bool, + cycle_participants: FxHashSet>, + dep_node: I::DepNodeIndex, + result: QueryResult, + ); + + /// Try to fetch a cached result, checking the recursion limit + /// and handling root goals of coinductive cycles. + /// + /// If this returns `Some` the cache result can be used. + fn get( + &self, + tcx: I, + key: CanonicalInput, + stack_entries: impl IntoIterator>, + available_depth: usize, + ) -> Option>; +} diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 11c1f73fef3..b099f63d382 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -10,6 +10,7 @@ use crate::ir_print::IrPrint; use crate::lang_items::TraitSolverLangItem; use crate::relate::Relate; use crate::solve::inspect::CanonicalGoalEvaluationStep; +use crate::solve::{ExternalConstraintsData, SolverMode}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; use crate::{self as ty}; @@ -45,16 +46,26 @@ pub trait Interner: + Default; type BoundVarKind: Copy + Debug + Hash + Eq; - type CanonicalVars: Copy + Debug + Hash + Eq + IntoIterator>; type PredefinedOpaques: Copy + Debug + Hash + Eq; type DefiningOpaqueTypes: Copy + Debug + Hash + Default + Eq + TypeVisitable; - type ExternalConstraints: Copy + Debug + Hash + Eq; type CanonicalGoalEvaluationStepRef: Copy + Debug + Hash + Eq + Deref>; + type CanonicalVars: Copy + Debug + Hash + Eq + IntoIterator>; + fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo]) -> Self::CanonicalVars; + + type ExternalConstraints: Copy + Debug + Hash + Eq; + fn mk_external_constraints( + self, + data: ExternalConstraintsData, + ) -> Self::ExternalConstraints; + + type DepNodeIndex; + fn with_cached_task(self, task: impl FnOnce() -> T) -> (T, Self::DepNodeIndex); + // Kinds of tys type Ty: Ty; type Tys: Tys; @@ -97,9 +108,10 @@ pub trait Interner: type Clause: Clause; type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable + Flags; - fn expand_abstract_consts>(self, t: T) -> T; + type EvaluationCache: EvaluationCache; + fn evaluation_cache(self, mode: SolverMode) -> Self::EvaluationCache; - fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo]) -> Self::CanonicalVars; + fn expand_abstract_consts>(self, t: T) -> T; type GenericsOf: GenericsOf; fn generics_of(self, def_id: Self::DefId) -> Self::GenericsOf; diff --git a/compiler/rustc_type_ir/src/solve.rs b/compiler/rustc_type_ir/src/solve.rs index 99d2fa74494..fc4df7ede9d 100644 --- a/compiler/rustc_type_ir/src/solve.rs +++ b/compiler/rustc_type_ir/src/solve.rs @@ -57,6 +57,19 @@ pub enum Reveal { All, } +#[derive(Debug, Clone, Copy)] +pub enum SolverMode { + /// Ordinary trait solving, using everywhere except for coherence. + Normal, + /// Trait solving during coherence. There are a few notable differences + /// between coherence and ordinary trait solving. + /// + /// Most importantly, trait solving during coherence must not be incomplete, + /// i.e. return `Err(NoSolution)` for goals for which a solution exists. + /// This means that we must not make any guesses or arbitrary choices. + Coherence, +} + pub type CanonicalInput::Predicate> = Canonical>; pub type CanonicalResponse = Canonical>; /// The result of evaluating a canonical query. @@ -356,3 +369,12 @@ impl MaybeCause { } } } + +#[derive(derivative::Derivative)] +#[derivative(PartialEq(bound = ""), Eq(bound = ""), Debug(bound = ""))] +pub struct CacheData { + pub result: QueryResult, + pub proof_tree: Option, + pub additional_depth: usize, + pub encountered_overflow: bool, +}