diff --git a/src/ast.rs b/src/ast.rs index 6691831..418c45a 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -530,179 +530,135 @@ pub mod tree_visitor { children: Vec, } + #[derive(Debug, PartialEq, Eq)] + enum End { + Open, + Inclusive(Node), + Exclusive(Node), + } + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] enum PrePost { Pre(Node), Post(Node), } + impl PrePost { + fn node(self) -> Node { + match self { + PrePost::Pre(n) => n, + PrePost::Post(n) => n, + } + } + } + /// Don't modify `node` in `pre()` /// Don't modify `children` in `pre()` - pub struct Visitor<'a, F1, F2> { - tree: &'a mut Tree, + pub struct Visitor { frames: Vec, + current_node: Option, + end: End, pre: F1, post: F2, } - impl<'a, F1, F2> Visitor<'a, F1, F2> { - pub fn new(tree: &'a mut Tree, start: Node, pre: F1, post: F2) -> Visitor<'a, F1, F2> - where - F1: FnMut(&mut Tree, Node) -> T, - F2: FnMut(&mut Tree, Node) -> U, - { + impl Visitor { + pub fn new(start: Node, pre: F1, post: F2) -> Self { + Self::new_inner(start, End::Open, pre, post) + } + pub fn new_range(start: Node, end: Node, pre: F1, post: F2) -> Self { + Self::new_inner(start, End::Exclusive(end), pre, post) + } + pub fn new_range_inclusive(start: Node, end: Node, pre: F1, post: F2) -> Self { + Self::new_inner(start, End::Inclusive(end), pre, post) + } + } + + impl Visitor { + fn new_inner(start: Node, end: End, pre: F1, post: F2) -> Self { let frame = Frame { node: Node::MAX, children: vec![start], }; Self { frames: vec![frame], - tree, + current_node: None, + end, pre, post, } } - fn get_children(&self, node: Node) -> Vec { - match self.tree.nodes.get_node(node) { - super::Tag::FunctionProto { - name, - parameters, - return_type, - } => { - if let Some(params) = parameters { - vec![*name, *params, *return_type] - } else { - vec![*name, *return_type] - } - } - super::Tag::ParameterList { parameters } => parameters.clone(), - super::Tag::Parameter { name, ty } => { - vec![*name, *ty] - } - super::Tag::Pointer { pointee } => { - vec![*pointee] - } - super::Tag::FunctionDecl { proto, body } => { - vec![*proto, *body] - } - super::Tag::Block { - statements, - trailing_expr, - } => { - let mut children = statements.clone(); - if let Some(expr) = trailing_expr { - children.push(*expr); - } - children - } - super::Tag::ReturnStmt { expr } => expr.into_iter().cloned().collect::>(), - &super::Tag::ExprStmt { expr } => { - vec![expr] - } - super::Tag::VarDecl { - name, - explicit_type, - .. - } => { - if let Some(ty) = *explicit_type { - vec![*name, ty] - } else { - vec![*name] - } - } - super::Tag::GlobalDecl { - name, - explicit_type, - .. - } => { - if let Some(ty) = *explicit_type { - vec![*name, ty] - } else { - vec![*name] - } - } - &super::Tag::CallExpr { lhs, rhs } => { - if let Some(rhs) = rhs { - vec![lhs, rhs] - } else { - vec![lhs] - } - } - super::Tag::ArgumentList { parameters } => parameters.clone(), - &super::Tag::Argument { name, expr } => { - if let Some(name) = name { - vec![name, expr] - } else { - vec![expr] - } - } - &super::Tag::ExplicitCast { lhs, typename } => { - vec![lhs, typename] - } - super::Tag::Deref { lhs } - | super::Tag::Ref { lhs } - | super::Tag::Not { lhs } - | super::Tag::Negate { lhs } => { - vec![*lhs] - } - super::Tag::Or { lhs, rhs } - | super::Tag::And { lhs, rhs } - | super::Tag::BitOr { lhs, rhs } - | super::Tag::BitAnd { lhs, rhs } - | super::Tag::BitXOr { lhs, rhs } - | super::Tag::Eq { lhs, rhs } - | super::Tag::NEq { lhs, rhs } - | super::Tag::Lt { lhs, rhs } - | super::Tag::Gt { lhs, rhs } - | super::Tag::Le { lhs, rhs } - | super::Tag::Ge { lhs, rhs } - | super::Tag::Shl { lhs, rhs } - | super::Tag::Shr { lhs, rhs } - | super::Tag::Add { lhs, rhs } - | super::Tag::Sub { lhs, rhs } - | super::Tag::Mul { lhs, rhs } - | super::Tag::Rem { lhs, rhs } - | super::Tag::Div { lhs, rhs } - | super::Tag::Assign { lhs, rhs } => { - vec![*lhs, *rhs] - } - _ => vec![], + fn next_node(&mut self, tree: &Tree) -> Option { + if let Some(node) = self.current_node.take() { + let mut children = tree.get_node_children(node); + children.reverse(); + self.frames.push(Frame { node, children }); } - } - - fn next_node(&mut self) -> Option { - loop { + let node = { let frame = self.frames.last_mut()?; if let Some(node) = frame.children.pop() { - return Some(PrePost::Pre(node)); + if self.end == End::Exclusive(node) { + self.frames.clear(); + None + } else { + self.current_node = Some(node); + Some(PrePost::Pre(node)) + } } else { let frame = self.frames.pop()?; - if frame.node != Node::MAX { - return Some(PrePost::Post(frame.node)); + let node = frame.node; + + if node == Node::MAX { + self.frames.clear(); + None + } else { + if self.end == End::Inclusive(node) { + self.frames.clear(); + } + Some(PrePost::Post(node)) } } + }; + + node + } + + pub fn skip_until(mut self, tree: &Tree, node: Node) -> Self { + self.find(tree, node); + self + } + + pub fn find(&mut self, tree: &Tree, needle: Node) { + loop { + let Some(node) = self.next_node(tree) else { + break; + }; + if node == PrePost::Pre(needle) { + self.frames.last_mut().unwrap().children.push(needle); + break; + } } } - pub fn visit_ok(mut self) -> core::result::Result + /// short-circuits on the first E + pub fn visit_ok(mut self, tree: &Tree) -> core::result::Result where - F1: FnMut(&mut Tree, Node) -> core::result::Result, - F2: FnMut(&mut Tree, Node) -> core::result::Result, + F1: FnMut(&Tree, Node) -> core::result::Result, + F2: FnMut(&Tree, Node) -> core::result::Result, { let mut t = None; loop { - let Some(node) = self.next_node() else { + let Some(node) = self.next_node(tree) else { break; }; match node { PrePost::Pre(node) => { - t = Some((self.pre)(self.tree, node)?); - let children = self.get_children(node); - self.frames.push(Frame { node, children }); + t = Some((self.pre)(tree, node)?); } PrePost::Post(node) => { - t = Some((self.post)(self.tree, node)?); + t = Some((self.post)(tree, node)?); } } } @@ -710,24 +666,43 @@ pub mod tree_visitor { Ok(t.unwrap()) } - pub fn visit(mut self) + pub fn visit(mut self, tree: &Tree) where - F1: FnMut(&mut Tree, Node) -> T, - F2: FnMut(&mut Tree, Node) -> U, + F1: FnMut(&Tree, Node) -> T, + F2: FnMut(&Tree, Node) -> U, { loop { - let Some(node) = self.next_node() else { + let Some(node) = self.next_node(tree) else { break; }; match node { PrePost::Pre(node) => { - (self.pre)(self.tree, node); - let children = self.get_children(node); - self.frames.push(Frame { node, children }); + (self.pre)(tree, node); } PrePost::Post(node) => { - (self.post)(self.tree, node); + (self.post)(tree, node); + } + } + } + } + + pub fn visit_mut(mut self, tree: &mut Tree) + where + F1: FnMut(&mut Tree, Node) -> T, + F2: FnMut(&mut Tree, Node) -> U, + { + loop { + let Some(node) = self.next_node(tree) else { + break; + }; + + match node { + PrePost::Pre(node) => { + (self.pre)(tree, node); + } + PrePost::Post(node) => { + (self.post)(tree, node); } } } diff --git a/src/parser.rs b/src/parser.rs index 42c7af5..2e34d55 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -12,6 +12,7 @@ use crate::{ string_table::{ImmOrIndex, Index, StringTable}, symbol_table::{SymbolKind, SymbolTable}, tokens::Token, + variant, }; #[derive(Debug, thiserror::Error)] @@ -81,6 +82,10 @@ impl Nodes { fn reserve_node(&mut self) -> Node { self.push_tag(Tag::Undefined) } + + fn swap_nodes(&mut self, lhs: Node, rhs: Node) { + self.inner.swap(lhs.get() as usize, rhs.get() as usize); + } } // TODO: add a string-table which stores strings and maybe other bytes and @@ -973,6 +978,113 @@ impl Tree { } } +impl Tree { + pub fn get_node_children(&self, node: Node) -> Vec { + match self.nodes.get_node(node) { + Tag::FunctionProto { + name, + parameters, + return_type, + } => { + if let Some(params) = parameters { + vec![*name, *params, *return_type] + } else { + vec![*name, *return_type] + } + } + Tag::ParameterList { parameters } => parameters.clone(), + Tag::Parameter { name, ty } => { + vec![*name, *ty] + } + Tag::Pointer { pointee } => { + vec![*pointee] + } + Tag::FunctionDecl { proto, body } => { + vec![*proto, *body] + } + Tag::Block { + statements, + trailing_expr, + } => { + let mut children = statements.clone(); + if let Some(expr) = trailing_expr { + children.push(*expr); + } + children + } + Tag::ReturnStmt { expr } => expr.into_iter().cloned().collect::>(), + &Tag::ExprStmt { expr } => { + vec![expr] + } + Tag::VarDecl { + name, + explicit_type, + .. + } => { + if let Some(ty) = *explicit_type { + vec![*name, ty] + } else { + vec![*name] + } + } + Tag::GlobalDecl { + name, + explicit_type, + .. + } => { + if let Some(ty) = *explicit_type { + vec![*name, ty] + } else { + vec![*name] + } + } + &Tag::CallExpr { lhs, rhs } => { + if let Some(rhs) = rhs { + vec![lhs, rhs] + } else { + vec![lhs] + } + } + Tag::ArgumentList { parameters } => parameters.clone(), + &Tag::Argument { name, expr } => { + if let Some(name) = name { + vec![name, expr] + } else { + vec![expr] + } + } + &Tag::ExplicitCast { lhs, typename } => { + vec![lhs, typename] + } + Tag::Deref { lhs } | Tag::Ref { lhs } | Tag::Not { lhs } | Tag::Negate { lhs } => { + vec![*lhs] + } + Tag::Or { lhs, rhs } + | Tag::And { lhs, rhs } + | Tag::BitOr { lhs, rhs } + | Tag::BitAnd { lhs, rhs } + | Tag::BitXOr { lhs, rhs } + | Tag::Eq { lhs, rhs } + | Tag::NEq { lhs, rhs } + | Tag::Lt { lhs, rhs } + | Tag::Gt { lhs, rhs } + | Tag::Le { lhs, rhs } + | Tag::Ge { lhs, rhs } + | Tag::Shl { lhs, rhs } + | Tag::Shr { lhs, rhs } + | Tag::Add { lhs, rhs } + | Tag::Sub { lhs, rhs } + | Tag::Mul { lhs, rhs } + | Tag::Rem { lhs, rhs } + | Tag::Div { lhs, rhs } + | Tag::Assign { lhs, rhs } => { + vec![*lhs, *rhs] + } + _ => vec![], + } + } +} + impl Tree { fn render_node( &mut self, @@ -1605,15 +1717,62 @@ impl Tree { fn fold_comptime_with_visitor(&mut self, decl: Node) { ast::tree_visitor::Visitor::new( - self, decl, - |_, node| { - eprint!("%{node} "); - }, - |tree, node| { - if let Ok(value) = tree.fold_comptime_inner(node) { + |_: &mut Tree, _| {}, + |tree: &mut Tree, node| { + + let value_node = if let &Tag::DeclRef(lhs) = tree.nodes.get_node(node) { + let start = lhs; + let end = node; + let mut is_comptime = true; + let mut last_value = None; + eprintln!( + "checking if %{}, referencing %{} is comptime-evaluable", + node.get(), + lhs.get() + ); + ast::tree_visitor::Visitor::new_range_inclusive( + decl, + end, + |_: &Tree, _| { + }, + |tree: &Tree, node| match tree.nodes.get_node(node) { + &Tag::Assign { lhs, rhs } => { + if lhs == start || matches!(tree.nodes.get_node(lhs), &Tag::DeclRef(decl) if decl == start) { + eprintln!("found assignment at %{}", node.get()); + is_comptime &= tree.is_node_comptime(rhs); + if is_comptime { + last_value = Some(rhs); + } + } + } + &Tag::Ref { lhs } if lhs == start => { + // recursively checking for derefs would get very complicated. + is_comptime = false; + } + _ => {} + }, + ) + .skip_until(tree, start) + .visit(tree); + + eprintln!( + "%{} is {}comptime-evaluable.", + node.get(), + if is_comptime { "" } else { "not " } + ); + + eprintln!("%{node} comptime-value is %{last_value:?}"); + + is_comptime.then_some(last_value).flatten().unwrap_or(node) + }else { + node + }; + + if let Ok(value) = tree.fold_comptime_inner(value_node) { let (bytes, ty) = value.into_bytes_and_type(); let idx = tree.strings.insert(bytes); + *tree.nodes.get_node_mut(node) = Tag::Constant { bytes: ImmOrIndex::Index(idx), ty, @@ -1621,11 +1780,10 @@ impl Tree { } }, ) - .visit(); + .visit_mut(self); } fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result { - // if self.is_node_comptime(decl) { match self.nodes.get_node(decl) { Tag::Constant { bytes, ty } => { @@ -1679,9 +1837,6 @@ impl Tree { _ => unimplemented!(), }; } - Tag::DeclRef(lhs) => { - return self.fold_comptime_inner(*lhs); - } Tag::Not { lhs } => { let lhs = self.fold_comptime_inner(*lhs)?; return lhs.not(); @@ -1812,6 +1967,28 @@ impl Tree { return lhs.div(rhs); } + &Tag::DeclRef(lhs) => { + variant!(self.nodes.get_node(lhs) => &Tag::VarDecl { assignment, .. }); + + let start = assignment.unwrap_or(lhs); + let end = decl; + let mut last_value = None; + ast::tree_visitor::Visitor::new_range( + start, + end, + |_: &Tree, _| {}, + |tree: &Tree, node| match tree.nodes.get_node(node) { + &Tag::Assign { lhs, rhs } if lhs == start => { + last_value = Some(rhs); + } + _ => {} + }, + ) + .visit(self); + + return self + .fold_comptime_inner(last_value.ok_or(comptime::Error::NotComptime)?); + } _ => { unreachable!() } @@ -2061,7 +2238,9 @@ const global: u32 = 42u32; fn comptime() { let src = " fn main() -> void { -let a = 3 * 49573 << 4; +let x: u32; +x = 666u32; +let a = x + 3 * 49573 << 4; } "; let tokens = Tokenizer::new(src.as_bytes()).unwrap();