From 5d23d8bc299b718e7f026a7e6c1363dde3342817 Mon Sep 17 00:00:00 2001
From: vsrs <vit@conrlab.com>
Date: Sat, 27 Feb 2021 17:59:52 +0300
Subject: [PATCH] Add runnables::related_tests

---
 crates/cfg/src/cfg_expr.rs  |   2 +-
 crates/ide/src/lib.rs       |   9 +
 crates/ide/src/runnables.rs | 338 +++++++++++++++++++++++++++++++++++-
 crates/ide_db/src/search.rs |   4 +
 4 files changed, 346 insertions(+), 7 deletions(-)

diff --git a/crates/cfg/src/cfg_expr.rs b/crates/cfg/src/cfg_expr.rs
index 42327f1e147..069fc01d0c7 100644
--- a/crates/cfg/src/cfg_expr.rs
+++ b/crates/cfg/src/cfg_expr.rs
@@ -49,7 +49,7 @@ impl fmt::Display for CfgAtom {
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum CfgExpr {
     Invalid,
     Atom(CfgAtom),
diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs
index b600178ee15..baa80cf43a3 100644
--- a/crates/ide/src/lib.rs
+++ b/crates/ide/src/lib.rs
@@ -445,6 +445,15 @@ impl Analysis {
         self.with_db(|db| runnables::runnables(db, file_id))
     }
 
+    /// Returns the set of tests for the given file position.
+    pub fn related_tests(
+        &self,
+        position: FilePosition,
+        search_scope: Option<SearchScope>,
+    ) -> Cancelable<Vec<Runnable>> {
+        self.with_db(|db| runnables::related_tests(db, position, search_scope))
+    }
+
     /// Computes syntax highlighting for the given file
     pub fn highlight(&self, file_id: FileId) -> Cancelable<Vec<HlRange>> {
         self.with_db(|db| syntax_highlighting::highlight(db, file_id, None, false))
diff --git a/crates/ide/src/runnables.rs b/crates/ide/src/runnables.rs
index 1e7baed2046..ce3a2e7baa8 100644
--- a/crates/ide/src/runnables.rs
+++ b/crates/ide/src/runnables.rs
@@ -1,10 +1,14 @@
 use std::fmt;
 
+use ast::NameOwner;
 use cfg::CfgExpr;
 use hir::{AsAssocItem, HasAttrs, HasSource, Semantics};
 use ide_assists::utils::test_related_attribute;
-use ide_db::{defs::Definition, RootDatabase, SymbolKind};
+use ide_db::{
+    base_db::FilePosition, defs::Definition, search::SearchScope, RootDatabase, SymbolKind,
+};
 use itertools::Itertools;
+use rustc_hash::FxHashSet;
 use syntax::{
     ast::{self, AstNode, AttrsOwner},
     match_ast, SyntaxNode,
@@ -13,17 +17,17 @@ use test_utils::mark;
 
 use crate::{
     display::{ToNav, TryToNav},
-    FileId, NavigationTarget,
+    references, FileId, NavigationTarget,
 };
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub struct Runnable {
     pub nav: NavigationTarget,
     pub kind: RunnableKind,
     pub cfg: Option<CfgExpr>,
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub enum TestId {
     Name(String),
     Path(String),
@@ -38,7 +42,7 @@ impl fmt::Display for TestId {
     }
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub enum RunnableKind {
     Test { test_id: TestId, attr: TestAttr },
     TestMod { path: String },
@@ -106,6 +110,102 @@ pub(crate) fn runnables(db: &RootDatabase, file_id: FileId) -> Vec<Runnable> {
     res
 }
 
+// Feature: Run Test
+//
+// Shows a popup suggesting to run a test in which the item **at the current cursor
+// location** is used (if any).
+//
+// |===
+// | Editor  | Action Name
+//
+// | VS Code | **Rust Analyzer: Run Test**
+// |===
+pub(crate) fn related_tests(
+    db: &RootDatabase,
+    position: FilePosition,
+    search_scope: Option<SearchScope>,
+) -> Vec<Runnable> {
+    let sema = Semantics::new(db);
+    let mut res: FxHashSet<Runnable> = FxHashSet::default();
+
+    find_related_tests(&sema, position, search_scope, &mut res);
+
+    res.into_iter().collect_vec()
+}
+
+fn find_related_tests(
+    sema: &Semantics<RootDatabase>,
+    position: FilePosition,
+    search_scope: Option<SearchScope>,
+    tests: &mut FxHashSet<Runnable>,
+) {
+    if let Some(refs) = references::find_all_refs(&sema, position, search_scope) {
+        for (file_id, refs) in refs.references {
+            let file = sema.parse(file_id);
+            let file = file.syntax();
+            let functions = refs.iter().filter_map(|(range, _)| {
+                let token = file.token_at_offset(range.start()).next()?;
+                let token = sema.descend_into_macros(token);
+                let syntax = token.parent();
+                syntax.ancestors().find_map(ast::Fn::cast)
+            });
+
+            for fn_def in functions {
+                if let Some(runnable) = as_test_runnable(&sema, &fn_def) {
+                    // direct test
+                    tests.insert(runnable);
+                } else if let Some(module) = parent_test_module(&sema, &fn_def) {
+                    // indirect test
+                    find_related_tests_in_module(sema, &fn_def, &module, tests);
+                }
+            }
+        }
+    }
+}
+
+fn find_related_tests_in_module(
+    sema: &Semantics<RootDatabase>,
+    fn_def: &ast::Fn,
+    parent_module: &hir::Module,
+    tests: &mut FxHashSet<Runnable>,
+) {
+    if let Some(fn_name) = fn_def.name() {
+        let mod_source = parent_module.definition_source(sema.db);
+        let range = match mod_source.value {
+            hir::ModuleSource::Module(m) => m.syntax().text_range(),
+            hir::ModuleSource::BlockExpr(b) => b.syntax().text_range(),
+            hir::ModuleSource::SourceFile(f) => f.syntax().text_range(),
+        };
+
+        let file_id = mod_source.file_id.original_file(sema.db);
+        let mod_scope = SearchScope::file_part(file_id, range);
+        let fn_pos = FilePosition { file_id, offset: fn_name.syntax().text_range().start() };
+        find_related_tests(sema, fn_pos, Some(mod_scope), tests)
+    }
+}
+
+fn as_test_runnable(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Option<Runnable> {
+    if test_related_attribute(&fn_def).is_some() {
+        let function = sema.to_def(fn_def)?;
+        runnable_fn(sema, function)
+    } else {
+        None
+    }
+}
+
+fn parent_test_module(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Option<hir::Module> {
+    fn_def.syntax().ancestors().find_map(|node| {
+        let module = ast::Module::cast(node)?;
+        let module = sema.to_def(&module)?;
+
+        if has_test_function_or_multiple_test_submodules(sema, &module) {
+            Some(module)
+        } else {
+            None
+        }
+    })
+}
+
 fn runnables_mod(sema: &Semantics<RootDatabase>, acc: &mut Vec<Runnable>, module: hir::Module) {
     acc.extend(module.declarations(sema.db).into_iter().filter_map(|def| {
         let runnable = match def {
@@ -255,7 +355,7 @@ fn module_def_doctest(sema: &Semantics<RootDatabase>, def: hir::ModuleDef) -> Op
     Some(res)
 }
 
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
 pub struct TestAttr {
     pub ignore: bool,
 }
@@ -349,6 +449,12 @@ mod tests {
         );
     }
 
+    fn check_tests(ra_fixture: &str, expect: Expect) {
+        let (analysis, position) = fixture::position(ra_fixture);
+        let tests = analysis.related_tests(position, None).unwrap();
+        expect.assert_debug_eq(&tests);
+    }
+
     #[test]
     fn test_runnables() {
         check(
@@ -1074,4 +1180,224 @@ mod tests {
             "#]],
         );
     }
+
+    #[test]
+    fn find_no_tests() {
+        check_tests(
+            r#"
+//- /lib.rs
+fn foo$0() {  };
+"#,
+            expect![[r#"
+                []
+            "#]],
+        );
+    }
+
+    #[test]
+    fn find_direct_fn_test() {
+        check_tests(
+            r#"
+//- /lib.rs
+fn foo$0() { };
+
+mod tests {
+    #[test]
+    fn foo_test() {
+        super::foo()
+    }
+}
+"#,
+            expect![[r#"
+                [
+                    Runnable {
+                        nav: NavigationTarget {
+                            file_id: FileId(
+                                0,
+                            ),
+                            full_range: 31..85,
+                            focus_range: 46..54,
+                            name: "foo_test",
+                            kind: Function,
+                        },
+                        kind: Test {
+                            test_id: Path(
+                                "tests::foo_test",
+                            ),
+                            attr: TestAttr {
+                                ignore: false,
+                            },
+                        },
+                        cfg: None,
+                    },
+                ]
+            "#]],
+        );
+    }
+
+    #[test]
+    fn find_direct_struct_test() {
+        check_tests(
+            r#"
+//- /lib.rs
+struct Fo$0o;
+fn foo(arg: &Foo) { };
+
+mod tests {
+    use super::*;
+
+    #[test]
+    fn foo_test() {
+        foo(Foo);
+    }
+}
+"#,
+            expect![[r#"
+            [
+                Runnable {
+                    nav: NavigationTarget {
+                        file_id: FileId(
+                            0,
+                        ),
+                        full_range: 71..122,
+                        focus_range: 86..94,
+                        name: "foo_test",
+                        kind: Function,
+                    },
+                    kind: Test {
+                        test_id: Path(
+                            "tests::foo_test",
+                        ),
+                        attr: TestAttr {
+                            ignore: false,
+                        },
+                    },
+                    cfg: None,
+                },
+            ]
+            "#]],
+        );
+    }
+
+    #[test]
+    fn find_indirect_fn_test() {
+        check_tests(
+            r#"
+//- /lib.rs
+fn foo$0() { };
+
+mod tests {
+    use super::foo;
+
+    fn check1() {
+        check2()
+    }
+
+    fn check2() {
+        foo()
+    }
+
+    #[test]
+    fn foo_test() {
+        check1()
+    }
+}
+"#,
+            expect![[r#"
+                [
+                    Runnable {
+                        nav: NavigationTarget {
+                            file_id: FileId(
+                                0,
+                            ),
+                            full_range: 133..183,
+                            focus_range: 148..156,
+                            name: "foo_test",
+                            kind: Function,
+                        },
+                        kind: Test {
+                            test_id: Path(
+                                "tests::foo_test",
+                            ),
+                            attr: TestAttr {
+                                ignore: false,
+                            },
+                        },
+                        cfg: None,
+                    },
+                ]
+            "#]],
+        );
+    }
+
+    #[test]
+    fn tests_are_unique() {
+        check_tests(
+            r#"
+//- /lib.rs
+fn foo$0() { };
+
+mod tests {
+    use super::foo;
+
+    #[test]
+    fn foo_test() {
+        foo();
+        foo();
+    }
+
+    #[test]
+    fn foo2_test() {
+        foo();
+        foo();
+    }
+
+}
+"#,
+            expect![[r#"
+            [
+                Runnable {
+                    nav: NavigationTarget {
+                        file_id: FileId(
+                            0,
+                        ),
+                        full_range: 52..115,
+                        focus_range: 67..75,
+                        name: "foo_test",
+                        kind: Function,
+                    },
+                    kind: Test {
+                        test_id: Path(
+                            "tests::foo_test",
+                        ),
+                        attr: TestAttr {
+                            ignore: false,
+                        },
+                    },
+                    cfg: None,
+                },
+                Runnable {
+                    nav: NavigationTarget {
+                        file_id: FileId(
+                            0,
+                        ),
+                        full_range: 121..185,
+                        focus_range: 136..145,
+                        name: "foo2_test",
+                        kind: Function,
+                    },
+                    kind: Test {
+                        test_id: Path(
+                            "tests::foo2_test",
+                        ),
+                        attr: TestAttr {
+                            ignore: false,
+                        },
+                    },
+                    cfg: None,
+                },
+            ]
+            "#]],
+        );
+    }
 }
diff --git a/crates/ide_db/src/search.rs b/crates/ide_db/src/search.rs
index ddcfbd3f3ff..8b211256e26 100644
--- a/crates/ide_db/src/search.rs
+++ b/crates/ide_db/src/search.rs
@@ -86,6 +86,10 @@ impl SearchScope {
         SearchScope::new(std::iter::once((file, None)).collect())
     }
 
+    pub fn file_part(file: FileId, range: TextRange) -> SearchScope {
+        SearchScope::new(std::iter::once((file, Some(range))).collect())
+    }
+
     pub fn files(files: &[FileId]) -> SearchScope {
         SearchScope::new(files.iter().map(|f| (*f, None)).collect())
     }