diff --git a/src/comptime.rs b/src/comptime.rs index 1a07284..e6522ef 100644 --- a/src/comptime.rs +++ b/src/comptime.rs @@ -1499,12 +1499,15 @@ pub mod bigsint { } } -use std::ops::{Add, BitAnd, BitOr, BitXor, Not}; +use std::{ + cmp::Ordering, + ops::{Add, BitAnd, BitOr, BitXor, Not}, +}; -use num_bigint::{BigInt, Sign}; +use num_bigint::{BigInt, BigUint, Sign}; use num_traits::cast::ToPrimitive; -use crate::ast::IntegralType; +use crate::ast::{FloatingType, IntegralType}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -1516,10 +1519,13 @@ pub enum Error { ShiftTooLarge, #[error("Cannot negate unsigned integer")] UnsignedNegation, + #[error("Incomparable floats.")] + FloatingCmp, } pub type Result = core::result::Result; +#[derive(Debug, PartialEq, Eq)] pub enum ComptimeInt { Native { bits: u128, ty: IntegralType }, BigInt { bits: BigInt, ty: IntegralType }, @@ -1710,6 +1716,19 @@ impl ComptimeInt { } } } + pub fn cmp(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + let ord = match (a, b) { + (ComptimeInt::Native { bits: a, .. }, ComptimeInt::Native { bits: b, .. }) => a.cmp(&b), + (ComptimeInt::BigInt { bits: a, .. }, ComptimeInt::BigInt { bits: b, .. }) => a.cmp(&b), + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => a.cmp(&b), + _ => { + unreachable!() + } + }; + + Ok(ord) + } pub fn shl(self, other: Self) -> Result { use core::ops::Shl; @@ -1884,6 +1903,7 @@ impl ComptimeInt { } } +#[derive(Debug, PartialEq)] pub enum ComptimeFloat { Binary32(f32), Binary64(f64), @@ -1931,6 +1951,17 @@ impl ComptimeFloat { ComptimeFloat::Binary64(a) => Ok(Self::Binary64(-a)), } } + pub fn cmp(self, other: Self) -> Result { + let ord = match (self, other) { + (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => a.partial_cmp(&b), + (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => a.partial_cmp(&b), + _ => { + return Err(Error::IncompatibleTypes); + } + }; + + ord.ok_or(Error::FloatingCmp) + } } pub enum ComptimeNumber { @@ -1939,6 +1970,41 @@ pub enum ComptimeNumber { Floating(ComptimeFloat), } +impl From for ComptimeNumber { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl From for ComptimeNumber { + fn from(value: f32) -> Self { + Self::Floating(ComptimeFloat::Binary32(value)) + } +} + +impl From for ComptimeNumber { + fn from(value: f64) -> Self { + Self::Floating(ComptimeFloat::Binary64(value)) + } +} + +impl From for ComptimeNumber { + fn from(value: BigInt) -> Self { + Self::Integral(ComptimeInt::Comptime(value)) + } +} + +impl From<(BigInt, IntegralType)> for ComptimeNumber { + fn from((bits, ty): (BigInt, IntegralType)) -> Self { + Self::Integral(ComptimeInt::BigInt { bits, ty }) + } +} +impl From<(u128, IntegralType)> for ComptimeNumber { + fn from((bits, ty): (u128, IntegralType)) -> Self { + Self::Integral(ComptimeInt::Native { bits, ty }) + } +} + impl ComptimeNumber { pub fn add(self, other: Self) -> Result { match (self, other) { @@ -2076,4 +2142,125 @@ impl ComptimeNumber { _ => Err(Error::IncompatibleTypes), } } + pub fn or(self, other: Self) -> Result { + match (self, other) { + // (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + // Ok(Self::Integral(a.shr(b)?)) + // } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.bitxor(b)?)) + // } + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a || b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn and(self, other: Self) -> Result { + match (self, other) { + // (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + // Ok(Self::Integral(a.shr(b)?)) + // } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.bitxor(b)?)) + // } + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a && b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn eq(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => Ok(Self::Bool(a == b)), + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => Ok(Self::Bool(a == b)), + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a == b)), + _ => Err(Error::IncompatibleTypes), + } + } + + pub fn cmp(self, other: Self) -> Result { + let ord = match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => a.cmp(b)?, + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => a.cmp(b)?, + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => a.cmp(&b), + _ => { + return Err(Error::IncompatibleTypes); + } + }; + + Ok(ord) + } + + pub fn lt(self, other: Self) -> Result { + Ok(Self::Bool(self.cmp(other)? == Ordering::Less)) + } + + pub fn gt(self, other: Self) -> Result { + Ok(Self::Bool(self.cmp(other)? == Ordering::Greater)) + } + + pub fn ge(self, other: Self) -> Result { + Ok(Self::Bool(self.cmp(other)? != Ordering::Less)) + } + + pub fn le(self, other: Self) -> Result { + Ok(Self::Bool(self.cmp(other)? != Ordering::Greater)) + } + + pub fn into_bool(self) -> Result { + match self { + ComptimeNumber::Integral(i) => match i { + ComptimeInt::Native { bits, .. } => Ok((bits != 0).into()), + ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => { + Ok((bits.sign() != Sign::NoSign).into()) + } + }, + ComptimeNumber::Floating(ComptimeFloat::Binary32(f)) => Ok((f != 0.0).into()), + ComptimeNumber::Floating(ComptimeFloat::Binary64(f)) => Ok((f != 0.0).into()), + a => Ok(a), + } + } + + pub fn into_int(self, ty: IntegralType) -> Result { + match self { + ComptimeNumber::Integral(i) => match i { + ComptimeInt::Native { bits, .. } => Ok((bits & ty.u128_bitmask(), ty).into()), + ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => { + let max = BigUint::from(2u32).pow((ty.bits - ty.signed as u16) as u32); + let (sign, data) = bits.into_parts(); + let data = data.clamp(BigUint::ZERO, max); + + Ok((BigInt::from_biguint(sign, data), ty).into()) + } + }, + ComptimeNumber::Bool(b) => Ok((b as u128 & ty.u128_bitmask(), ty).into()), + ComptimeNumber::Floating(f) => match f { + ComptimeFloat::Binary32(f) => Ok((f as u128 & ty.u128_bitmask(), ty).into()), + ComptimeFloat::Binary64(f) => Ok((f as u128 & ty.u128_bitmask(), ty).into()), + }, + } + } + pub fn into_float(self, ty: FloatingType) -> Result { + let f = match self { + ComptimeNumber::Integral(i) => match i { + ComptimeInt::Native { bits, .. } => bits as f64, + ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => { + bits.to_f64().unwrap_or(f64::NAN) + } + }, + ComptimeNumber::Bool(b) => { + if b { + 1.0f64 + } else { + 0.0f64 + } + } + ComptimeNumber::Floating(f) => match f { + ComptimeFloat::Binary32(f) => f as f64, + ComptimeFloat::Binary64(f) => f as f64, + }, + }; + + match ty { + FloatingType::Binary32 => Ok((f as f32).into()), + FloatingType::Binary64 => Ok(f.into()), + } + } } diff --git a/src/parser.rs b/src/parser.rs index f81583f..6996195 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; use itertools::Itertools; +use num_bigint::BigInt; use crate::{ ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type}, common::NextIf, - comptime::bigint::{self, BigInt}, + comptime::{self, ComptimeNumber}, error::{AnalysisError, AnalysisErrorTag}, lexer::{Radix, TokenIterator}, string_table::{ImmOrIndex, Index, StringTable}, @@ -1546,6 +1547,270 @@ impl Tree { } } +// simplify tree with compile-time math +impl Tree { + fn is_node_comptime(&self, node: Node) -> bool { + match self.nodes.get_node(node) { + Tag::Block { + statements, + trailing_expr, + } => statements + .iter() + .chain(trailing_expr.into_iter()) + .all(|n| self.is_node_comptime(*n)), + Tag::Constant { .. } => true, + Tag::ExplicitCast { lhs, typename } => { + self.is_node_comptime(*lhs) + && match self.type_of_node(*typename) { + Type::Bool + | Type::ComptimeNumber + | Type::Integer(_) + | Type::Floating(_) => true, + _ => false, + } + } + Tag::DeclRef(lhs) | Tag::Not { lhs } | Tag::Negate { lhs } => { + self.is_node_comptime(*lhs) + } + Tag::Or { lhs, rhs } + | Tag::And { lhs, rhs } + | Tag::BitOr { lhs, rhs } + | Tag::BitAnd { lhs, rhs } + | Tag::BitXOr { lhs, rhs } + | Tag::Eq { lhs, rhs } + | Tag::NEq { lhs, rhs } + | Tag::Lt { lhs, rhs } + | Tag::Gt { lhs, rhs } + | Tag::Le { lhs, rhs } + | Tag::Ge { lhs, rhs } + | Tag::Shl { lhs, rhs } + | Tag::Shr { lhs, rhs } + | Tag::Add { lhs, rhs } + | Tag::Sub { lhs, rhs } + | Tag::Mul { lhs, rhs } + | Tag::Rem { lhs, rhs } + | Tag::Div { lhs, rhs } => self.is_node_comptime(*lhs) && self.is_node_comptime(*rhs), + _ => false, + } + } + + fn try_fold_comptime_inner(&mut self, node: Node) { + if self.is_node_comptime(node) { + self.fold_comptime_inner(node); + } + } + + fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result { + // + if self.is_node_comptime(decl) { + match self.nodes.get_node(decl) { + Tag::Constant { bytes, ty } => { + let bytes = match bytes { + ImmOrIndex::U64(v) => &v.to_le_bytes()[..], + + ImmOrIndex::U32(v) => &v.to_le_bytes()[..], + ImmOrIndex::Index(idx) => self.strings.get_bytes(*idx), + }; + + let number: ComptimeNumber = match ty { + Type::Bool => (bytes[0] != 0).into(), + Type::ComptimeNumber => { + BigInt::from_bytes_le(num_bigint::Sign::Plus, bytes).into() + } + Type::Integer(ty) => { + if bytes.len() > core::mem::size_of::() { + let bits = BigInt::from_bytes_le(num_bigint::Sign::Plus, bytes); + (bits, *ty).into() + } else { + let mut buf = [0u8; core::mem::size_of::()]; + buf[..bytes.len()].copy_from_slice(bytes); + let bits = u128::from_le_bytes(buf); + (bits, *ty).into() + } + } + Type::Floating(ty) => match ty { + FloatingType::Binary32 => { + (f32::from_le_bytes((&bytes[..4]).try_into().unwrap())).into() + } + FloatingType::Binary64 => { + (f64::from_le_bytes((&bytes[..8]).try_into().unwrap())).into() + } + }, + _ => unimplemented!(), + }; + return Ok(number); + } + Tag::Negate { lhs } => { + let lhs = self.fold_comptime_inner(*lhs)?; + return Ok(lhs.neg()?); + } + Tag::ExplicitCast { lhs, typename } => { + let ty = self.type_of_node(*typename); + let lhs = self.fold_comptime_inner(*lhs)?; + + return match ty { + Type::Bool => lhs.into_bool(), + Type::Integer(ty) => lhs.into_int(ty), + Type::Floating(ty) => lhs.into_float(ty), + _ => unimplemented!(), + }; + } + Tag::DeclRef(lhs) => { + return self.fold_comptime_inner(*lhs); + } + Tag::Not { lhs } => { + let lhs = self.fold_comptime_inner(*lhs)?; + return lhs.not(); + } + Tag::Or { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.or(rhs); + } + Tag::And { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.and(rhs); + } + Tag::Eq { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.eq(rhs); + } + Tag::NEq { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.eq(rhs)?.not(); + } + Tag::Lt { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.lt(rhs); + } + Tag::Gt { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.gt(rhs); + } + Tag::Le { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.le(rhs); + } + Tag::Ge { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.ge(rhs); + } + Tag::BitOr { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.bitor(rhs); + } + Tag::BitAnd { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.bitand(rhs); + } + Tag::BitXOr { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.bitxor(rhs); + } + Tag::Shl { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.shl(rhs); + } + Tag::Shr { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.shr(rhs); + } + Tag::Add { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.add(rhs); + } + Tag::Sub { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.sub(rhs); + } + Tag::Mul { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.mul(rhs); + } + Tag::Rem { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.rem(rhs); + } + Tag::Div { lhs, rhs } => { + let (lhs, rhs) = (*lhs, *rhs); + let lhs = self.fold_comptime_inner(lhs)?; + let rhs = self.fold_comptime_inner(rhs)?; + + return lhs.div(rhs); + } + _ => { + unreachable!() + } + } + } + todo!() + } + + pub fn fold_comptime(&mut self) { + for decl in self.global_decls.clone() { + match self.nodes.get_node(decl) { + Tag::FunctionDecl { body, .. } => { + self.fold_comptime_inner(*body); + } + Tag::GlobalDecl { assignment, .. } => { + self.fold_comptime_inner(*assignment); + } + _ => unreachable!(), + } + } + } +} + impl Tree { /// type-checks and inserts appropriate explicit-cast nodes. pub fn typecheck(&mut self) {