Use correct HirFileId in find_related_test

This commit is contained in:
Lukas Wirth 2021-08-20 13:49:28 +02:00
parent 7342dcf0b0
commit 557df6ff3f
3 changed files with 14 additions and 16 deletions

View File

@ -173,6 +173,10 @@ pub fn descend_node_at_offset<N: ast::AstNode>(
self.imp.descend_node_at_offset(node, offset).find_map(N::cast) self.imp.descend_node_at_offset(node, offset).find_map(N::cast)
} }
pub fn hir_file_for(&self, syntax_node: &SyntaxNode) -> HirFileId {
self.imp.find_file(syntax_node.clone()).file_id
}
pub fn original_range(&self, node: &SyntaxNode) -> FileRange { pub fn original_range(&self, node: &SyntaxNode) -> FileRange {
self.imp.original_range(node) self.imp.original_range(node)
} }

View File

@ -22,7 +22,7 @@
pub use mbe::{ExpandError, ExpandResult}; pub use mbe::{ExpandError, ExpandResult};
pub use parser::FragmentKind; pub use parser::FragmentKind;
use std::{hash::Hash, sync::Arc}; use std::{hash::Hash, iter, sync::Arc};
use base_db::{impl_intern_key, salsa, CrateId, FileId, FileRange}; use base_db::{impl_intern_key, salsa, CrateId, FileId, FileRange};
use syntax::{ use syntax::{
@ -454,7 +454,7 @@ pub fn ancestors_with_macros(
self, self,
db: &dyn db::AstDatabase, db: &dyn db::AstDatabase,
) -> impl Iterator<Item = InFile<SyntaxNode>> + '_ { ) -> impl Iterator<Item = InFile<SyntaxNode>> + '_ {
std::iter::successors(Some(self), move |node| match node.value.parent() { iter::successors(Some(self), move |node| match node.value.parent() {
Some(parent) => Some(node.with_value(parent)), Some(parent) => Some(node.with_value(parent)),
None => { None => {
let parent_node = node.file_id.call_node(db)?; let parent_node = node.file_id.call_node(db)?;
@ -570,19 +570,14 @@ pub fn nodes_with_attributes<'db>(
where where
N: 'db, N: 'db,
{ {
std::iter::successors(Some(self), move |node| { iter::successors(Some(self), move |node| {
let InFile { file_id, value } = node.file_id.call_node(db)?; let InFile { file_id, value } = node.file_id.call_node(db)?;
N::cast(value).map(|n| InFile::new(file_id, n)) N::cast(value).map(|n| InFile::new(file_id, n))
}) })
} }
pub fn node_with_attributes(self, db: &dyn db::AstDatabase) -> InFile<N> { pub fn node_with_attributes(self, db: &dyn db::AstDatabase) -> InFile<N> {
std::iter::successors(Some(self), move |node| { self.nodes_with_attributes(db).last().unwrap()
let InFile { file_id, value } = node.file_id.call_node(db)?;
N::cast(value).map(|n| InFile::new(file_id, n))
})
.last()
.unwrap()
} }
} }

View File

@ -3,7 +3,7 @@
use ast::NameOwner; use ast::NameOwner;
use cfg::CfgExpr; use cfg::CfgExpr;
use either::Either; use either::Either;
use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, Semantics}; use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, InFile, Semantics};
use ide_assists::utils::test_related_attribute; use ide_assists::utils::test_related_attribute;
use ide_db::{ use ide_db::{
base_db::{FilePosition, FileRange}, base_db::{FilePosition, FileRange},
@ -232,22 +232,21 @@ fn find_related_tests(
let functions = refs.iter().filter_map(|(range, _)| { let functions = refs.iter().filter_map(|(range, _)| {
let token = file.token_at_offset(range.start()).next()?; let token = file.token_at_offset(range.start()).next()?;
let token = sema.descend_into_macros(token); let token = sema.descend_into_macros(token);
// FIXME: This is the wrong file_id
token token
.ancestors() .ancestors()
.find_map(ast::Fn::cast) .find_map(ast::Fn::cast)
.map(|f| hir::InFile::new(file_id.into(), f)) .map(|f| hir::InFile::new(sema.hir_file_for(f.syntax()), f))
}); });
for fn_def in functions { for fn_def in functions {
// #[test/bench] expands to just the item causing us to lose the attribute, so recover them by going out of the attribute // #[test/bench] expands to just the item causing us to lose the attribute, so recover them by going out of the attribute
let fn_def = fn_def.node_with_attributes(sema.db); let InFile { value: fn_def, .. } = &fn_def.node_with_attributes(sema.db);
if let Some(runnable) = as_test_runnable(sema, &fn_def.value) { if let Some(runnable) = as_test_runnable(sema, fn_def) {
// direct test // direct test
tests.insert(runnable); tests.insert(runnable);
} else if let Some(module) = parent_test_module(sema, &fn_def.value) { } else if let Some(module) = parent_test_module(sema, fn_def) {
// indirect test // indirect test
find_related_tests_in_module(sema, &fn_def.value, &module, tests); find_related_tests_in_module(sema, fn_def, &module, tests);
} }
} }
} }