comptime folding works!!!!!

This commit is contained in:
Janis 2024-08-27 18:49:28 +02:00
parent 18e29f1fa1
commit 270162850d
3 changed files with 322 additions and 29 deletions

View file

@ -519,3 +519,218 @@ impl ToString for PrimitiveType {
}
}
}
pub mod tree_visitor {
use crate::ast::Node;
use crate::parser::Tree;
struct Frame {
node: Node,
children: Vec<Node>,
}
enum PrePost {
Pre(Node),
Post(Node),
}
/// Don't modify `node` in `pre()`
/// Don't modify `children` in `pre()`
pub struct Visitor<'a, F1, F2> {
tree: &'a mut Tree,
frames: Vec<Frame>,
pre: F1,
post: F2,
}
impl<'a, F1, F2> Visitor<'a, F1, F2> {
pub fn new<T, U>(tree: &'a mut Tree, start: Node, pre: F1, post: F2) -> Visitor<'a, F1, F2>
where
F1: FnMut(&mut Tree, Node) -> T,
F2: FnMut(&mut Tree, Node) -> U,
{
let frame = Frame {
node: Node::MAX,
children: vec![start],
};
Self {
frames: vec![frame],
tree,
pre,
post,
}
}
fn get_children(&self, node: Node) -> Vec<Node> {
match self.tree.nodes.get_node(node) {
super::Tag::FunctionProto {
name,
parameters,
return_type,
} => {
if let Some(params) = parameters {
vec![*name, *params, *return_type]
} else {
vec![*name, *return_type]
}
}
super::Tag::ParameterList { parameters } => parameters.clone(),
super::Tag::Parameter { name, ty } => {
vec![*name, *ty]
}
super::Tag::Pointer { pointee } => {
vec![*pointee]
}
super::Tag::FunctionDecl { proto, body } => {
vec![*proto, *body]
}
super::Tag::Block {
statements,
trailing_expr,
} => {
let mut children = statements.clone();
if let Some(expr) = trailing_expr {
children.push(*expr);
}
children
}
super::Tag::ReturnStmt { expr } => expr.into_iter().cloned().collect::<Vec<_>>(),
&super::Tag::ExprStmt { expr } => {
vec![expr]
}
super::Tag::VarDecl {
name,
explicit_type,
..
} => {
if let Some(ty) = *explicit_type {
vec![*name, ty]
} else {
vec![*name]
}
}
super::Tag::GlobalDecl {
name,
explicit_type,
..
} => {
if let Some(ty) = *explicit_type {
vec![*name, ty]
} else {
vec![*name]
}
}
&super::Tag::CallExpr { lhs, rhs } => {
if let Some(rhs) = rhs {
vec![lhs, rhs]
} else {
vec![lhs]
}
}
super::Tag::ArgumentList { parameters } => parameters.clone(),
&super::Tag::Argument { name, expr } => {
if let Some(name) = name {
vec![name, expr]
} else {
vec![expr]
}
}
&super::Tag::ExplicitCast { lhs, typename } => {
vec![lhs, typename]
}
super::Tag::Deref { lhs }
| super::Tag::Ref { lhs }
| super::Tag::Not { lhs }
| super::Tag::Negate { lhs } => {
vec![*lhs]
}
super::Tag::Or { lhs, rhs }
| super::Tag::And { lhs, rhs }
| super::Tag::BitOr { lhs, rhs }
| super::Tag::BitAnd { lhs, rhs }
| super::Tag::BitXOr { lhs, rhs }
| super::Tag::Eq { lhs, rhs }
| super::Tag::NEq { lhs, rhs }
| super::Tag::Lt { lhs, rhs }
| super::Tag::Gt { lhs, rhs }
| super::Tag::Le { lhs, rhs }
| super::Tag::Ge { lhs, rhs }
| super::Tag::Shl { lhs, rhs }
| super::Tag::Shr { lhs, rhs }
| super::Tag::Add { lhs, rhs }
| super::Tag::Sub { lhs, rhs }
| super::Tag::Mul { lhs, rhs }
| super::Tag::Rem { lhs, rhs }
| super::Tag::Div { lhs, rhs }
| super::Tag::Assign { lhs, rhs } => {
vec![*lhs, *rhs]
}
_ => vec![],
}
}
fn next_node(&mut self) -> Option<PrePost> {
loop {
let frame = self.frames.last_mut()?;
if let Some(node) = frame.children.pop() {
return Some(PrePost::Pre(node));
} else {
let frame = self.frames.pop()?;
if frame.node != Node::MAX {
return Some(PrePost::Post(frame.node));
}
}
}
}
pub fn visit_ok<T, E>(mut self) -> core::result::Result<T, E>
where
F1: FnMut(&mut Tree, Node) -> core::result::Result<T, E>,
F2: FnMut(&mut Tree, Node) -> core::result::Result<T, E>,
{
let mut t = None;
loop {
let Some(node) = self.next_node() else {
break;
};
match node {
PrePost::Pre(node) => {
t = Some((self.pre)(self.tree, node)?);
let children = self.get_children(node);
self.frames.push(Frame { node, children });
}
PrePost::Post(node) => {
t = Some((self.post)(self.tree, node)?);
}
}
}
Ok(t.unwrap())
}
pub fn visit<T, U>(mut self)
where
F1: FnMut(&mut Tree, Node) -> T,
F2: FnMut(&mut Tree, Node) -> U,
{
loop {
let Some(node) = self.next_node() else {
break;
};
match node {
PrePost::Pre(node) => {
(self.pre)(self.tree, node);
let children = self.get_children(node);
self.frames.push(Frame { node, children });
}
PrePost::Post(node) => {
(self.post)(self.tree, node);
}
}
}
}
}
}

