diff --git a/crates/libsyntax2/src/ast/mod.rs b/crates/libsyntax2/src/ast/mod.rs index 881f380f37f..0b6868547e5 100644 --- a/crates/libsyntax2/src/ast/mod.rs +++ b/crates/libsyntax2/src/ast/mod.rs @@ -1,10 +1,13 @@ mod generated; +use std::marker::PhantomData; + use itertools::Itertools; use smol_str::SmolStr; use { SyntaxNodeRef, SyntaxKind::*, + yellow::{RefRoot, SyntaxNodeChildren}, }; pub use self::generated::*; @@ -33,8 +36,8 @@ fn arg_list(self) -> Option> { } pub trait FnDefOwner<'a>: AstNode<'a> { - fn functions(self) -> Box> + 'a> { - Box::new(children(self)) + fn functions(self) -> AstNodeChildren<'a, FnDef<'a>> { + children(self) } } @@ -49,8 +52,8 @@ fn where_clause(self) -> Option> { } pub trait AttrsOwner<'a>: AstNode<'a> { - fn attrs(self) -> Box> + 'a> { - Box::new(children(self)) + fn attrs(self) -> AstNodeChildren<'a, Attr<'a>> { + children(self) } } @@ -155,7 +158,7 @@ pub fn then_branch(self) -> Option> { pub fn else_branch(self) -> Option> { self.blocks().nth(1) } - fn blocks(self) -> impl Iterator> { + fn blocks(self) -> AstNodeChildren<'a, Block<'a>> { children(self) } } @@ -164,8 +167,34 @@ fn child_opt<'a, P: AstNode<'a>, C: AstNode<'a>>(parent: P) -> Option { children(parent).next() } -fn children<'a, P: AstNode<'a>, C: AstNode<'a>>(parent: P) -> impl Iterator + 'a { - parent.syntax() - .children() - .filter_map(C::cast) +fn children<'a, P: AstNode<'a>, C: AstNode<'a>>(parent: P) -> AstNodeChildren<'a, C> { + AstNodeChildren::new(parent.syntax()) +} + + +#[derive(Debug)] +pub struct AstNodeChildren<'a, N> { + inner: SyntaxNodeChildren>, + ph: PhantomData, +} + +impl<'a, N> AstNodeChildren<'a, N> { + fn new(parent: SyntaxNodeRef<'a>) -> Self { + AstNodeChildren { + inner: parent.children(), + ph: PhantomData, + } + } +} + +impl<'a, N: AstNode<'a>> Iterator for AstNodeChildren<'a, N> { + type Item = N; + fn next(&mut self) -> Option { + loop { + match N::cast(self.inner.next()?) { + Some(n) => return Some(n), + None => (), + } + } + } } diff --git a/crates/libsyntax2/src/yellow/mod.rs b/crates/libsyntax2/src/yellow/mod.rs index 82eda79d6d3..0596e702f89 100644 --- a/crates/libsyntax2/src/yellow/mod.rs +++ b/crates/libsyntax2/src/yellow/mod.rs @@ -8,7 +8,7 @@ sync::Arc, ptr, }; -pub use self::syntax::{SyntaxNode, SyntaxNodeRef, SyntaxError}; +pub use self::syntax::{SyntaxNode, SyntaxNodeRef, SyntaxError, SyntaxNodeChildren}; pub(crate) use self::{ builder::GreenBuilder, green::GreenNode, diff --git a/crates/libsyntax2/src/yellow/syntax.rs b/crates/libsyntax2/src/yellow/syntax.rs index 444dbeb3092..1d99cab4a94 100644 --- a/crates/libsyntax2/src/yellow/syntax.rs +++ b/crates/libsyntax2/src/yellow/syntax.rs @@ -1,6 +1,7 @@ use std::{ fmt, sync::Arc, hash::{Hasher, Hash}, + ops::Range, }; use smol_str::SmolStr; @@ -93,17 +94,11 @@ pub fn text(&self) -> SyntaxText { SyntaxText::new(self.borrowed()) } - pub fn children(&self) -> impl Iterator> { - let red = self.red; - let n_children = self.red().n_children(); - let root = self.root.clone(); - (0..n_children).map(move |i| { - let red = unsafe { red.get(root.syntax_root()) }; - SyntaxNode { - root: root.clone(), - red: red.get_child(i).unwrap(), - } - }) + pub fn children(&self) -> SyntaxNodeChildren { + SyntaxNodeChildren { + parent: self.clone(), + iter: (0..self.red().n_children()) + } } pub fn parent(&self) -> Option> { @@ -192,6 +187,26 @@ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { } } +#[derive(Debug)] +pub struct SyntaxNodeChildren { + parent: SyntaxNode, + iter: Range, +} + +impl Iterator for SyntaxNodeChildren { + type Item = SyntaxNode; + + fn next(&mut self) -> Option> { + self.iter.next().map(|i| { + let red = self.parent.red(); + SyntaxNode { + root: self.parent.root.clone(), + red: red.get_child(i).unwrap(), + } + }) + } +} + fn has_short_text(kind: SyntaxKind) -> bool { match kind { IDENT | LIFETIME | INT_NUMBER | FLOAT_NUMBER => true,