diff --git a/src/parser.rs b/src/parser.rs index 47fc7de..dedf160 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -319,12 +319,11 @@ impl Tree { let name_str = self.nodes.get_ident_str(name).unwrap().to_owned(); let node = if global { - let node = match self.st.root_mut().find_orderless_symbol(&name_str) { + let node = match self.st.find_root_symbol(&name_str) { Some(r) => r.node(), None => self .st - .root_mut() - .insert_orderless_symbol(&name_str, self.nodes.reserve_node()) + .insert_root_symbol(&name_str, self.nodes.reserve_node()) .node(), }; node @@ -940,7 +939,7 @@ impl Tree { trailing_expr, } => { writeln_indented!(indent, writer, "%{} = {{", node.get())?; - self.st.into_find_child(node); + self.st.into_child(node); for stmt in statements { self.render_node(writer, stmt, indent + 1)?; } diff --git a/src/symbol_table.rs b/src/symbol_table.rs index 74fcf42..8335b6b 100644 --- a/src/symbol_table.rs +++ b/src/symbol_table.rs @@ -1,4 +1,7 @@ -use std::collections::{BTreeMap, HashMap}; +use std::{ + collections::{BTreeMap, HashMap}, + ptr::NonNull, +}; use crate::ast::Node as AstNode; @@ -24,63 +27,75 @@ pub enum SymbolKind { } #[derive(Debug, Default)] -pub struct SymbolTable { - // this is a `Vec<_>` because order matters. Some symbols such as functions - // cannot be shadowed, but I really like shadowing variables and function - // parameters, so any `x` may be redefined. +struct InnerSymbolTable { ordered_identifiers: Vec, orderless_identifiers: HashMap, - children: BTreeMap, SymbolTable>, - scope: Option, - parent: Option>, + children: BTreeMap>, + parent: Option>, } - -impl SymbolTable { - pub fn new() -> SymbolTable { +impl InnerSymbolTable { + fn new() -> NonNull { + Self::new_with(Self::new_inner) + } + fn new_with(gen: G) -> NonNull + where + G: FnOnce() -> Self, + { + NonNull::new(Box::leak(Box::new(gen())) as *mut _).unwrap() + } + fn new_inner() -> InnerSymbolTable { Self { + parent: None, ordered_identifiers: Vec::new(), orderless_identifiers: HashMap::new(), children: BTreeMap::new(), - scope: None, - parent: None, } } - pub fn root_ref(&self) -> &SymbolTable { - match self.parent.as_ref() { - Some(parent) => parent.root_ref(), - None => self, - } + fn make_child(&self) -> NonNull { + Self::new_with(|| Self { + parent: NonNull::new(self.as_ptr()), + ordered_identifiers: Vec::new(), + orderless_identifiers: HashMap::new(), + children: BTreeMap::new(), + }) } - pub fn root_mut(&mut self) -> &mut SymbolTable { - let this = self as *mut Self; - unsafe { - match (&mut *this).parent.as_mut() { - Some(parent) => parent.root_mut(), - None => self, + fn parent(&self) -> Option> { + self.parent + } + + fn parent_ref(&self) -> Option<&InnerSymbolTable> { + unsafe { self.parent.map(|p| p.as_ref()) } + } + + fn parent_mut(&mut self) -> Option<&mut InnerSymbolTable> { + unsafe { self.parent.map(|mut p| p.as_mut()) } + } + + fn as_ptr(&self) -> *mut Self { + self as *const _ as *mut _ + } + + fn root(&self) -> NonNull { + self.parent() + .map(|p| unsafe { p.as_ref().root() }) + .unwrap_or(NonNull::new(self.as_ptr()).unwrap()) + } +} + +impl Drop for InnerSymbolTable { + fn drop(&mut self) { + for child in self.children.values() { + unsafe { + _ = Box::from_raw(child.as_ptr()); } } } +} - pub fn parent_ref(&self) -> &SymbolTable { - match self.parent.as_ref() { - Some(parent) => Box::as_ref(parent), - None => self, - } - } - - pub fn parent_mut(&mut self) -> &mut SymbolTable { - let this = self as *mut Self; - unsafe { - match (&mut *this).parent.as_mut() { - Some(parent) => Box::as_mut(parent), - None => self, - } - } - } - - pub fn insert_symbol(&mut self, name: &str, node: AstNode, kind: SymbolKind) -> &SymbolRecord { +impl InnerSymbolTable { + fn insert_symbol(&mut self, name: &str, node: AstNode, kind: SymbolKind) -> &SymbolRecord { match kind { SymbolKind::Var => { self.ordered_identifiers.push(SymbolRecord { @@ -89,20 +104,11 @@ impl SymbolTable { }); self.ordered_identifiers.last().unwrap() } - _ => { - self.orderless_identifiers.insert( - name.to_owned(), - SymbolRecord { - name: name.to_owned(), - decl: node, - }, - ); - self.orderless_identifiers.get(name).unwrap() - } + _ => self.insert_orderless_symbol(name, node), } } - pub fn insert_orderless_symbol(&mut self, name: &str, node: AstNode) -> &SymbolRecord { + fn insert_orderless_symbol(&mut self, name: &str, node: AstNode) -> &SymbolRecord { self.orderless_identifiers.insert( name.to_owned(), SymbolRecord { @@ -113,7 +119,7 @@ impl SymbolTable { self.orderless_identifiers.get(name).unwrap() } - pub fn find_symbol_or_insert_with<'a, F>(&'a mut self, name: &str, cb: F) -> &'a SymbolRecord + fn find_symbol_or_insert_with<'a, F>(&'a mut self, name: &str, cb: F) -> &'a SymbolRecord where F: FnOnce() -> (AstNode, SymbolKind), { @@ -126,7 +132,7 @@ impl SymbolTable { } } - pub fn find_symbol_by_decl(&self, decl: AstNode) -> Option<&SymbolRecord> { + fn find_symbol_by_decl(&self, decl: AstNode) -> Option<&SymbolRecord> { self.ordered_identifiers .iter() .find(|r| r.decl == decl) @@ -136,53 +142,148 @@ impl SymbolTable { .find(|(_, v)| v.decl == decl) .map(|(_, v)| v) }) - .or_else(|| { - self.parent - .as_ref() - .and_then(|p| p.find_symbol_by_decl(decl)) - }) + .or_else(|| self.parent_ref().and_then(|p| p.find_symbol_by_decl(decl))) } - pub fn find_symbol(&self, name: &str) -> Option<&SymbolRecord> { + fn find_symbol(&self, name: &str) -> Option<&SymbolRecord> { self.ordered_identifiers .iter() .find(|r| r.name.as_str() == name) .or_else(|| self.orderless_identifiers.get(name)) - .or_else(|| self.parent.as_ref().and_then(|p| p.find_symbol(name))) + .or_else(|| self.parent_ref().and_then(|p| p.find_symbol(name))) } - pub fn find_orderless_symbol(&self, name: &str) -> Option<&SymbolRecord> { + fn find_orderless_symbol(&self, name: &str) -> Option<&SymbolRecord> { self.orderless_identifiers.get(name).or_else(|| { - self.parent - .as_ref() + self.parent_ref() .and_then(|p| p.find_orderless_symbol(name)) }) } - pub fn into_find_child(&mut self, scope: AstNode) -> Option<()> { - if let Some(mut parent) = self.children.remove(&Some(scope)) { - core::mem::swap(self, &mut parent); - self.parent = Some(Box::new(parent)); - Some(()) - } else { - None + fn extend_orderless(&mut self, iter: I) + where + I: IntoIterator, + { + self.orderless_identifiers.extend(iter) + } + + fn extract_orderless_if( + &mut self, + pred: F, + ) -> std::collections::hash_map::ExtractIf + where + F: FnMut(&String, &mut SymbolRecord) -> bool, + { + self.orderless_identifiers.extract_if(pred) + } +} + +#[derive(Debug)] +pub struct SymbolTableWrapper { + current: NonNull, +} + +impl Drop for SymbolTableWrapper { + fn drop(&mut self) { + unsafe { + _ = Box::from_raw(self.current.as_ref().root().as_ptr()); + } + } +} + +impl SymbolTableWrapper { + pub fn new() -> SymbolTableWrapper { + Self { + current: InnerSymbolTable::new(), } } + fn current(&self) -> &InnerSymbolTable { + unsafe { self.current.as_ref() } + } + + fn current_mut(&mut self) -> &mut InnerSymbolTable { + unsafe { self.current.as_mut() } + } + + #[allow(dead_code)] + fn root_ref(&self) -> &InnerSymbolTable { + unsafe { self.current().root().as_ref() } + } + + fn root_mut(&mut self) -> &mut InnerSymbolTable { + unsafe { self.current_mut().root().as_mut() } + } + + #[allow(dead_code)] + fn parent_ref(&self) -> Option<&InnerSymbolTable> { + self.current().parent_ref() + } + + #[allow(dead_code)] + fn parent_mut(&mut self) -> Option<&mut InnerSymbolTable> { + self.current_mut().parent_mut() + } + pub fn into_child(&mut self, scope: AstNode) { - let mut parent = Self { - scope: Some(scope), - ..Default::default() + let child = if let Some(child) = self.current().children.get(&scope) { + *child + } else { + let child = self.current().make_child(); + self.current_mut().children.insert(scope, child); + child }; - core::mem::swap(self, &mut parent); - self.parent = Some(Box::new(parent)); + self.current = child; + } + + pub fn into_parent(&mut self) { + if let Some(parent) = self.current().parent() { + self.current = parent; + } + } +} + +impl SymbolTableWrapper { + pub fn insert_symbol(&mut self, name: &str, node: AstNode, kind: SymbolKind) -> &SymbolRecord { + self.current_mut().insert_symbol(name, node, kind) + } + + pub fn find_root_symbol(&mut self, name: &str) -> Option<&SymbolRecord> { + self.root_mut().find_orderless_symbol(name) + } + + pub fn insert_root_symbol(&mut self, name: &str, node: AstNode) -> &SymbolRecord { + self.root_mut().insert_orderless_symbol(name, node) + } + + pub fn insert_orderless_symbol(&mut self, name: &str, node: AstNode) -> &SymbolRecord { + self.current_mut().insert_orderless_symbol(name, node) + } + + pub fn find_symbol_or_insert_with<'a, F>(&'a mut self, name: &str, cb: F) -> &'a SymbolRecord + where + F: FnOnce() -> (AstNode, SymbolKind), + { + self.current_mut().find_symbol_or_insert_with(name, cb) + } + + pub fn find_symbol_by_decl(&self, decl: AstNode) -> Option<&SymbolRecord> { + self.current().find_symbol_by_decl(decl) + } + + pub fn find_symbol(&self, name: &str) -> Option<&SymbolRecord> { + self.current().find_symbol(name) + } + + pub fn find_orderless_symbol(&self, name: &str) -> Option<&SymbolRecord> { + self.current().find_orderless_symbol(name) } pub fn extend_orderless(&mut self, iter: I) where I: IntoIterator, { - self.orderless_identifiers.extend(iter) + self.current_mut().extend_orderless(iter) } pub fn extract_orderless_if( @@ -192,15 +293,8 @@ impl SymbolTable { where F: FnMut(&String, &mut SymbolRecord) -> bool, { - self.orderless_identifiers.extract_if(pred) - } - - /// returns `self` if `self.parent` was `Some(_)`. - pub fn into_parent(&mut self) { - if let Some(child) = self.parent.take() { - let mut child = Box::into_inner(child); - core::mem::swap(self, &mut child); - self.children.insert(child.scope, child); - } + self.current_mut().extract_orderless_if(pred) } } + +pub type SymbolTable = SymbolTableWrapper;