From 270162850dcaa8fb2977ee4b418caec8074f25f8 Mon Sep 17 00:00:00 2001 From: Janis Date: Tue, 27 Aug 2024 18:49:28 +0200 Subject: [PATCH] comptime folding works!!!!! --- src/ast.rs | 215 ++++++++++++++++++++++++++++++++++++++++++++++++ src/comptime.rs | 68 ++++++++++----- src/parser.rs | 68 ++++++++++++--- 3 files changed, 322 insertions(+), 29 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index e4b312e..6691831 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -519,3 +519,218 @@ impl ToString for PrimitiveType { } } } + +pub mod tree_visitor { + + use crate::ast::Node; + use crate::parser::Tree; + + struct Frame { + node: Node, + children: Vec, + } + + enum PrePost { + Pre(Node), + Post(Node), + } + + /// Don't modify `node` in `pre()` + /// Don't modify `children` in `pre()` + pub struct Visitor<'a, F1, F2> { + tree: &'a mut Tree, + frames: Vec, + pre: F1, + post: F2, + } + + impl<'a, F1, F2> Visitor<'a, F1, F2> { + pub fn new(tree: &'a mut Tree, start: Node, pre: F1, post: F2) -> Visitor<'a, F1, F2> + where + F1: FnMut(&mut Tree, Node) -> T, + F2: FnMut(&mut Tree, Node) -> U, + { + let frame = Frame { + node: Node::MAX, + children: vec![start], + }; + Self { + frames: vec![frame], + tree, + pre, + post, + } + } + + fn get_children(&self, node: Node) -> Vec { + match self.tree.nodes.get_node(node) { + super::Tag::FunctionProto { + name, + parameters, + return_type, + } => { + if let Some(params) = parameters { + vec![*name, *params, *return_type] + } else { + vec![*name, *return_type] + } + } + super::Tag::ParameterList { parameters } => parameters.clone(), + super::Tag::Parameter { name, ty } => { + vec![*name, *ty] + } + super::Tag::Pointer { pointee } => { + vec![*pointee] + } + super::Tag::FunctionDecl { proto, body } => { + vec![*proto, *body] + } + super::Tag::Block { + statements, + trailing_expr, + } => { + let mut children = statements.clone(); + if let Some(expr) = trailing_expr { + children.push(*expr); + } + children + } + super::Tag::ReturnStmt { expr } => expr.into_iter().cloned().collect::>(), + &super::Tag::ExprStmt { expr } => { + vec![expr] + } + super::Tag::VarDecl { + name, + explicit_type, + .. + } => { + if let Some(ty) = *explicit_type { + vec![*name, ty] + } else { + vec![*name] + } + } + super::Tag::GlobalDecl { + name, + explicit_type, + .. + } => { + if let Some(ty) = *explicit_type { + vec![*name, ty] + } else { + vec![*name] + } + } + &super::Tag::CallExpr { lhs, rhs } => { + if let Some(rhs) = rhs { + vec![lhs, rhs] + } else { + vec![lhs] + } + } + super::Tag::ArgumentList { parameters } => parameters.clone(), + &super::Tag::Argument { name, expr } => { + if let Some(name) = name { + vec![name, expr] + } else { + vec![expr] + } + } + &super::Tag::ExplicitCast { lhs, typename } => { + vec![lhs, typename] + } + super::Tag::Deref { lhs } + | super::Tag::Ref { lhs } + | super::Tag::Not { lhs } + | super::Tag::Negate { lhs } => { + vec![*lhs] + } + super::Tag::Or { lhs, rhs } + | super::Tag::And { lhs, rhs } + | super::Tag::BitOr { lhs, rhs } + | super::Tag::BitAnd { lhs, rhs } + | super::Tag::BitXOr { lhs, rhs } + | super::Tag::Eq { lhs, rhs } + | super::Tag::NEq { lhs, rhs } + | super::Tag::Lt { lhs, rhs } + | super::Tag::Gt { lhs, rhs } + | super::Tag::Le { lhs, rhs } + | super::Tag::Ge { lhs, rhs } + | super::Tag::Shl { lhs, rhs } + | super::Tag::Shr { lhs, rhs } + | super::Tag::Add { lhs, rhs } + | super::Tag::Sub { lhs, rhs } + | super::Tag::Mul { lhs, rhs } + | super::Tag::Rem { lhs, rhs } + | super::Tag::Div { lhs, rhs } + | super::Tag::Assign { lhs, rhs } => { + vec![*lhs, *rhs] + } + _ => vec![], + } + } + + fn next_node(&mut self) -> Option { + loop { + let frame = self.frames.last_mut()?; + if let Some(node) = frame.children.pop() { + return Some(PrePost::Pre(node)); + } else { + let frame = self.frames.pop()?; + if frame.node != Node::MAX { + return Some(PrePost::Post(frame.node)); + } + } + } + } + + pub fn visit_ok(mut self) -> core::result::Result + where + F1: FnMut(&mut Tree, Node) -> core::result::Result, + F2: FnMut(&mut Tree, Node) -> core::result::Result, + { + let mut t = None; + loop { + let Some(node) = self.next_node() else { + break; + }; + + match node { + PrePost::Pre(node) => { + t = Some((self.pre)(self.tree, node)?); + let children = self.get_children(node); + self.frames.push(Frame { node, children }); + } + PrePost::Post(node) => { + t = Some((self.post)(self.tree, node)?); + } + } + } + + Ok(t.unwrap()) + } + + pub fn visit(mut self) + where + F1: FnMut(&mut Tree, Node) -> T, + F2: FnMut(&mut Tree, Node) -> U, + { + loop { + let Some(node) = self.next_node() else { + break; + }; + + match node { + PrePost::Pre(node) => { + (self.pre)(self.tree, node); + let children = self.get_children(node); + self.frames.push(Frame { node, children }); + } + PrePost::Post(node) => { + (self.post)(self.tree, node); + } + } + } + } + } +} diff --git a/src/comptime.rs b/src/comptime.rs index e6522ef..f625ffa 100644 --- a/src/comptime.rs +++ b/src/comptime.rs @@ -13,7 +13,7 @@ pub mod bigint { impl BigInt { pub fn parse_digits>(text: C, radix: Radix) -> BigInt { - parse_bigint(text.into_iter(), radix) + Self(parse_bigint(text.into_iter(), radix)) } pub fn from_u32(v: u32) -> BigInt { Self(vec![v]) @@ -1022,7 +1022,7 @@ pub mod bigint { carry } - fn parse_bigint(text: impl Iterator, radix: Radix) -> BigInt { + pub fn parse_bigint(text: impl Iterator, radix: Radix) -> Vec { let digits = text .filter_map(|c| match c { '_' => None, @@ -1080,7 +1080,8 @@ pub mod bigint { } } } - BigInt(data) + + data } #[cfg(test)] @@ -1089,53 +1090,53 @@ pub mod bigint { #[test] fn parse() { - let bigint = super::parse_bigint("2_cafe_babe_dead_beef".chars(), Radix::Hex); + let bigint = BigInt::parse_digits("2_cafe_babe_dead_beef".chars(), Radix::Hex); println!("{:#x?}", bigint); - let bigint = super::parse_bigint("f".chars(), Radix::Hex); + let bigint = BigInt::parse_digits("f".chars(), Radix::Hex); println!("{:#x?}", bigint); } #[test] fn add() { - let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); + let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex); println!("{:#x?}", a); - let b = super::parse_bigint("cafebabe".chars(), Radix::Hex); + let b = BigInt::parse_digits("cafebabe".chars(), Radix::Hex); println!("{:#x?}", b); let sum = a + b; println!("{:#x?}", sum); } #[test] fn sub() { - let a = super::parse_bigint("deadbeef".chars(), Radix::Hex); + let a = BigInt::parse_digits("deadbeef".chars(), Radix::Hex); println!("{:#x?}", a); - let b = super::parse_bigint("56d2c".chars(), Radix::Hex); + let b = BigInt::parse_digits("56d2c".chars(), Radix::Hex); println!("{:#x?}", b); let sum = a - b; println!("{:#x?}", sum); } #[test] fn overflowing_sub() { - let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); + let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex); println!("{:#x?}", a); - let b = super::parse_bigint("ffff_ffff".chars(), Radix::Hex); + let b = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex); println!("{:#x?}", b); let sum = b - a; println!("{:#x?}", sum); } #[test] fn shr() { - let mut a = super::parse_bigint("cafe_babe_0000".chars(), Radix::Hex); + let mut a = BigInt::parse_digits("cafe_babe_0000".chars(), Radix::Hex); print!("{:0>8x?} >> 32 ", a); shr_bitint(&mut a.0, 32); println!("{:0>8x?}", a); - let mut a = super::parse_bigint("11110000".chars(), Radix::Bin); + let mut a = BigInt::parse_digits("11110000".chars(), Radix::Bin); print!("{:0>8x?} >> 32 ", a); shr_bitint(&mut a.0, 3); println!("{:0>8x?}", a); } #[test] fn shl() { - let mut a = super::parse_bigint("ffff_ffff".chars(), Radix::Hex); + let mut a = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex); a.0.extend([0; 4]); println!("{:0>8x?}", a); shl_bitint(&mut a.0, 40); @@ -1143,8 +1144,8 @@ pub mod bigint { } #[test] fn div() { - let a = super::parse_bigint("cafebabe".chars(), Radix::Hex); - let b = super::parse_bigint("dead".chars(), Radix::Hex); + let a = BigInt::parse_digits("cafebabe".chars(), Radix::Hex); + let b = BigInt::parse_digits("dead".chars(), Radix::Hex); let (div, rem) = div_rem_bigint(a, b); println!("div: {:0>8x?}", div); println!("rem: {:0>8x?}", rem); @@ -1501,13 +1502,13 @@ pub mod bigsint { use std::{ cmp::Ordering, - ops::{Add, BitAnd, BitOr, BitXor, Not}, + ops::{BitAnd, BitOr, BitXor, Not}, }; use num_bigint::{BigInt, BigUint, Sign}; -use num_traits::cast::ToPrimitive; +use num_traits::{cast::ToPrimitive, ToBytes}; -use crate::ast::{FloatingType, IntegralType}; +use crate::ast::{FloatingType, IntegralType, Type}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -1521,6 +1522,8 @@ pub enum Error { UnsignedNegation, #[error("Incomparable floats.")] FloatingCmp, + #[error("Not a comptime expression.")] + NotComptime, } pub type Result = core::result::Result; @@ -2263,4 +2266,31 @@ impl ComptimeNumber { FloatingType::Binary64 => Ok(f.into()), } } + + pub fn into_bytes_and_type(self) -> (Vec, Type) { + match self { + ComptimeNumber::Integral(i) => match i { + ComptimeInt::Native { bits, ty } => { + (bits.to_le_bytes().to_vec(), Type::Integer(ty)) + } + ComptimeInt::BigInt { bits, ty } => { + (bits.to_le_bytes().to_vec(), Type::Integer(ty)) + } + ComptimeInt::Comptime(bits) => { + (bits.to_le_bytes().to_vec(), Type::comptime_number()) + } + }, + ComptimeNumber::Bool(b) => (vec![b as u8], Type::bool()), + ComptimeNumber::Floating(f) => match f { + ComptimeFloat::Binary32(f) => ( + f.to_le_bytes().to_vec(), + Type::Floating(FloatingType::Binary32), + ), + ComptimeFloat::Binary64(f) => ( + f.to_le_bytes().to_vec(), + Type::Floating(FloatingType::Binary64), + ), + }, + } + } } diff --git a/src/parser.rs b/src/parser.rs index 6996195..42c7af5 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Display}; use itertools::Itertools; -use num_bigint::BigInt; +use num_bigint::{BigInt, BigUint}; use crate::{ - ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type}, + ast::{self, FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type}, common::NextIf, comptime::{self, ComptimeNumber}, error::{AnalysisError, AnalysisErrorTag}, @@ -252,7 +252,7 @@ impl Tree { .map(|(_, c)| c) .collect::>(); - let value = BigInt::parse_digits(digits, radix); + let value = comptime::bigint::parse_bigint(digits.into_iter(), radix); let ty = match iter.clone().next() { Some((_, 'u')) | Some((_, 'i')) => { @@ -261,7 +261,10 @@ impl Tree { _ => None, }; - (value, ty) + ( + BigInt::from_biguint(num_bigint::Sign::Plus, BigUint::new(value)), + ty, + ) } fn parse_floating_constant(_token: Token, lexeme: &str) -> (u64, FloatingType) { @@ -860,7 +863,7 @@ impl Tree { | Token::IntegerConstant => { _ = tokens.next(); let (bits, ty) = Self::parse_integral_constant(token.token(), token.lexeme()); - let bytes = bits.into_bytes_le(); + let (_, bytes) = bits.to_bytes_le(); const BUF_SIZE: usize = core::mem::size_of::(); let mut buf = [0u8; BUF_SIZE]; @@ -1596,10 +1599,31 @@ impl Tree { fn try_fold_comptime_inner(&mut self, node: Node) { if self.is_node_comptime(node) { - self.fold_comptime_inner(node); + _ = self.fold_comptime_inner(node); } } + fn fold_comptime_with_visitor(&mut self, decl: Node) { + ast::tree_visitor::Visitor::new( + self, + decl, + |_, node| { + eprint!("%{node} "); + }, + |tree, node| { + if let Ok(value) = tree.fold_comptime_inner(node) { + let (bytes, ty) = value.into_bytes_and_type(); + let idx = tree.strings.insert(bytes); + *tree.nodes.get_node_mut(node) = Tag::Constant { + bytes: ImmOrIndex::Index(idx), + ty, + }; + } + }, + ) + .visit(); + } + fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result { // if self.is_node_comptime(decl) { @@ -1792,18 +1816,19 @@ impl Tree { unreachable!() } } + } else { + Err(comptime::Error::NotComptime) } - todo!() } pub fn fold_comptime(&mut self) { for decl in self.global_decls.clone() { match self.nodes.get_node(decl) { Tag::FunctionDecl { body, .. } => { - self.fold_comptime_inner(*body); + _ = self.fold_comptime_inner(*body); } Tag::GlobalDecl { assignment, .. } => { - self.fold_comptime_inner(*assignment); + _ = self.fold_comptime_inner(*assignment); } _ => unreachable!(), } @@ -2031,4 +2056,27 @@ const global: u32 = 42u32; tree.render(&mut buf).unwrap(); println!("{buf}"); } + + #[test] + fn comptime() { + let src = " +fn main() -> void { +let a = 3 * 49573 << 4; +} +"; + let tokens = Tokenizer::new(src.as_bytes()).unwrap(); + + let mut tree = Tree::new(); + tree.parse(tokens.iter()).unwrap(); + + let mut buf = String::new(); + tree.render(&mut buf).unwrap(); + println!("{buf}"); + + tree.fold_comptime_with_visitor(tree.global_decls.first().cloned().unwrap()); + + let mut buf = String::new(); + tree.render(&mut buf).unwrap(); + println!("{buf}"); + } }