comptime folding works!!!!!
This commit is contained in:
parent
18e29f1fa1
commit
270162850d
215
src/ast.rs
215
src/ast.rs
|
@ -519,3 +519,218 @@ impl ToString for PrimitiveType {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod tree_visitor {
|
||||
|
||||
use crate::ast::Node;
|
||||
use crate::parser::Tree;
|
||||
|
||||
struct Frame {
|
||||
node: Node,
|
||||
children: Vec<Node>,
|
||||
}
|
||||
|
||||
enum PrePost {
|
||||
Pre(Node),
|
||||
Post(Node),
|
||||
}
|
||||
|
||||
/// Don't modify `node` in `pre()`
|
||||
/// Don't modify `children` in `pre()`
|
||||
pub struct Visitor<'a, F1, F2> {
|
||||
tree: &'a mut Tree,
|
||||
frames: Vec<Frame>,
|
||||
pre: F1,
|
||||
post: F2,
|
||||
}
|
||||
|
||||
impl<'a, F1, F2> Visitor<'a, F1, F2> {
|
||||
pub fn new<T, U>(tree: &'a mut Tree, start: Node, pre: F1, post: F2) -> Visitor<'a, F1, F2>
|
||||
where
|
||||
F1: FnMut(&mut Tree, Node) -> T,
|
||||
F2: FnMut(&mut Tree, Node) -> U,
|
||||
{
|
||||
let frame = Frame {
|
||||
node: Node::MAX,
|
||||
children: vec![start],
|
||||
};
|
||||
Self {
|
||||
frames: vec![frame],
|
||||
tree,
|
||||
pre,
|
||||
post,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_children(&self, node: Node) -> Vec<Node> {
|
||||
match self.tree.nodes.get_node(node) {
|
||||
super::Tag::FunctionProto {
|
||||
name,
|
||||
parameters,
|
||||
return_type,
|
||||
} => {
|
||||
if let Some(params) = parameters {
|
||||
vec![*name, *params, *return_type]
|
||||
} else {
|
||||
vec![*name, *return_type]
|
||||
}
|
||||
}
|
||||
super::Tag::ParameterList { parameters } => parameters.clone(),
|
||||
super::Tag::Parameter { name, ty } => {
|
||||
vec![*name, *ty]
|
||||
}
|
||||
super::Tag::Pointer { pointee } => {
|
||||
vec![*pointee]
|
||||
}
|
||||
super::Tag::FunctionDecl { proto, body } => {
|
||||
vec![*proto, *body]
|
||||
}
|
||||
super::Tag::Block {
|
||||
statements,
|
||||
trailing_expr,
|
||||
} => {
|
||||
let mut children = statements.clone();
|
||||
if let Some(expr) = trailing_expr {
|
||||
children.push(*expr);
|
||||
}
|
||||
children
|
||||
}
|
||||
super::Tag::ReturnStmt { expr } => expr.into_iter().cloned().collect::<Vec<_>>(),
|
||||
&super::Tag::ExprStmt { expr } => {
|
||||
vec![expr]
|
||||
}
|
||||
super::Tag::VarDecl {
|
||||
name,
|
||||
explicit_type,
|
||||
..
|
||||
} => {
|
||||
if let Some(ty) = *explicit_type {
|
||||
vec![*name, ty]
|
||||
} else {
|
||||
vec![*name]
|
||||
}
|
||||
}
|
||||
super::Tag::GlobalDecl {
|
||||
name,
|
||||
explicit_type,
|
||||
..
|
||||
} => {
|
||||
if let Some(ty) = *explicit_type {
|
||||
vec![*name, ty]
|
||||
} else {
|
||||
vec![*name]
|
||||
}
|
||||
}
|
||||
&super::Tag::CallExpr { lhs, rhs } => {
|
||||
if let Some(rhs) = rhs {
|
||||
vec![lhs, rhs]
|
||||
} else {
|
||||
vec![lhs]
|
||||
}
|
||||
}
|
||||
super::Tag::ArgumentList { parameters } => parameters.clone(),
|
||||
&super::Tag::Argument { name, expr } => {
|
||||
if let Some(name) = name {
|
||||
vec![name, expr]
|
||||
} else {
|
||||
vec![expr]
|
||||
}
|
||||
}
|
||||
&super::Tag::ExplicitCast { lhs, typename } => {
|
||||
vec![lhs, typename]
|
||||
}
|
||||
super::Tag::Deref { lhs }
|
||||
| super::Tag::Ref { lhs }
|
||||
| super::Tag::Not { lhs }
|
||||
| super::Tag::Negate { lhs } => {
|
||||
vec![*lhs]
|
||||
}
|
||||
super::Tag::Or { lhs, rhs }
|
||||
| super::Tag::And { lhs, rhs }
|
||||
| super::Tag::BitOr { lhs, rhs }
|
||||
| super::Tag::BitAnd { lhs, rhs }
|
||||
| super::Tag::BitXOr { lhs, rhs }
|
||||
| super::Tag::Eq { lhs, rhs }
|
||||
| super::Tag::NEq { lhs, rhs }
|
||||
| super::Tag::Lt { lhs, rhs }
|
||||
| super::Tag::Gt { lhs, rhs }
|
||||
| super::Tag::Le { lhs, rhs }
|
||||
| super::Tag::Ge { lhs, rhs }
|
||||
| super::Tag::Shl { lhs, rhs }
|
||||
| super::Tag::Shr { lhs, rhs }
|
||||
| super::Tag::Add { lhs, rhs }
|
||||
| super::Tag::Sub { lhs, rhs }
|
||||
| super::Tag::Mul { lhs, rhs }
|
||||
| super::Tag::Rem { lhs, rhs }
|
||||
| super::Tag::Div { lhs, rhs }
|
||||
| super::Tag::Assign { lhs, rhs } => {
|
||||
vec![*lhs, *rhs]
|
||||
}
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn next_node(&mut self) -> Option<PrePost> {
|
||||
loop {
|
||||
let frame = self.frames.last_mut()?;
|
||||
if let Some(node) = frame.children.pop() {
|
||||
return Some(PrePost::Pre(node));
|
||||
} else {
|
||||
let frame = self.frames.pop()?;
|
||||
if frame.node != Node::MAX {
|
||||
return Some(PrePost::Post(frame.node));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_ok<T, E>(mut self) -> core::result::Result<T, E>
|
||||
where
|
||||
F1: FnMut(&mut Tree, Node) -> core::result::Result<T, E>,
|
||||
F2: FnMut(&mut Tree, Node) -> core::result::Result<T, E>,
|
||||
{
|
||||
let mut t = None;
|
||||
loop {
|
||||
let Some(node) = self.next_node() else {
|
||||
break;
|
||||
};
|
||||
|
||||
match node {
|
||||
PrePost::Pre(node) => {
|
||||
t = Some((self.pre)(self.tree, node)?);
|
||||
let children = self.get_children(node);
|
||||
self.frames.push(Frame { node, children });
|
||||
}
|
||||
PrePost::Post(node) => {
|
||||
t = Some((self.post)(self.tree, node)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(t.unwrap())
|
||||
}
|
||||
|
||||
pub fn visit<T, U>(mut self)
|
||||
where
|
||||
F1: FnMut(&mut Tree, Node) -> T,
|
||||
F2: FnMut(&mut Tree, Node) -> U,
|
||||
{
|
||||
loop {
|
||||
let Some(node) = self.next_node() else {
|
||||
break;
|
||||
};
|
||||
|
||||
match node {
|
||||
PrePost::Pre(node) => {
|
||||
(self.pre)(self.tree, node);
|
||||
let children = self.get_children(node);
|
||||
self.frames.push(Frame { node, children });
|
||||
}
|
||||
PrePost::Post(node) => {
|
||||
(self.post)(self.tree, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ pub mod bigint {
|
|||
|
||||
impl BigInt {
|
||||
pub fn parse_digits<C: IntoIterator<Item = char>>(text: C, radix: Radix) -> BigInt {
|
||||
parse_bigint(text.into_iter(), radix)
|
||||
Self(parse_bigint(text.into_iter(), radix))
|
||||
}
|
||||
pub fn from_u32(v: u32) -> BigInt {
|
||||
Self(vec![v])
|
||||
|
@ -1022,7 +1022,7 @@ pub mod bigint {
|
|||
carry
|
||||
}
|
||||
|
||||
fn parse_bigint(text: impl Iterator<Item = char>, radix: Radix) -> BigInt {
|
||||
pub fn parse_bigint(text: impl Iterator<Item = char>, radix: Radix) -> Vec<u32> {
|
||||
let digits = text
|
||||
.filter_map(|c| match c {
|
||||
'_' => None,
|
||||
|
@ -1080,7 +1080,8 @@ pub mod bigint {
|
|||
}
|
||||
}
|
||||
}
|
||||
BigInt(data)
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1089,53 +1090,53 @@ pub mod bigint {
|
|||
|
||||
#[test]
|
||||
fn parse() {
|
||||
let bigint = super::parse_bigint("2_cafe_babe_dead_beef".chars(), Radix::Hex);
|
||||
let bigint = BigInt::parse_digits("2_cafe_babe_dead_beef".chars(), Radix::Hex);
|
||||
println!("{:#x?}", bigint);
|
||||
let bigint = super::parse_bigint("f".chars(), Radix::Hex);
|
||||
let bigint = BigInt::parse_digits("f".chars(), Radix::Hex);
|
||||
println!("{:#x?}", bigint);
|
||||
}
|
||||
#[test]
|
||||
fn add() {
|
||||
let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex);
|
||||
let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex);
|
||||
println!("{:#x?}", a);
|
||||
let b = super::parse_bigint("cafebabe".chars(), Radix::Hex);
|
||||
let b = BigInt::parse_digits("cafebabe".chars(), Radix::Hex);
|
||||
println!("{:#x?}", b);
|
||||
let sum = a + b;
|
||||
println!("{:#x?}", sum);
|
||||
}
|
||||
#[test]
|
||||
fn sub() {
|
||||
let a = super::parse_bigint("deadbeef".chars(), Radix::Hex);
|
||||
let a = BigInt::parse_digits("deadbeef".chars(), Radix::Hex);
|
||||
println!("{:#x?}", a);
|
||||
let b = super::parse_bigint("56d2c".chars(), Radix::Hex);
|
||||
let b = BigInt::parse_digits("56d2c".chars(), Radix::Hex);
|
||||
println!("{:#x?}", b);
|
||||
let sum = a - b;
|
||||
println!("{:#x?}", sum);
|
||||
}
|
||||
#[test]
|
||||
fn overflowing_sub() {
|
||||
let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex);
|
||||
let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex);
|
||||
println!("{:#x?}", a);
|
||||
let b = super::parse_bigint("ffff_ffff".chars(), Radix::Hex);
|
||||
let b = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex);
|
||||
println!("{:#x?}", b);
|
||||
let sum = b - a;
|
||||
println!("{:#x?}", sum);
|
||||
}
|
||||
#[test]
|
||||
fn shr() {
|
||||
let mut a = super::parse_bigint("cafe_babe_0000".chars(), Radix::Hex);
|
||||
let mut a = BigInt::parse_digits("cafe_babe_0000".chars(), Radix::Hex);
|
||||
print!("{:0>8x?} >> 32 ", a);
|
||||
shr_bitint(&mut a.0, 32);
|
||||
println!("{:0>8x?}", a);
|
||||
|
||||
let mut a = super::parse_bigint("11110000".chars(), Radix::Bin);
|
||||
let mut a = BigInt::parse_digits("11110000".chars(), Radix::Bin);
|
||||
print!("{:0>8x?} >> 32 ", a);
|
||||
shr_bitint(&mut a.0, 3);
|
||||
println!("{:0>8x?}", a);
|
||||
}
|
||||
#[test]
|
||||
fn shl() {
|
||||
let mut a = super::parse_bigint("ffff_ffff".chars(), Radix::Hex);
|
||||
let mut a = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex);
|
||||
a.0.extend([0; 4]);
|
||||
println!("{:0>8x?}", a);
|
||||
shl_bitint(&mut a.0, 40);
|
||||
|
@ -1143,8 +1144,8 @@ pub mod bigint {
|
|||
}
|
||||
#[test]
|
||||
fn div() {
|
||||
let a = super::parse_bigint("cafebabe".chars(), Radix::Hex);
|
||||
let b = super::parse_bigint("dead".chars(), Radix::Hex);
|
||||
let a = BigInt::parse_digits("cafebabe".chars(), Radix::Hex);
|
||||
let b = BigInt::parse_digits("dead".chars(), Radix::Hex);
|
||||
let (div, rem) = div_rem_bigint(a, b);
|
||||
println!("div: {:0>8x?}", div);
|
||||
println!("rem: {:0>8x?}", rem);
|
||||
|
@ -1501,13 +1502,13 @@ pub mod bigsint {
|
|||
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
ops::{Add, BitAnd, BitOr, BitXor, Not},
|
||||
ops::{BitAnd, BitOr, BitXor, Not},
|
||||
};
|
||||
|
||||
use num_bigint::{BigInt, BigUint, Sign};
|
||||
use num_traits::cast::ToPrimitive;
|
||||
use num_traits::{cast::ToPrimitive, ToBytes};
|
||||
|
||||
use crate::ast::{FloatingType, IntegralType};
|
||||
use crate::ast::{FloatingType, IntegralType, Type};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
|
@ -1521,6 +1522,8 @@ pub enum Error {
|
|||
UnsignedNegation,
|
||||
#[error("Incomparable floats.")]
|
||||
FloatingCmp,
|
||||
#[error("Not a comptime expression.")]
|
||||
NotComptime,
|
||||
}
|
||||
|
||||
pub type Result<T> = core::result::Result<T, Error>;
|
||||
|
@ -2263,4 +2266,31 @@ impl ComptimeNumber {
|
|||
FloatingType::Binary64 => Ok(f.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bytes_and_type(self) -> (Vec<u8>, Type) {
|
||||
match self {
|
||||
ComptimeNumber::Integral(i) => match i {
|
||||
ComptimeInt::Native { bits, ty } => {
|
||||
(bits.to_le_bytes().to_vec(), Type::Integer(ty))
|
||||
}
|
||||
ComptimeInt::BigInt { bits, ty } => {
|
||||
(bits.to_le_bytes().to_vec(), Type::Integer(ty))
|
||||
}
|
||||
ComptimeInt::Comptime(bits) => {
|
||||
(bits.to_le_bytes().to_vec(), Type::comptime_number())
|
||||
}
|
||||
},
|
||||
ComptimeNumber::Bool(b) => (vec![b as u8], Type::bool()),
|
||||
ComptimeNumber::Floating(f) => match f {
|
||||
ComptimeFloat::Binary32(f) => (
|
||||
f.to_le_bytes().to_vec(),
|
||||
Type::Floating(FloatingType::Binary32),
|
||||
),
|
||||
ComptimeFloat::Binary64(f) => (
|
||||
f.to_le_bytes().to_vec(),
|
||||
Type::Floating(FloatingType::Binary64),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, fmt::Display};
|
||||
|
||||
use itertools::Itertools;
|
||||
use num_bigint::BigInt;
|
||||
use num_bigint::{BigInt, BigUint};
|
||||
|
||||
use crate::{
|
||||
ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type},
|
||||
ast::{self, FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type},
|
||||
common::NextIf,
|
||||
comptime::{self, ComptimeNumber},
|
||||
error::{AnalysisError, AnalysisErrorTag},
|
||||
|
@ -252,7 +252,7 @@ impl Tree {
|
|||
.map(|(_, c)| c)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let value = BigInt::parse_digits(digits, radix);
|
||||
let value = comptime::bigint::parse_bigint(digits.into_iter(), radix);
|
||||
|
||||
let ty = match iter.clone().next() {
|
||||
Some((_, 'u')) | Some((_, 'i')) => {
|
||||
|
@ -261,7 +261,10 @@ impl Tree {
|
|||
_ => None,
|
||||
};
|
||||
|
||||
(value, ty)
|
||||
(
|
||||
BigInt::from_biguint(num_bigint::Sign::Plus, BigUint::new(value)),
|
||||
ty,
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_floating_constant(_token: Token, lexeme: &str) -> (u64, FloatingType) {
|
||||
|
@ -860,7 +863,7 @@ impl Tree {
|
|||
| Token::IntegerConstant => {
|
||||
_ = tokens.next();
|
||||
let (bits, ty) = Self::parse_integral_constant(token.token(), token.lexeme());
|
||||
let bytes = bits.into_bytes_le();
|
||||
let (_, bytes) = bits.to_bytes_le();
|
||||
|
||||
const BUF_SIZE: usize = core::mem::size_of::<u64>();
|
||||
let mut buf = [0u8; BUF_SIZE];
|
||||
|
@ -1596,10 +1599,31 @@ impl Tree {
|
|||
|
||||
fn try_fold_comptime_inner(&mut self, node: Node) {
|
||||
if self.is_node_comptime(node) {
|
||||
self.fold_comptime_inner(node);
|
||||
_ = self.fold_comptime_inner(node);
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_comptime_with_visitor(&mut self, decl: Node) {
|
||||
ast::tree_visitor::Visitor::new(
|
||||
self,
|
||||
decl,
|
||||
|_, node| {
|
||||
eprint!("%{node} ");
|
||||
},
|
||||
|tree, node| {
|
||||
if let Ok(value) = tree.fold_comptime_inner(node) {
|
||||
let (bytes, ty) = value.into_bytes_and_type();
|
||||
let idx = tree.strings.insert(bytes);
|
||||
*tree.nodes.get_node_mut(node) = Tag::Constant {
|
||||
bytes: ImmOrIndex::Index(idx),
|
||||
ty,
|
||||
};
|
||||
}
|
||||
},
|
||||
)
|
||||
.visit();
|
||||
}
|
||||
|
||||
fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result<ComptimeNumber> {
|
||||
//
|
||||
if self.is_node_comptime(decl) {
|
||||
|
@ -1792,18 +1816,19 @@ impl Tree {
|
|||
unreachable!()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(comptime::Error::NotComptime)
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn fold_comptime(&mut self) {
|
||||
for decl in self.global_decls.clone() {
|
||||
match self.nodes.get_node(decl) {
|
||||
Tag::FunctionDecl { body, .. } => {
|
||||
self.fold_comptime_inner(*body);
|
||||
_ = self.fold_comptime_inner(*body);
|
||||
}
|
||||
Tag::GlobalDecl { assignment, .. } => {
|
||||
self.fold_comptime_inner(*assignment);
|
||||
_ = self.fold_comptime_inner(*assignment);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
@ -2031,4 +2056,27 @@ const global: u32 = 42u32;
|
|||
tree.render(&mut buf).unwrap();
|
||||
println!("{buf}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn comptime() {
|
||||
let src = "
|
||||
fn main() -> void {
|
||||
let a = 3 * 49573 << 4;
|
||||
}
|
||||
";
|
||||
let tokens = Tokenizer::new(src.as_bytes()).unwrap();
|
||||
|
||||
let mut tree = Tree::new();
|
||||
tree.parse(tokens.iter()).unwrap();
|
||||
|
||||
let mut buf = String::new();
|
||||
tree.render(&mut buf).unwrap();
|
||||
println!("{buf}");
|
||||
|
||||
tree.fold_comptime_with_visitor(tree.global_decls.first().cloned().unwrap());
|
||||
|
||||
let mut buf = String::new();
|
||||
tree.render(&mut buf).unwrap();
|
||||
println!("{buf}");
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue