Attribute drop to parent expression of the consume point

This is needed to handle cases like `[a, b.await, c]`. `ExprUseVisitor`
considers `a` to be consumed when it is passed to the array, but the
array is not quite live yet at that point. This means we were missing
the `a` value across the await point. Attributing drops to the parent
expression means we do not consider the value consumed until the
consuming expression has finished.

Issue #57478
This commit is contained in:
Eric Holk 2021-10-25 17:01:24 -07:00
parent f664cfc47c
commit f246c0b116
5 changed files with 87 additions and 31 deletions

View file

@ -6,7 +6,7 @@
use crate::expr_use_visitor::{self, ExprUseVisitor};
use super::FnCtxt;
use hir::HirIdMap;
use hir::{HirIdMap, Node};
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_errors::pluralize;
use rustc_hir as hir;
@ -15,6 +15,7 @@ use rustc_hir::def_id::DefId;
use rustc_hir::hir_id::HirIdSet;
use rustc_hir::intravisit::{self, Visitor};
use rustc_hir::{Arm, Expr, ExprKind, Guard, HirId, Pat, PatKind};
use rustc_middle::hir::map::Map;
use rustc_middle::hir::place::{Place, PlaceBase};
use rustc_middle::middle::region::{self, YieldData};
use rustc_middle::ty::{self, Ty, TyCtxt};
@ -225,6 +226,7 @@ pub fn resolve_interior<'a, 'tcx>(
let mut visitor = {
let mut drop_range_visitor = DropRangeVisitor {
hir: fcx.tcx.hir(),
consumed_places: <_>::default(),
borrowed_places: <_>::default(),
drop_ranges: vec![<_>::default()],
@ -664,19 +666,28 @@ fn check_must_not_suspend_def(
}
/// This struct facilitates computing the ranges for which a place is uninitialized.
struct DropRangeVisitor {
consumed_places: HirIdSet,
struct DropRangeVisitor<'tcx> {
hir: Map<'tcx>,
/// Maps a HirId to a set of HirIds that are dropped by that node.
consumed_places: HirIdMap<HirIdSet>,
borrowed_places: HirIdSet,
drop_ranges: Vec<HirIdMap<DropRange>>,
expr_count: usize,
}
impl DropRangeVisitor {
impl DropRangeVisitor<'tcx> {
fn mark_consumed(&mut self, consumer: HirId, target: HirId) {
if !self.consumed_places.contains_key(&consumer) {
self.consumed_places.insert(consumer, <_>::default());
}
self.consumed_places.get_mut(&consumer).map(|places| places.insert(target));
}
fn record_drop(&mut self, hir_id: HirId) {
let drop_ranges = self.drop_ranges.last_mut().unwrap();
if self.borrowed_places.contains(&hir_id) {
debug!("not marking {:?} as dropped because it is borrowed at some point", hir_id);
} else if self.consumed_places.contains(&hir_id) {
} else {
debug!("marking {:?} as dropped at {}", hir_id, self.expr_count);
drop_ranges.insert(hir_id, DropRange { dropped_at: self.expr_count });
}
@ -700,15 +711,24 @@ impl DropRangeVisitor {
/// ExprUseVisitor's consume callback doesn't go deep enough for our purposes in all
/// expressions. This method consumes a little deeper into the expression when needed.
fn consume_expr(&mut self, expr: &hir::Expr<'_>) {
self.record_drop(expr.hir_id);
match expr.kind {
hir::ExprKind::Path(hir::QPath::Resolved(
_,
hir::Path { res: hir::def::Res::Local(hir_id), .. },
)) => {
self.record_drop(*hir_id);
debug!("consuming expr {:?}, count={}", expr.hir_id, self.expr_count);
let places = self
.consumed_places
.get(&expr.hir_id)
.map_or(vec![], |places| places.iter().cloned().collect());
for place in places {
self.record_drop(place);
if let Some(Node::Expr(expr)) = self.hir.find(place) {
match expr.kind {
hir::ExprKind::Path(hir::QPath::Resolved(
_,
hir::Path { res: hir::def::Res::Local(hir_id), .. },
)) => {
self.record_drop(*hir_id);
}
_ => (),
}
}
_ => (),
}
}
}
@ -721,15 +741,19 @@ fn place_hir_id(place: &Place<'_>) -> Option<HirId> {
}
}
impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor {
impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor<'tcx> {
fn consume(
&mut self,
place_with_id: &expr_use_visitor::PlaceWithHirId<'tcx>,
diag_expr_id: hir::HirId,
) {
debug!("consume {:?}; diag_expr_id={:?}", place_with_id, diag_expr_id);
self.consumed_places.insert(place_with_id.hir_id);
place_hir_id(&place_with_id.place).map(|place| self.consumed_places.insert(place));
let parent = match self.hir.find_parent_node(place_with_id.hir_id) {
Some(parent) => parent,
None => place_with_id.hir_id,
};
debug!("consume {:?}; diag_expr_id={:?}, using parent {:?}", place_with_id, diag_expr_id, parent);
self.mark_consumed(parent, place_with_id.hir_id);
place_hir_id(&place_with_id.place).map(|place| self.mark_consumed(parent, place));
}
fn borrow(
@ -757,7 +781,7 @@ impl<'tcx> expr_use_visitor::Delegate<'tcx> for DropRangeVisitor {
}
}
impl<'tcx> Visitor<'tcx> for DropRangeVisitor {
impl<'tcx> Visitor<'tcx> for DropRangeVisitor<'tcx> {
type Map = intravisit::ErasedMap<'tcx>;
fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
@ -766,20 +790,20 @@ impl<'tcx> Visitor<'tcx> for DropRangeVisitor {
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
match expr.kind {
ExprKind::AssignOp(_, lhs, rhs) => {
ExprKind::AssignOp(_op, lhs, rhs) => {
// These operations are weird because their order of evaluation depends on whether
// the operator is overloaded. In a perfect world, we'd just ask the type checker
// whether this is a method call, but we also need to match the expression IDs
// from RegionResolutionVisitor. RegionResolutionVisitor doesn't know the order,
// so it runs both orders and picks the most conservative. We'll mirror that here.
let mut old_count = self.expr_count;
intravisit::walk_expr(self, lhs);
intravisit::walk_expr(self, rhs);
self.visit_expr(lhs);
self.visit_expr(rhs);
self.push_drop_scope();
std::mem::swap(&mut old_count, &mut self.expr_count);
intravisit::walk_expr(self, rhs);
intravisit::walk_expr(self, lhs);
self.visit_expr(rhs);
self.visit_expr(lhs);
// We should have visited the same number of expressions in either order.
assert_eq!(old_count, self.expr_count);

View file

@ -8,8 +8,16 @@ async fn bar<T>() -> () {}
async fn foo() {
bar().await;
//~^ ERROR type inside `async fn` body must be known in this context
//~| ERROR type inside `async fn` body must be known in this context
//~| ERROR type inside `async fn` body must be known in this context
//~| NOTE cannot infer type for type parameter `T`
//~| NOTE cannot infer type for type parameter `T`
//~| NOTE cannot infer type for type parameter `T`
//~| NOTE the type is part of the `async fn` body because of this `await`
//~| NOTE the type is part of the `async fn` body because of this `await`
//~| NOTE the type is part of the `async fn` body because of this `await`
//~| NOTE in this expansion of desugaring of `await`
//~| NOTE in this expansion of desugaring of `await`
//~| NOTE in this expansion of desugaring of `await`
}
fn main() {}

View file

@ -10,6 +10,30 @@ note: the type is part of the `async fn` body because of this `await`
LL | bar().await;
| ^^^^^^
error: aborting due to previous error
error[E0698]: type inside `async fn` body must be known in this context
--> $DIR/unresolved_type_param.rs:9:5
|
LL | bar().await;
| ^^^ cannot infer type for type parameter `T` declared on the function `bar`
|
note: the type is part of the `async fn` body because of this `await`
--> $DIR/unresolved_type_param.rs:9:5
|
LL | bar().await;
| ^^^^^^^^^^^
error[E0698]: type inside `async fn` body must be known in this context
--> $DIR/unresolved_type_param.rs:9:5
|
LL | bar().await;
| ^^^ cannot infer type for type parameter `T` declared on the function `bar`
|
note: the type is part of the `async fn` body because of this `await`
--> $DIR/unresolved_type_param.rs:9:5
|
LL | bar().await;
| ^^^^^^^^^^^
error: aborting due to 3 previous errors
For more information about this error, try `rustc --explain E0698`.

View file

@ -13,7 +13,7 @@ async fn wheeee<T>(t: T) {
}
async fn yes() {
wheeee(No {}).await; //~ ERROR `No` held across
wheeee(&No {}).await; //~ ERROR `No` held across
}
fn main() {

View file

@ -1,8 +1,8 @@
error: `No` held across a suspend point, but should not be
--> $DIR/dedup.rs:16:12
--> $DIR/dedup.rs:16:13
|
LL | wheeee(No {}).await;
| ^^^^^ ------ the value is held across this suspend point
LL | wheeee(&No {}).await;
| --------^^^^^------- the value is held across this suspend point
|
note: the lint level is defined here
--> $DIR/dedup.rs:3:9
@ -10,10 +10,10 @@ note: the lint level is defined here
LL | #![deny(must_not_suspend)]
| ^^^^^^^^^^^^^^^^
help: consider using a block (`{ ... }`) to shrink the value's scope, ending before the suspend point
--> $DIR/dedup.rs:16:12
--> $DIR/dedup.rs:16:13
|
LL | wheeee(No {}).await;
| ^^^^^
LL | wheeee(&No {}).await;
| ^^^^^
error: aborting due to previous error