From 81bf9ae263ffbe75dc74b467b80e9ec4184d444a Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 15 Apr 2024 20:08:21 -0400
Subject: [PATCH] Make thir_tree and thir_flat into hooks

---
 compiler/rustc_middle/src/hooks/mod.rs     |  6 ++++++
 compiler/rustc_middle/src/query/mod.rs     | 14 --------------
 compiler/rustc_mir_build/src/lib.rs        |  4 ++--
 compiler/rustc_mir_build/src/thir/print.rs | 11 ++++++-----
 4 files changed, 14 insertions(+), 21 deletions(-)

diff --git a/compiler/rustc_middle/src/hooks/mod.rs b/compiler/rustc_middle/src/hooks/mod.rs
index f7ce15d0a8d..1872f568907 100644
--- a/compiler/rustc_middle/src/hooks/mod.rs
+++ b/compiler/rustc_middle/src/hooks/mod.rs
@@ -102,4 +102,10 @@ declare_hooks! {
     /// turn a deserialized `DefPathHash` into its current `DefId`.
     /// Will fetch a DefId from a DefPathHash for a foreign crate.
     hook def_path_hash_to_def_id_extern(hash: DefPathHash, stable_crate_id: StableCrateId) -> DefId;
+
+    /// Create a THIR tree for debugging.
+    hook thir_tree(key: LocalDefId) -> String;
+
+    /// Create a list-like THIR representation for debugging.
+    hook thir_flat(key: LocalDefId) -> String;
 }
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 394515f091f..2bef8f1fe3d 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -474,20 +474,6 @@ rustc_queries! {
         desc { |tcx| "building THIR for `{}`", tcx.def_path_str(key) }
     }
 
-    /// Create a THIR tree for debugging.
-    query thir_tree(key: LocalDefId) -> &'tcx String {
-        no_hash
-        arena_cache
-        desc { |tcx| "constructing THIR tree for `{}`", tcx.def_path_str(key) }
-    }
-
-    /// Create a list-like THIR representation for debugging.
-    query thir_flat(key: LocalDefId) -> &'tcx String {
-        no_hash
-        arena_cache
-        desc { |tcx| "constructing flat THIR representation for `{}`", tcx.def_path_str(key) }
-    }
-
     /// Set of all the `DefId`s in this crate that have MIR associated with
     /// them. This includes all the body owners, but also things like struct
     /// constructors.
diff --git a/compiler/rustc_mir_build/src/lib.rs b/compiler/rustc_mir_build/src/lib.rs
index 82fb7d1ae4a..442f5fa7d17 100644
--- a/compiler/rustc_mir_build/src/lib.rs
+++ b/compiler/rustc_mir_build/src/lib.rs
@@ -34,6 +34,6 @@ pub fn provide(providers: &mut Providers) {
         build::closure_saved_names_of_captured_variables;
     providers.check_unsafety = check_unsafety::check_unsafety;
     providers.thir_body = thir::cx::thir_body;
-    providers.thir_tree = thir::print::thir_tree;
-    providers.thir_flat = thir::print::thir_flat;
+    providers.hooks.thir_tree = thir::print::thir_tree;
+    providers.hooks.thir_flat = thir::print::thir_flat;
 }
diff --git a/compiler/rustc_mir_build/src/thir/print.rs b/compiler/rustc_mir_build/src/thir/print.rs
index ef15082a481..49e48427b65 100644
--- a/compiler/rustc_mir_build/src/thir/print.rs
+++ b/compiler/rustc_mir_build/src/thir/print.rs
@@ -1,10 +1,11 @@
+use rustc_middle::query::TyCtxtAt;
 use rustc_middle::thir::*;
-use rustc_middle::ty::{self, TyCtxt};
+use rustc_middle::ty;
 use rustc_span::def_id::LocalDefId;
 use std::fmt::{self, Write};
 
-pub(crate) fn thir_tree(tcx: TyCtxt<'_>, owner_def: LocalDefId) -> String {
-    match super::cx::thir_body(tcx, owner_def) {
+pub(crate) fn thir_tree(tcx: TyCtxtAt<'_>, owner_def: LocalDefId) -> String {
+    match super::cx::thir_body(*tcx, owner_def) {
         Ok((thir, _)) => {
             let thir = thir.steal();
             let mut printer = ThirPrinter::new(&thir);
@@ -15,8 +16,8 @@ pub(crate) fn thir_tree(tcx: TyCtxt<'_>, owner_def: LocalDefId) -> String {
     }
 }
 
-pub(crate) fn thir_flat(tcx: TyCtxt<'_>, owner_def: LocalDefId) -> String {
-    match super::cx::thir_body(tcx, owner_def) {
+pub(crate) fn thir_flat(tcx: TyCtxtAt<'_>, owner_def: LocalDefId) -> String {
+    match super::cx::thir_body(*tcx, owner_def) {
         Ok((thir, _)) => format!("{:#?}", thir.steal()),
         Err(_) => "error".into(),
     }