diff --git a/compiler/rustc_sanitizers/src/cfi/typeid/itanium_cxx_abi/transform.rs b/compiler/rustc_sanitizers/src/cfi/typeid/itanium_cxx_abi/transform.rs index f0f2d1fefd2..9b05576d721 100644 --- a/compiler/rustc_sanitizers/src/cfi/typeid/itanium_cxx_abi/transform.rs +++ b/compiler/rustc_sanitizers/src/cfi/typeid/itanium_cxx_abi/transform.rs @@ -9,10 +9,10 @@ use rustc_hir::LangItem; use rustc_middle::bug; use rustc_middle::ty::fold::{TypeFolder, TypeSuperFoldable}; use rustc_middle::ty::{ - self, ExistentialPredicateStableCmpExt as _, Instance, IntTy, List, Ty, TyCtxt, TypeFoldable, - TypeVisitableExt, UintTy, + self, ExistentialPredicateStableCmpExt as _, Instance, InstanceKind, IntTy, List, TraitRef, Ty, + TyCtxt, TypeFoldable, TypeVisitableExt, UintTy, }; -use rustc_span::sym; +use rustc_span::{def_id::DefId, sym}; use rustc_trait_selection::traits; use std::iter; use tracing::{debug, instrument}; @@ -360,41 +360,29 @@ pub fn transform_instance<'tcx>( if !options.contains(TransformTyOptions::USE_CONCRETE_SELF) { // Perform type erasure for calls on trait objects by transforming self into a trait object // of the trait that defines the method. - if let Some(impl_id) = tcx.impl_of_method(instance.def_id()) - && let Some(trait_ref) = tcx.impl_trait_ref(impl_id) - { - let impl_method = tcx.associated_item(instance.def_id()); - let method_id = impl_method - .trait_item_def_id - .expect("Part of a trait implementation, but not linked to the def_id?"); - let trait_method = tcx.associated_item(method_id); - let trait_id = trait_ref.skip_binder().def_id; - if traits::is_vtable_safe_method(tcx, trait_id, trait_method) - && tcx.is_object_safe(trait_id) - { - // Trait methods will have a Self polymorphic parameter, where the concreteized - // implementatation will not. We need to walk back to the more general trait method - let trait_ref = tcx.instantiate_and_normalize_erasing_regions( - instance.args, - ty::ParamEnv::reveal_all(), - trait_ref, - ); - let invoke_ty = trait_object_ty(tcx, ty::Binder::dummy(trait_ref)); + if let Some((trait_ref, method_id, ancestor)) = implemented_method(tcx, instance) { + // Trait methods will have a Self polymorphic parameter, where the concreteized + // implementatation will not. We need to walk back to the more general trait method + let trait_ref = tcx.instantiate_and_normalize_erasing_regions( + instance.args, + ty::ParamEnv::reveal_all(), + trait_ref, + ); + let invoke_ty = trait_object_ty(tcx, ty::Binder::dummy(trait_ref)); - // At the call site, any call to this concrete function through a vtable will be - // `Virtual(method_id, idx)` with appropriate arguments for the method. Since we have the - // original method id, and we've recovered the trait arguments, we can make the callee - // instance we're computing the alias set for match the caller instance. - // - // Right now, our code ignores the vtable index everywhere, so we use 0 as a placeholder. - // If we ever *do* start encoding the vtable index, we will need to generate an alias set - // based on which vtables we are putting this method into, as there will be more than one - // index value when supertraits are involved. - instance.def = ty::InstanceKind::Virtual(method_id, 0); - let abstract_trait_args = - tcx.mk_args_trait(invoke_ty, trait_ref.args.into_iter().skip(1)); - instance.args = instance.args.rebase_onto(tcx, impl_id, abstract_trait_args); - } + // At the call site, any call to this concrete function through a vtable will be + // `Virtual(method_id, idx)` with appropriate arguments for the method. Since we have the + // original method id, and we've recovered the trait arguments, we can make the callee + // instance we're computing the alias set for match the caller instance. + // + // Right now, our code ignores the vtable index everywhere, so we use 0 as a placeholder. + // If we ever *do* start encoding the vtable index, we will need to generate an alias set + // based on which vtables we are putting this method into, as there will be more than one + // index value when supertraits are involved. + instance.def = ty::InstanceKind::Virtual(method_id, 0); + let abstract_trait_args = + tcx.mk_args_trait(invoke_ty, trait_ref.args.into_iter().skip(1)); + instance.args = instance.args.rebase_onto(tcx, ancestor, abstract_trait_args); } else if tcx.is_closure_like(instance.def_id()) { // We're either a closure or a coroutine. Our goal is to find the trait we're defined on, // instantiate it, and take the type of its only method as our own. @@ -452,3 +440,36 @@ pub fn transform_instance<'tcx>( instance } + +fn implemented_method<'tcx>( + tcx: TyCtxt<'tcx>, + instance: Instance<'tcx>, +) -> Option<(ty::EarlyBinder<'tcx, TraitRef<'tcx>>, DefId, DefId)> { + let trait_ref; + let method_id; + let trait_id; + let trait_method; + let ancestor = if let Some(impl_id) = tcx.impl_of_method(instance.def_id()) { + // Implementation in an `impl` block + trait_ref = tcx.impl_trait_ref(impl_id)?; + let impl_method = tcx.associated_item(instance.def_id()); + method_id = impl_method.trait_item_def_id?; + trait_method = tcx.associated_item(method_id); + trait_id = trait_ref.skip_binder().def_id; + impl_id + } else if let InstanceKind::Item(def_id) = instance.def + && let Some(trait_method_bound) = tcx.opt_associated_item(def_id) + { + // Provided method in a `trait` block + trait_method = trait_method_bound; + method_id = instance.def_id(); + trait_id = tcx.trait_of_item(method_id)?; + trait_ref = ty::EarlyBinder::bind(TraitRef::from_method(tcx, trait_id, instance.args)); + trait_id + } else { + return None; + }; + let vtable_possible = + traits::is_vtable_safe_method(tcx, trait_id, trait_method) && tcx.is_object_safe(trait_id); + vtable_possible.then_some((trait_ref, method_id, ancestor)) +} diff --git a/tests/ui/sanitizer/cfi-supertraits.rs b/tests/ui/sanitizer/cfi-supertraits.rs index ed3d722ebb7..4bb6177577f 100644 --- a/tests/ui/sanitizer/cfi-supertraits.rs +++ b/tests/ui/sanitizer/cfi-supertraits.rs @@ -16,6 +16,9 @@ trait Parent1 { type P1; fn p1(&self) -> Self::P1; + fn d(&self) -> i32 { + 42 + } } trait Parent2 { @@ -60,14 +63,17 @@ fn main() { x.c(); x.p1(); x.p2(); + x.d(); // Parents can be created and access their methods. let y = &Foo as &dyn Parent1; y.p1(); + y.d(); let z = &Foo as &dyn Parent2; z.p2(); // Trait upcasting works let x1 = x as &dyn Parent1; x1.p1(); + x1.d(); let x2 = x as &dyn Parent2; x2.p2(); }