From 2db0dc32970a70bcf622e0854c944753a065cb15 Mon Sep 17 00:00:00 2001
From: Arpad Borsos <swatinem@swatinem.de>
Date: Sat, 26 Nov 2022 21:33:12 +0100
Subject: [PATCH] Simplify checking for `GeneratorKind::Async`

Adds a helper method around `generator_kind` that makes matching async constructs simpler.
---
 .../rustc_borrowck/src/diagnostics/region_errors.rs   |  7 +------
 compiler/rustc_lint/src/unused.rs                     |  5 +----
 compiler/rustc_middle/src/ty/context.rs               |  5 +++++
 .../src/traits/error_reporting/suggestions.rs         | 11 ++---------
 .../src/traits/select/candidate_assembly.rs           |  4 +---
 5 files changed, 10 insertions(+), 22 deletions(-)

diff --git a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs
index 534675f1dc0..7aa099433a7 100644
--- a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs
@@ -514,12 +514,7 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
             span: *span,
             ty_err: match output_ty.kind() {
                 ty::Closure(_, _) => FnMutReturnTypeErr::ReturnClosure { span: *span },
-                ty::Generator(def, ..)
-                    if matches!(
-                        self.infcx.tcx.generator_kind(def),
-                        Some(hir::GeneratorKind::Async(_))
-                    ) =>
-                {
+                ty::Generator(def, ..) if self.infcx.tcx.generator_is_async(*def) => {
                     FnMutReturnTypeErr::ReturnAsyncBlock { span: *span }
                 }
                 _ => FnMutReturnTypeErr::ReturnRef { span: *span },
diff --git a/compiler/rustc_lint/src/unused.rs b/compiler/rustc_lint/src/unused.rs
index 43864ed45fa..88ad4c67d93 100644
--- a/compiler/rustc_lint/src/unused.rs
+++ b/compiler/rustc_lint/src/unused.rs
@@ -322,10 +322,7 @@ impl<'tcx> LateLintPass<'tcx> for UnusedResults {
                 ty::Closure(..) => Some(MustUsePath::Closure(span)),
                 ty::Generator(def_id, ..) => {
                     // async fn should be treated as "implementor of `Future`"
-                    let must_use = if matches!(
-                        cx.tcx.generator_kind(def_id),
-                        Some(hir::GeneratorKind::Async(..))
-                    ) {
+                    let must_use = if cx.tcx.generator_is_async(def_id) {
                         let def_id = cx.tcx.lang_items().future_trait().unwrap();
                         is_def_must_use(cx, def_id, span)
                             .map(|inner| MustUsePath::Opaque(Box::new(inner)))
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index bf30a403d9b..1628cca638e 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -1360,6 +1360,11 @@ impl<'tcx> TyCtxt<'tcx> {
         self.diagnostic_items(did.krate).name_to_id.get(&name) == Some(&did)
     }
 
+    /// Returns `true` if the node pointed to by `def_id` is a generator for an async construct.
+    pub fn generator_is_async(self, def_id: DefId) -> bool {
+        matches!(self.generator_kind(def_id), Some(hir::GeneratorKind::Async(_)))
+    }
+
     pub fn stability(self) -> &'tcx stability::Index {
         self.stability_index(())
     }
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index 992ea175516..eeb4693eec3 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -1988,11 +1988,6 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
             .as_local()
             .and_then(|def_id| hir.maybe_body_owned_by(def_id))
             .map(|body_id| hir.body(body_id));
-        let is_async = self
-            .tcx
-            .generator_kind(generator_did)
-            .map(|generator_kind| matches!(generator_kind, hir::GeneratorKind::Async(..)))
-            .unwrap_or(false);
         let mut visitor = AwaitsVisitor::default();
         if let Some(body) = generator_body {
             visitor.visit_body(body);
@@ -2069,6 +2064,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
 
         debug!(?interior_or_upvar_span);
         if let Some(interior_or_upvar_span) = interior_or_upvar_span {
+            let is_async = self.tcx.generator_is_async(generator_did);
             let typeck_results = match generator_data {
                 GeneratorData::Local(typeck_results) => Some(typeck_results),
                 GeneratorData::Foreign(_) => None,
@@ -2641,10 +2637,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
                                 if is_future
                                     && obligated_types.last().map_or(false, |ty| match ty.kind() {
                                         ty::Generator(last_def_id, ..) => {
-                                            matches!(
-                                                tcx.generator_kind(last_def_id),
-                                                Some(GeneratorKind::Async(..))
-                                            )
+                                            tcx.generator_is_async(*last_def_id)
                                         }
                                         _ => false,
                                     })
diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
index 10854ede652..627ed4674b0 100644
--- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
@@ -430,9 +430,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
     ) {
         let self_ty = obligation.self_ty().skip_binder();
         if let ty::Generator(did, ..) = self_ty.kind() {
-            if let Some(rustc_hir::GeneratorKind::Async(_generator_kind)) =
-                self.tcx().generator_kind(did)
-            {
+            if self.tcx().generator_is_async(*did) {
                 debug!(?self_ty, ?obligation, "assemble_future_candidates",);
 
                 candidates.vec.push(FutureCandidate);