Simplify ast_transform

This commit is contained in:
Aleksey Kladov 2020-10-02 20:52:48 +02:00
parent 673e1ddb9a
commit 3290bb4112

View File

@ -5,12 +5,13 @@ use hir::{HirDisplay, PathResolution, SemanticsScope};
use syntax::{ use syntax::{
algo::SyntaxRewriter, algo::SyntaxRewriter,
ast::{self, AstNode}, ast::{self, AstNode},
SyntaxNode,
}; };
pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N { pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N {
SyntaxRewriter::from_fn(|element| match element { SyntaxRewriter::from_fn(|element| match element {
syntax::SyntaxElement::Node(n) => { syntax::SyntaxElement::Node(n) => {
let replacement = transformer.get_substitution(&n)?; let replacement = transformer.get_substitution(&n, transformer)?;
Some(replacement.into()) Some(replacement.into())
} }
_ => None, _ => None,
@ -47,32 +48,35 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N {
/// We'd want to somehow express this concept simpler, but so far nobody got to /// We'd want to somehow express this concept simpler, but so far nobody got to
/// simplifying this! /// simplifying this!
pub trait AstTransform<'a> { pub trait AstTransform<'a> {
fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode>; fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode>;
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a>;
fn or<T: AstTransform<'a> + 'a>(self, other: T) -> Box<dyn AstTransform<'a> + 'a> fn or<T: AstTransform<'a> + 'a>(self, other: T) -> Box<dyn AstTransform<'a> + 'a>
where where
Self: Sized + 'a, Self: Sized + 'a,
{ {
self.chain_before(Box::new(other)) Box::new(Or(Box::new(self), Box::new(other)))
} }
} }
struct NullTransformer; struct Or<'a>(Box<dyn AstTransform<'a> + 'a>, Box<dyn AstTransform<'a> + 'a>);
impl<'a> AstTransform<'a> for NullTransformer { impl<'a> AstTransform<'a> for Or<'a> {
fn get_substitution(&self, _node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> { fn get_substitution(
None &self,
} node: &SyntaxNode,
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a> { recur: &dyn AstTransform<'a>,
other ) -> Option<SyntaxNode> {
self.0.get_substitution(node, recur).or_else(|| self.1.get_substitution(node, recur))
} }
} }
pub struct SubstituteTypeParams<'a> { pub struct SubstituteTypeParams<'a> {
source_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>,
substs: FxHashMap<hir::TypeParam, ast::Type>, substs: FxHashMap<hir::TypeParam, ast::Type>,
previous: Box<dyn AstTransform<'a> + 'a>,
} }
impl<'a> SubstituteTypeParams<'a> { impl<'a> SubstituteTypeParams<'a> {
@ -111,11 +115,7 @@ impl<'a> SubstituteTypeParams<'a> {
} }
}) })
.collect(); .collect();
return SubstituteTypeParams { return SubstituteTypeParams { source_scope, substs: substs_by_param };
source_scope,
substs: substs_by_param,
previous: Box::new(NullTransformer),
};
// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the // FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
// trait ref, and then go from the types in the substs back to the syntax). // trait ref, and then go from the types in the substs back to the syntax).
@ -140,7 +140,14 @@ impl<'a> SubstituteTypeParams<'a> {
Some(result) Some(result)
} }
} }
fn get_substitution_inner(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> { }
impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
_recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
let type_ref = ast::Type::cast(node.clone())?; let type_ref = ast::Type::cast(node.clone())?;
let path = match &type_ref { let path = match &type_ref {
ast::Type::PathType(path_type) => path_type.path()?, ast::Type::PathType(path_type) => path_type.path()?,
@ -154,27 +161,23 @@ impl<'a> SubstituteTypeParams<'a> {
} }
} }
impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> {
fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
self.get_substitution_inner(node).or_else(|| self.previous.get_substitution(node))
}
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a> {
Box::new(SubstituteTypeParams { previous: other, ..self })
}
}
pub struct QualifyPaths<'a> { pub struct QualifyPaths<'a> {
target_scope: &'a SemanticsScope<'a>, target_scope: &'a SemanticsScope<'a>,
source_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>,
previous: Box<dyn AstTransform<'a> + 'a>,
} }
impl<'a> QualifyPaths<'a> { impl<'a> QualifyPaths<'a> {
pub fn new(target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>) -> Self { pub fn new(target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>) -> Self {
Self { target_scope, source_scope, previous: Box::new(NullTransformer) } Self { target_scope, source_scope }
} }
}
fn get_substitution_inner(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> { impl<'a> AstTransform<'a> for QualifyPaths<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
// FIXME handle value ns? // FIXME handle value ns?
let from = self.target_scope.module()?; let from = self.target_scope.module()?;
let p = ast::Path::cast(node.clone())?; let p = ast::Path::cast(node.clone())?;
@ -191,7 +194,7 @@ impl<'a> QualifyPaths<'a> {
let type_args = p let type_args = p
.segment() .segment()
.and_then(|s| s.generic_arg_list()) .and_then(|s| s.generic_arg_list())
.map(|arg_list| apply(self, arg_list)); .map(|arg_list| apply(recur, arg_list));
if let Some(type_args) = type_args { if let Some(type_args) = type_args {
let last_segment = path.segment().unwrap(); let last_segment = path.segment().unwrap();
path = path.with_segment(last_segment.with_generic_args(type_args)) path = path.with_segment(last_segment.with_generic_args(type_args))
@ -208,15 +211,6 @@ impl<'a> QualifyPaths<'a> {
} }
} }
impl<'a> AstTransform<'a> for QualifyPaths<'a> {
fn get_substitution(&self, node: &syntax::SyntaxNode) -> Option<syntax::SyntaxNode> {
self.get_substitution_inner(node).or_else(|| self.previous.get_substitution(node))
}
fn chain_before(self, other: Box<dyn AstTransform<'a> + 'a>) -> Box<dyn AstTransform<'a> + 'a> {
Box::new(QualifyPaths { previous: other, ..self })
}
}
pub(crate) fn path_to_ast(path: hir::ModPath) -> ast::Path { pub(crate) fn path_to_ast(path: hir::ModPath) -> ast::Path {
let parse = ast::SourceFile::parse(&path.to_string()); let parse = ast::SourceFile::parse(&path.to_string());
parse parse