diff --git a/compiler/rustc_trait_selection/src/traits/vtable.rs b/compiler/rustc_trait_selection/src/traits/vtable.rs index a2760fe6049..5aed55ef2d2 100644 --- a/compiler/rustc_trait_selection/src/traits/vtable.rs +++ b/compiler/rustc_trait_selection/src/traits/vtable.rs @@ -2,6 +2,8 @@ use std::fmt::Debug; use std::ops::ControlFlow; use rustc_hir::def_id::DefId; +use rustc_infer::infer::{BoundRegionConversionTime, TyCtxtInferExt}; +use rustc_infer::traits::ObligationCause; use rustc_infer::traits::util::PredicateSet; use rustc_middle::bug; use rustc_middle::query::Providers; @@ -13,7 +15,7 @@ use smallvec::{SmallVec, smallvec}; use tracing::debug; use crate::errors::DumpVTableEntries; -use crate::traits::{impossible_predicates, is_vtable_safe_method}; +use crate::traits::{ObligationCtxt, impossible_predicates, is_vtable_safe_method}; #[derive(Clone, Debug)] pub enum VtblSegment<'tcx> { @@ -383,17 +385,37 @@ pub(crate) fn supertrait_vtable_slot<'tcx>( let ty::Dynamic(target, _, _) = *target.kind() else { bug!(); }; - let target_principal = tcx - .normalize_erasing_regions(ty::ParamEnv::reveal_all(), target.principal()?) - .with_self_ty(tcx, tcx.types.trait_object_dummy_self); + let target_principal = target.principal()?.with_self_ty(tcx, tcx.types.trait_object_dummy_self); // Given that we have a target principal, it is a bug for there not to be a source principal. let ty::Dynamic(source, _, _) = *source.kind() else { bug!(); }; - let source_principal = tcx - .normalize_erasing_regions(ty::ParamEnv::reveal_all(), source.principal().unwrap()) - .with_self_ty(tcx, tcx.types.trait_object_dummy_self); + let source_principal = + source.principal().unwrap().with_self_ty(tcx, tcx.types.trait_object_dummy_self); + + let infcx = tcx.infer_ctxt().build(); + let param_env = ty::ParamEnv::reveal_all(); + let trait_refs_are_compatible = + |source: ty::PolyTraitRef<'tcx>, target: ty::PolyTraitRef<'tcx>| { + infcx.probe(|_| { + let ocx = ObligationCtxt::new(&infcx); + let source = ocx.normalize(&ObligationCause::dummy(), param_env, source); + let target = ocx.normalize(&ObligationCause::dummy(), param_env, target); + infcx.enter_forall(target, |target| { + let source = infcx.instantiate_binder_with_fresh_vars( + DUMMY_SP, + BoundRegionConversionTime::HigherRankedType, + source, + ); + let Ok(()) = ocx.eq(&ObligationCause::dummy(), param_env, target, source) + else { + return false; + }; + ocx.select_all_or_error().is_empty() + }) + }) + }; let vtable_segment_callback = { let mut vptr_offset = 0; @@ -404,9 +426,7 @@ pub(crate) fn supertrait_vtable_slot<'tcx>( } VtblSegment::TraitOwnEntries { trait_ref, emit_vptr } => { vptr_offset += tcx.own_existential_vtable_entries(trait_ref.def_id()).len(); - if tcx.normalize_erasing_regions(ty::ParamEnv::reveal_all(), trait_ref) - == target_principal - { + if trait_refs_are_compatible(trait_ref, target_principal) { if emit_vptr { return ControlFlow::Break(Some(vptr_offset)); } else { diff --git a/tests/ui/traits/trait-upcasting/higher-ranked-upcasting-ok.rs b/tests/ui/traits/trait-upcasting/higher-ranked-upcasting-ok.rs index 7f793e1269f..c4c070e49fd 100644 --- a/tests/ui/traits/trait-upcasting/higher-ranked-upcasting-ok.rs +++ b/tests/ui/traits/trait-upcasting/higher-ranked-upcasting-ok.rs @@ -1,21 +1,22 @@ //@ revisions: current next //@ ignore-compare-mode-next-solver (explicit revisions) //@[next] compile-flags: -Znext-solver -//@ check-pass +//@ build-pass -// We should be able to instantiate a binder during trait upcasting. -// This test could be `check-pass`, but we should make sure that we -// do so in both trait solvers. +// Check that we are able to instantiate a binder during trait upcasting, +// and that it doesn't cause any issues with codegen either. #![feature(trait_upcasting)] trait Supertrait<'a, 'b> {} trait Subtrait<'a, 'b>: Supertrait<'a, 'b> {} -impl<'a> Supertrait<'a, 'a> for () {} -impl<'a> Subtrait<'a, 'a> for () {} +impl Supertrait<'_, '_> for () {} +impl Subtrait<'_, '_> for () {} fn ok(x: &dyn for<'a, 'b> Subtrait<'a, 'b>) -> &dyn for<'a> Supertrait<'a, 'a> { x } -fn main() {} +fn main() { + ok(&()); +}