View file

@ -13,7 +13,7 @@ pub mod bigint {
impl BigInt {
pub fn parse_digits<C: IntoIterator<Item = char>>(text: C, radix: Radix) -> BigInt {
parse_bigint(text.into_iter(), radix)
Self(parse_bigint(text.into_iter(), radix))
}
pub fn from_u32(v: u32) -> BigInt {
Self(vec![v])
@ -1022,7 +1022,7 @@ pub mod bigint {
carry
}
fn parse_bigint(text: impl Iterator<Item = char>, radix: Radix) -> BigInt {
pub fn parse_bigint(text: impl Iterator<Item = char>, radix: Radix) -> Vec<u32> {
let digits = text
.filter_map(|c| match c {
'_' => None,
@ -1080,7 +1080,8 @@ pub mod bigint {
}
}
}
BigInt(data)
data
}
#[cfg(test)]
@ -1089,53 +1090,53 @@ pub mod bigint {
#[test]
fn parse() {
let bigint = super::parse_bigint("2_cafe_babe_dead_beef".chars(), Radix::Hex);
let bigint = BigInt::parse_digits("2_cafe_babe_dead_beef".chars(), Radix::Hex);
println!("{:#x?}", bigint);
let bigint = super::parse_bigint("f".chars(), Radix::Hex);
let bigint = BigInt::parse_digits("f".chars(), Radix::Hex);
println!("{:#x?}", bigint);
}
#[test]
fn add() {
let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex);
let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex);
println!("{:#x?}", a);
let b = super::parse_bigint("cafebabe".chars(), Radix::Hex);
let b = BigInt::parse_digits("cafebabe".chars(), Radix::Hex);
println!("{:#x?}", b);
let sum = a + b;
println!("{:#x?}", sum);
}
#[test]
fn sub() {
let a = super::parse_bigint("deadbeef".chars(), Radix::Hex);
let a = BigInt::parse_digits("deadbeef".chars(), Radix::Hex);
println!("{:#x?}", a);
let b = super::parse_bigint("56d2c".chars(), Radix::Hex);
let b = BigInt::parse_digits("56d2c".chars(), Radix::Hex);
println!("{:#x?}", b);
let sum = a - b;
println!("{:#x?}", sum);
}
#[test]
fn overflowing_sub() {
let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex);
let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex);
println!("{:#x?}", a);
let b = super::parse_bigint("ffff_ffff".chars(), Radix::Hex);
let b = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex);
println!("{:#x?}", b);
let sum = b - a;
println!("{:#x?}", sum);
}
#[test]
fn shr() {
let mut a = super::parse_bigint("cafe_babe_0000".chars(), Radix::Hex);
let mut a = BigInt::parse_digits("cafe_babe_0000".chars(), Radix::Hex);
print!("{:0>8x?} >> 32 ", a);
shr_bitint(&mut a.0, 32);
println!("{:0>8x?}", a);
let mut a = super::parse_bigint("11110000".chars(), Radix::Bin);
let mut a = BigInt::parse_digits("11110000".chars(), Radix::Bin);
print!("{:0>8x?} >> 32 ", a);
shr_bitint(&mut a.0, 3);
println!("{:0>8x?}", a);
}
#[test]
fn shl() {
let mut a = super::parse_bigint("ffff_ffff".chars(), Radix::Hex);
let mut a = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex);
a.0.extend([0; 4]);
println!("{:0>8x?}", a);
shl_bitint(&mut a.0, 40);
@ -1143,8 +1144,8 @@ pub mod bigint {
}
#[test]
fn div() {
let a = super::parse_bigint("cafebabe".chars(), Radix::Hex);
let b = super::parse_bigint("dead".chars(), Radix::Hex);
let a = BigInt::parse_digits("cafebabe".chars(), Radix::Hex);
let b = BigInt::parse_digits("dead".chars(), Radix::Hex);
let (div, rem) = div_rem_bigint(a, b);
println!("div: {:0>8x?}", div);
println!("rem: {:0>8x?}", rem);
@ -1501,13 +1502,13 @@ pub mod bigsint {
use std::{
cmp::Ordering,
ops::{Add, BitAnd, BitOr, BitXor, Not},
ops::{BitAnd, BitOr, BitXor, Not},
};
use num_bigint::{BigInt, BigUint, Sign};
use num_traits::cast::ToPrimitive;
use num_traits::{cast::ToPrimitive, ToBytes};
use crate::ast::{FloatingType, IntegralType};
use crate::ast::{FloatingType, IntegralType, Type};
#[derive(Debug, thiserror::Error)]
pub enum Error {
@ -1521,6 +1522,8 @@ pub enum Error {
UnsignedNegation,
#[error("Incomparable floats.")]
FloatingCmp,
#[error("Not a comptime expression.")]
NotComptime,
}
pub type Result<T> = core::result::Result<T, Error>;
@ -2263,4 +2266,31 @@ impl ComptimeNumber {
FloatingType::Binary64 => Ok(f.into()),
}
}
pub fn into_bytes_and_type(self) -> (Vec<u8>, Type) {
match self {
ComptimeNumber::Integral(i) => match i {
ComptimeInt::Native { bits, ty } => {
(bits.to_le_bytes().to_vec(), Type::Integer(ty))
}
ComptimeInt::BigInt { bits, ty } => {
(bits.to_le_bytes().to_vec(), Type::Integer(ty))
}
ComptimeInt::Comptime(bits) => {
(bits.to_le_bytes().to_vec(), Type::comptime_number())
}
},
ComptimeNumber::Bool(b) => (vec![b as u8], Type::bool()),
ComptimeNumber::Floating(f) => match f {
ComptimeFloat::Binary32(f) => (
f.to_le_bytes().to_vec(),
Type::Floating(FloatingType::Binary32),
),
ComptimeFloat::Binary64(f) => (
f.to_le_bytes().to_vec(),
Type::Floating(FloatingType::Binary64),
),
},
}
}
}

View file

@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Display};
use itertools::Itertools;
use num_bigint::BigInt;
use num_bigint::{BigInt, BigUint};
use crate::{
ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type},
ast::{self, FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type},
common::NextIf,
comptime::{self, ComptimeNumber},
error::{AnalysisError, AnalysisErrorTag},
@ -252,7 +252,7 @@ impl Tree {
.map(|(_, c)| c)
.collect::<Vec<_>>();
let value = BigInt::parse_digits(digits, radix);
let value = comptime::bigint::parse_bigint(digits.into_iter(), radix);
let ty = match iter.clone().next() {
Some((_, 'u')) | Some((_, 'i')) => {
@ -261,7 +261,10 @@ impl Tree {
_ => None,
};
(value, ty)
(
BigInt::from_biguint(num_bigint::Sign::Plus, BigUint::new(value)),
ty,
)
}
fn parse_floating_constant(_token: Token, lexeme: &str) -> (u64, FloatingType) {
@ -860,7 +863,7 @@ impl Tree {
| Token::IntegerConstant => {
_ = tokens.next();
let (bits, ty) = Self::parse_integral_constant(token.token(), token.lexeme());
let bytes = bits.into_bytes_le();
let (_, bytes) = bits.to_bytes_le();
const BUF_SIZE: usize = core::mem::size_of::<u64>();
let mut buf = [0u8; BUF_SIZE];
@ -1596,10 +1599,31 @@ impl Tree {
fn try_fold_comptime_inner(&mut self, node: Node) {
if self.is_node_comptime(node) {
self.fold_comptime_inner(node);
_ = self.fold_comptime_inner(node);
}
}
fn fold_comptime_with_visitor(&mut self, decl: Node) {
ast::tree_visitor::Visitor::new(
self,
decl,
|_, node| {
eprint!("%{node} ");
},
|tree, node| {
if let Ok(value) = tree.fold_comptime_inner(node) {
let (bytes, ty) = value.into_bytes_and_type();
let idx = tree.strings.insert(bytes);
*tree.nodes.get_node_mut(node) = Tag::Constant {
bytes: ImmOrIndex::Index(idx),
ty,
};
}
},
)
.visit();
}
fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result<ComptimeNumber> {
//
if self.is_node_comptime(decl) {
@ -1792,18 +1816,19 @@ impl Tree {
unreachable!()
}
}
} else {
Err(comptime::Error::NotComptime)
}
todo!()
}
pub fn fold_comptime(&mut self) {
for decl in self.global_decls.clone() {
match self.nodes.get_node(decl) {
Tag::FunctionDecl { body, .. } => {
self.fold_comptime_inner(*body);
_ = self.fold_comptime_inner(*body);
}
Tag::GlobalDecl { assignment, .. } => {
self.fold_comptime_inner(*assignment);
_ = self.fold_comptime_inner(*assignment);
}
_ => unreachable!(),
}
@ -2031,4 +2056,27 @@ const global: u32 = 42u32;
tree.render(&mut buf).unwrap();
println!("{buf}");
}
#[test]
fn comptime() {
let src = "
fn main() -> void {
let a = 3 * 49573 << 4;
}
";
let tokens = Tokenizer::new(src.as_bytes()).unwrap();
let mut tree = Tree::new();
tree.parse(tokens.iter()).unwrap();
let mut buf = String::new();
tree.render(&mut buf).unwrap();
println!("{buf}");
tree.fold_comptime_with_visitor(tree.global_decls.first().cloned().unwrap());
let mut buf = String::new();
tree.render(&mut buf).unwrap();
println!("{buf}");
}
}