From aa0a916c81936ba725b7efb68804a4217b09b43a Mon Sep 17 00:00:00 2001 From: Maybe Waffle Date: Sun, 14 Apr 2024 18:07:40 +0000 Subject: [PATCH] Add a lint against never type fallback affecting unsafe code --- compiler/rustc_hir_typeck/messages.ftl | 4 + compiler/rustc_hir_typeck/src/errors.rs | 6 +- compiler/rustc_hir_typeck/src/fallback.rs | 135 ++++++++++++++++-- compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs | 4 +- compiler/rustc_lint_defs/src/builtin.rs | 44 ++++++ ...never-type-fallback-flowing-into-unsafe.rs | 35 +++++ ...r-type-fallback-flowing-into-unsafe.stderr | 23 +++ 7 files changed, 241 insertions(+), 10 deletions(-) create mode 100644 tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs create mode 100644 tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr diff --git a/compiler/rustc_hir_typeck/messages.ftl b/compiler/rustc_hir_typeck/messages.ftl index 07b4948872d..0caebf44a19 100644 --- a/compiler/rustc_hir_typeck/messages.ftl +++ b/compiler/rustc_hir_typeck/messages.ftl @@ -99,6 +99,10 @@ hir_typeck_lossy_provenance_ptr2int = hir_typeck_missing_parentheses_in_range = can't call method `{$method_name}` on type `{$ty_str}` +hir_typeck_never_type_fallback_flowing_into_unsafe = + never type fallback affects this call to an `unsafe` function + .help = specify the type explicitly + hir_typeck_no_associated_item = no {$item_kind} named `{$item_name}` found for {$ty_prefix} `{$ty_str}`{$trait_missing_method -> [true] {""} *[other] {" "}in the current scope diff --git a/compiler/rustc_hir_typeck/src/errors.rs b/compiler/rustc_hir_typeck/src/errors.rs index 1c4d5657b17..fcad88f829e 100644 --- a/compiler/rustc_hir_typeck/src/errors.rs +++ b/compiler/rustc_hir_typeck/src/errors.rs @@ -164,6 +164,11 @@ pub struct MissingParenthesesInRange { pub add_missing_parentheses: Option, } +#[derive(LintDiagnostic)] +#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe)] +#[help] +pub struct NeverTypeFallbackFlowingIntoUnsafe {} + #[derive(Subdiagnostic)] #[multipart_suggestion( hir_typeck_add_missing_parentheses_in_range, @@ -632,7 +637,6 @@ pub enum SuggestBoxingForReturnImplTrait { ends: Vec, }, } - #[derive(LintDiagnostic)] #[diag(hir_typeck_dereferencing_mut_binding)] pub struct DereferencingMutBinding { diff --git a/compiler/rustc_hir_typeck/src/fallback.rs b/compiler/rustc_hir_typeck/src/fallback.rs index 3b00c7353e5..86a75aa4d78 100644 --- a/compiler/rustc_hir_typeck/src/fallback.rs +++ b/compiler/rustc_hir_typeck/src/fallback.rs @@ -1,10 +1,15 @@ -use crate::FnCtxt; +use std::cell::OnceCell; + +use crate::{errors, FnCtxt}; use rustc_data_structures::{ graph::{self, iterate::DepthFirstSearch, vec_graph::VecGraph}, unord::{UnordBag, UnordMap, UnordSet}, }; +use rustc_hir::HirId; use rustc_infer::infer::{DefineOpaqueTypes, InferOk}; -use rustc_middle::ty::{self, Ty}; +use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitable}; +use rustc_session::lint; +use rustc_span::Span; use rustc_span::DUMMY_SP; #[derive(Copy, Clone)] @@ -335,6 +340,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { // reach a member of N. If so, it falls back to `()`. Else // `!`. let mut diverging_fallback = UnordMap::with_capacity(diverging_vids.len()); + let unsafe_infer_vars = OnceCell::new(); for &diverging_vid in &diverging_vids { let diverging_ty = Ty::new_var(self.tcx, diverging_vid); let root_vid = self.root_var(diverging_vid); @@ -354,11 +360,35 @@ impl<'tcx> FnCtxt<'_, 'tcx> { output: infer_var_infos.items().any(|info| info.output), }; + let mut fallback_to = |ty| { + let unsafe_infer_vars = unsafe_infer_vars.get_or_init(|| { + let unsafe_infer_vars = compute_unsafe_infer_vars(self.root_ctxt, self.body_id); + debug!(?unsafe_infer_vars); + unsafe_infer_vars + }); + + let affected_unsafe_infer_vars = + graph::depth_first_search_as_undirected(&coercion_graph, root_vid) + .filter_map(|x| unsafe_infer_vars.get(&x).copied()) + .collect::>(); + + for (hir_id, span) in affected_unsafe_infer_vars { + self.tcx.emit_node_span_lint( + lint::builtin::NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE, + hir_id, + span, + errors::NeverTypeFallbackFlowingIntoUnsafe {}, + ); + } + + diverging_fallback.insert(diverging_ty, ty); + }; + use DivergingFallbackBehavior::*; match behavior { FallbackToUnit => { debug!("fallback to () - legacy: {:?}", diverging_vid); - diverging_fallback.insert(diverging_ty, self.tcx.types.unit); + fallback_to(self.tcx.types.unit); } FallbackToNiko => { if found_infer_var_info.self_in_trait && found_infer_var_info.output { @@ -387,13 +417,13 @@ impl<'tcx> FnCtxt<'_, 'tcx> { // set, see the relationship finding module in // compiler/rustc_trait_selection/src/traits/relationships.rs. debug!("fallback to () - found trait and projection: {:?}", diverging_vid); - diverging_fallback.insert(diverging_ty, self.tcx.types.unit); + fallback_to(self.tcx.types.unit); } else if can_reach_non_diverging { debug!("fallback to () - reached non-diverging: {:?}", diverging_vid); - diverging_fallback.insert(diverging_ty, self.tcx.types.unit); + fallback_to(self.tcx.types.unit); } else { debug!("fallback to ! - all diverging: {:?}", diverging_vid); - diverging_fallback.insert(diverging_ty, self.tcx.types.never); + fallback_to(self.tcx.types.never); } } FallbackToNever => { @@ -401,7 +431,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { "fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}", diverging_vid ); - diverging_fallback.insert(diverging_ty, self.tcx.types.never); + fallback_to(self.tcx.types.never); } NoFallback => { debug!( @@ -417,7 +447,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { /// Returns a graph whose nodes are (unresolved) inference variables and where /// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`. - fn create_coercion_graph(&self) -> VecGraph { + fn create_coercion_graph(&self) -> VecGraph { let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations(); debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations); let coercion_edges: Vec<(ty::TyVid, ty::TyVid)> = pending_obligations @@ -451,6 +481,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { .collect(); debug!("create_coercion_graph: coercion_edges={:?}", coercion_edges); let num_ty_vars = self.num_ty_vars(); + VecGraph::new(num_ty_vars, coercion_edges) } @@ -459,3 +490,91 @@ impl<'tcx> FnCtxt<'_, 'tcx> { Some(self.root_var(self.shallow_resolve(ty).ty_vid()?)) } } + +/// Finds all type variables which are passed to an `unsafe` function. +/// +/// For example, for this function `f`: +/// ```ignore (demonstrative) +/// fn f() { +/// unsafe { +/// let x /* ?X */ = core::mem::zeroed(); +/// // ^^^^^^^^^^^^^^^^^^^ -- hir_id, span +/// +/// let y = core::mem::zeroed::>(); +/// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span +/// } +/// } +/// ``` +/// +/// Will return `{ id(?X) -> (hir_id, span) }` +fn compute_unsafe_infer_vars<'a, 'tcx>( + root_ctxt: &'a crate::TypeckRootCtxt<'tcx>, + body_id: rustc_span::def_id::LocalDefId, +) -> UnordMap { + use rustc_hir as hir; + + let tcx = root_ctxt.infcx.tcx; + let body_id = tcx.hir().maybe_body_owned_by(body_id).unwrap(); + let body = tcx.hir().body(body_id); + let mut res = <_>::default(); + + struct UnsafeInferVarsVisitor<'a, 'tcx, 'r> { + root_ctxt: &'a crate::TypeckRootCtxt<'tcx>, + res: &'r mut UnordMap, + } + + use hir::intravisit::Visitor; + impl hir::intravisit::Visitor<'_> for UnsafeInferVarsVisitor<'_, '_, '_> { + fn visit_expr(&mut self, ex: &'_ hir::Expr<'_>) { + // FIXME: method calls + if let hir::ExprKind::Call(func, ..) = ex.kind { + let typeck_results = self.root_ctxt.typeck_results.borrow(); + + let func_ty = typeck_results.expr_ty(func); + + // `is_fn` is required to ignore closures (which can't be unsafe) + if func_ty.is_fn() + && let sig = func_ty.fn_sig(self.root_ctxt.infcx.tcx) + && let hir::Unsafety::Unsafe = sig.unsafety() + { + let mut collector = + InferVarCollector { hir_id: ex.hir_id, call_span: ex.span, res: self.res }; + + // Collect generic arguments of the function which are inference variables + typeck_results + .node_args(ex.hir_id) + .types() + .for_each(|t| t.visit_with(&mut collector)); + + // Also check the return type, for cases like `(unsafe_fn::<_> as unsafe fn() -> _)()` + sig.output().visit_with(&mut collector); + } + } + + hir::intravisit::walk_expr(self, ex); + } + } + + struct InferVarCollector<'r> { + hir_id: HirId, + call_span: Span, + res: &'r mut UnordMap, + } + + impl<'tcx> ty::TypeVisitor> for InferVarCollector<'_> { + fn visit_ty(&mut self, t: Ty<'tcx>) { + if let Some(vid) = t.ty_vid() { + self.res.insert(vid, (self.hir_id, self.call_span)); + } else { + use ty::TypeSuperVisitable as _; + t.super_visit_with(self) + } + } + } + + UnsafeInferVarsVisitor { root_ctxt, res: &mut res }.visit_expr(&body.value); + + debug!(?res, "collected the following unsafe vars for {body_id:?}"); + + res +} diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs index 2f96cf9e373..794b854ca5f 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs @@ -5,12 +5,14 @@ mod checks; mod inspect_obligations; mod suggestions; +use rustc_errors::ErrorGuaranteed; + use crate::coercion::DynamicCoerceMany; use crate::fallback::DivergingFallbackBehavior; use crate::fn_ctxt::checks::DivergingBlockBehavior; use crate::{CoroutineTypes, Diverges, EnclosingBreakables, TypeckRootCtxt}; use hir::def_id::CRATE_DEF_ID; -use rustc_errors::{DiagCtxt, ErrorGuaranteed}; +use rustc_errors::DiagCtxt; use rustc_hir as hir; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer; diff --git a/compiler/rustc_lint_defs/src/builtin.rs b/compiler/rustc_lint_defs/src/builtin.rs index 86a0f33a8d1..664c63da0fc 100644 --- a/compiler/rustc_lint_defs/src/builtin.rs +++ b/compiler/rustc_lint_defs/src/builtin.rs @@ -69,6 +69,7 @@ declare_lint_pass! { MISSING_FRAGMENT_SPECIFIER, MUST_NOT_SUSPEND, NAMED_ARGUMENTS_USED_POSITIONALLY, + NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE, NON_CONTIGUOUS_RANGE_ENDPOINTS, NON_EXHAUSTIVE_OMITTED_PATTERNS, ORDER_DEPENDENT_TRAIT_OBJECTS, @@ -4245,6 +4246,49 @@ declare_lint! { "named arguments in format used positionally" } +declare_lint! { + /// The `never_type_fallback_flowing_into_unsafe` lint detects cases where never type fallback + /// affects unsafe function calls. + /// + /// ### Example + /// + /// ```rust,compile_fail + /// #![deny(never_type_fallback_flowing_into_unsafe)] + /// fn main() { + /// if true { + /// // return has type `!` (never) which, is some cases, causes never type fallback + /// return + /// } else { + /// // `zeroed` is an unsafe function, which returns an unbounded type + /// unsafe { std::mem::zeroed() } + /// }; + /// // depending on the fallback, `zeroed` may create `()` (which is completely sound), + /// // or `!` (which is instant undefined behavior) + /// } + /// ``` + /// + /// {{produces}} + /// + /// ### Explanation + /// + /// Due to historic reasons never type fallback were `()`, meaning that `!` got spontaneously + /// coerced to `()`. There are plans to change that, but they may make the code such as above + /// unsound. Instead of depending on the fallback, you should specify the type explicitly: + /// ``` + /// if true { + /// return + /// } else { + /// // type is explicitly specified, fallback can't hurt us no more + /// unsafe { std::mem::zeroed::<()>() } + /// }; + /// ``` + /// + /// See [Tracking Issue for making `!` fall back to `!`](https://github.com/rust-lang/rust/issues/123748). + pub NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE, + Warn, + "never type fallback affecting unsafe function calls" +} + declare_lint! { /// The `byte_slice_in_packed_struct_with_derive` lint detects cases where a byte slice field /// (`[u8]`) or string slice field (`str`) is used in a `packed` struct that derives one or diff --git a/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs new file mode 100644 index 00000000000..f13e20cc0f2 --- /dev/null +++ b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.rs @@ -0,0 +1,35 @@ +//@ check-pass +use std::mem; + +fn main() { + if false { + unsafe { mem::zeroed() } + //~^ warn: never type fallback affects this call to an `unsafe` function + } else { + return; + }; + + // no ; -> type is inferred without fallback + if true { unsafe { mem::zeroed() } } else { return } +} + +// Minimization of the famous `objc` crate issue +fn _objc() { + pub unsafe fn send_message() -> Result { + Ok(unsafe { core::mem::zeroed() }) + } + + macro_rules! msg_send { + () => { + match send_message::<_ /* ?0 */>() { + //~^ warn: never type fallback affects this call to an `unsafe` function + Ok(x) => x, + Err(_) => loop {}, + } + }; + } + + unsafe { + msg_send!(); + } +} diff --git a/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr new file mode 100644 index 00000000000..1610804c29b --- /dev/null +++ b/tests/ui/never_type/lint-never-type-fallback-flowing-into-unsafe.stderr @@ -0,0 +1,23 @@ +warning: never type fallback affects this call to an `unsafe` function + --> $DIR/lint-never-type-fallback-flowing-into-unsafe.rs:6:18 + | +LL | unsafe { mem::zeroed() } + | ^^^^^^^^^^^^^ + | + = help: specify the type explicitly + = note: `#[warn(never_type_fallback_flowing_into_unsafe)]` on by default + +warning: never type fallback affects this call to an `unsafe` function + --> $DIR/lint-never-type-fallback-flowing-into-unsafe.rs:24:19 + | +LL | match send_message::<_ /* ?0 */>() { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +... +LL | msg_send!(); + | ----------- in this macro invocation + | + = help: specify the type explicitly + = note: this warning originates in the macro `msg_send` (in Nightly builds, run with -Z macro-backtrace for more info) + +warning: 2 warnings emitted +