folding comptime math

This commit is contained in:
Janis 2024-08-27 17:22:47 +02:00
parent a64658995a
commit 18e29f1fa1
2 changed files with 456 additions and 4 deletions

View file

@ -1499,12 +1499,15 @@ pub mod bigsint {
}
}
use std::ops::{Add, BitAnd, BitOr, BitXor, Not};
use std::{
cmp::Ordering,
ops::{Add, BitAnd, BitOr, BitXor, Not},
};
use num_bigint::{BigInt, Sign};
use num_bigint::{BigInt, BigUint, Sign};
use num_traits::cast::ToPrimitive;
use crate::ast::IntegralType;
use crate::ast::{FloatingType, IntegralType};
#[derive(Debug, thiserror::Error)]
pub enum Error {
@ -1516,10 +1519,13 @@ pub enum Error {
ShiftTooLarge,
#[error("Cannot negate unsigned integer")]
UnsignedNegation,
#[error("Incomparable floats.")]
FloatingCmp,
}
pub type Result<T> = core::result::Result<T, Error>;
#[derive(Debug, PartialEq, Eq)]
pub enum ComptimeInt {
Native { bits: u128, ty: IntegralType },
BigInt { bits: BigInt, ty: IntegralType },
@ -1710,6 +1716,19 @@ impl ComptimeInt {
}
}
}
pub fn cmp(self, other: Self) -> Result<Ordering> {
let (a, b) = self.coalesce(other)?;
let ord = match (a, b) {
(ComptimeInt::Native { bits: a, .. }, ComptimeInt::Native { bits: b, .. }) => a.cmp(&b),
(ComptimeInt::BigInt { bits: a, .. }, ComptimeInt::BigInt { bits: b, .. }) => a.cmp(&b),
(ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => a.cmp(&b),
_ => {
unreachable!()
}
};
Ok(ord)
}
pub fn shl(self, other: Self) -> Result<Self> {
use core::ops::Shl;
@ -1884,6 +1903,7 @@ impl ComptimeInt {
}
}
#[derive(Debug, PartialEq)]
pub enum ComptimeFloat {
Binary32(f32),
Binary64(f64),
@ -1931,6 +1951,17 @@ impl ComptimeFloat {
ComptimeFloat::Binary64(a) => Ok(Self::Binary64(-a)),
}
}
pub fn cmp(self, other: Self) -> Result<Ordering> {
let ord = match (self, other) {
(ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => a.partial_cmp(&b),
(ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => a.partial_cmp(&b),
_ => {
return Err(Error::IncompatibleTypes);
}
};
ord.ok_or(Error::FloatingCmp)
}
}
pub enum ComptimeNumber {
@ -1939,6 +1970,41 @@ pub enum ComptimeNumber {
Floating(ComptimeFloat),
}
impl From<bool> for ComptimeNumber {
fn from(value: bool) -> Self {
Self::Bool(value)
}
}
impl From<f32> for ComptimeNumber {
fn from(value: f32) -> Self {
Self::Floating(ComptimeFloat::Binary32(value))
}
}
impl From<f64> for ComptimeNumber {
fn from(value: f64) -> Self {
Self::Floating(ComptimeFloat::Binary64(value))
}
}
impl From<BigInt> for ComptimeNumber {
fn from(value: BigInt) -> Self {
Self::Integral(ComptimeInt::Comptime(value))
}
}
impl From<(BigInt, IntegralType)> for ComptimeNumber {
fn from((bits, ty): (BigInt, IntegralType)) -> Self {
Self::Integral(ComptimeInt::BigInt { bits, ty })
}
}
impl From<(u128, IntegralType)> for ComptimeNumber {
fn from((bits, ty): (u128, IntegralType)) -> Self {
Self::Integral(ComptimeInt::Native { bits, ty })
}
}
impl ComptimeNumber {
pub fn add(self, other: Self) -> Result<Self> {
match (self, other) {
@ -2076,4 +2142,125 @@ impl ComptimeNumber {
_ => Err(Error::IncompatibleTypes),
}
}
pub fn or(self, other: Self) -> Result<Self> {
match (self, other) {
// (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => {
// Ok(Self::Integral(a.shr(b)?))
// }
// (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => {
// Ok(Self::Floating(a.bitxor(b)?))
// }
(ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a || b)),
_ => Err(Error::IncompatibleTypes),
}
}
pub fn and(self, other: Self) -> Result<Self> {
match (self, other) {
// (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => {
// Ok(Self::Integral(a.shr(b)?))
// }
// (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => {
// Ok(Self::Floating(a.bitxor(b)?))
// }
(ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a && b)),
_ => Err(Error::IncompatibleTypes),
}
}
pub fn eq(self, other: Self) -> Result<Self> {
match (self, other) {
(ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => Ok(Self::Bool(a == b)),
(ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => Ok(Self::Bool(a == b)),
(ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a == b)),
_ => Err(Error::IncompatibleTypes),
}
}
pub fn cmp(self, other: Self) -> Result<Ordering> {
let ord = match (self, other) {
(ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => a.cmp(b)?,
(ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => a.cmp(b)?,
(ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => a.cmp(&b),
_ => {
return Err(Error::IncompatibleTypes);
}
};
Ok(ord)
}
pub fn lt(self, other: Self) -> Result<Self> {
Ok(Self::Bool(self.cmp(other)? == Ordering::Less))
}
pub fn gt(self, other: Self) -> Result<Self> {
Ok(Self::Bool(self.cmp(other)? == Ordering::Greater))
}
pub fn ge(self, other: Self) -> Result<Self> {
Ok(Self::Bool(self.cmp(other)? != Ordering::Less))
}
pub fn le(self, other: Self) -> Result<Self> {
Ok(Self::Bool(self.cmp(other)? != Ordering::Greater))
}
pub fn into_bool(self) -> Result<Self> {
match self {
ComptimeNumber::Integral(i) => match i {
ComptimeInt::Native { bits, .. } => Ok((bits != 0).into()),
ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => {
Ok((bits.sign() != Sign::NoSign).into())
}
},
ComptimeNumber::Floating(ComptimeFloat::Binary32(f)) => Ok((f != 0.0).into()),
ComptimeNumber::Floating(ComptimeFloat::Binary64(f)) => Ok((f != 0.0).into()),
a => Ok(a),
}
}
pub fn into_int(self, ty: IntegralType) -> Result<Self> {
match self {
ComptimeNumber::Integral(i) => match i {
ComptimeInt::Native { bits, .. } => Ok((bits & ty.u128_bitmask(), ty).into()),
ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => {
let max = BigUint::from(2u32).pow((ty.bits - ty.signed as u16) as u32);
let (sign, data) = bits.into_parts();
let data = data.clamp(BigUint::ZERO, max);
Ok((BigInt::from_biguint(sign, data), ty).into())
}
},
ComptimeNumber::Bool(b) => Ok((b as u128 & ty.u128_bitmask(), ty).into()),
ComptimeNumber::Floating(f) => match f {
ComptimeFloat::Binary32(f) => Ok((f as u128 & ty.u128_bitmask(), ty).into()),
ComptimeFloat::Binary64(f) => Ok((f as u128 & ty.u128_bitmask(), ty).into()),
},
}
}
pub fn into_float(self, ty: FloatingType) -> Result<Self> {
let f = match self {
ComptimeNumber::Integral(i) => match i {
ComptimeInt::Native { bits, .. } => bits as f64,
ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => {
bits.to_f64().unwrap_or(f64::NAN)
}
},
ComptimeNumber::Bool(b) => {
if b {
1.0f64
} else {
0.0f64
}
}
ComptimeNumber::Floating(f) => match f {
ComptimeFloat::Binary32(f) => f as f64,
ComptimeFloat::Binary64(f) => f as f64,
},
};
match ty {
FloatingType::Binary32 => Ok((f as f32).into()),
FloatingType::Binary64 => Ok(f.into()),
}
}
}

View file

@ -1,11 +1,12 @@
use std::collections::HashMap;
use itertools::Itertools;
use num_bigint::BigInt;
use crate::{
ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type},
common::NextIf,
comptime::bigint::{self, BigInt},
comptime::{self, ComptimeNumber},
error::{AnalysisError, AnalysisErrorTag},
lexer::{Radix, TokenIterator},
string_table::{ImmOrIndex, Index, StringTable},
@ -1546,6 +1547,270 @@ impl Tree {
}
}
// simplify tree with compile-time math
impl Tree {
fn is_node_comptime(&self, node: Node) -> bool {
match self.nodes.get_node(node) {
Tag::Block {
statements,
trailing_expr,
} => statements
.iter()
.chain(trailing_expr.into_iter())
.all(|n| self.is_node_comptime(*n)),
Tag::Constant { .. } => true,
Tag::ExplicitCast { lhs, typename } => {
self.is_node_comptime(*lhs)
&& match self.type_of_node(*typename) {
Type::Bool
| Type::ComptimeNumber
| Type::Integer(_)
| Type::Floating(_) => true,
_ => false,
}
}
Tag::DeclRef(lhs) | Tag::Not { lhs } | Tag::Negate { lhs } => {
self.is_node_comptime(*lhs)
}
Tag::Or { lhs, rhs }
| Tag::And { lhs, rhs }
| Tag::BitOr { lhs, rhs }
| Tag::BitAnd { lhs, rhs }
| Tag::BitXOr { lhs, rhs }
| Tag::Eq { lhs, rhs }
| Tag::NEq { lhs, rhs }
| Tag::Lt { lhs, rhs }
| Tag::Gt { lhs, rhs }
| Tag::Le { lhs, rhs }
| Tag::Ge { lhs, rhs }
| Tag::Shl { lhs, rhs }
| Tag::Shr { lhs, rhs }
| Tag::Add { lhs, rhs }
| Tag::Sub { lhs, rhs }
| Tag::Mul { lhs, rhs }
| Tag::Rem { lhs, rhs }
| Tag::Div { lhs, rhs } => self.is_node_comptime(*lhs) && self.is_node_comptime(*rhs),
_ => false,
}
}
fn try_fold_comptime_inner(&mut self, node: Node) {
if self.is_node_comptime(node) {
self.fold_comptime_inner(node);
}
}
fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result<ComptimeNumber> {
//
if self.is_node_comptime(decl) {
match self.nodes.get_node(decl) {
Tag::Constant { bytes, ty } => {
let bytes = match bytes {
ImmOrIndex::U64(v) => &v.to_le_bytes()[..],
ImmOrIndex::U32(v) => &v.to_le_bytes()[..],
ImmOrIndex::Index(idx) => self.strings.get_bytes(*idx),
};
let number: ComptimeNumber = match ty {
Type::Bool => (bytes[0] != 0).into(),
Type::ComptimeNumber => {
BigInt::from_bytes_le(num_bigint::Sign::Plus, bytes).into()
}
Type::Integer(ty) => {
if bytes.len() > core::mem::size_of::<u128>() {
let bits = BigInt::from_bytes_le(num_bigint::Sign::Plus, bytes);
(bits, *ty).into()
} else {
let mut buf = [0u8; core::mem::size_of::<u128>()];
buf[..bytes.len()].copy_from_slice(bytes);
let bits = u128::from_le_bytes(buf);
(bits, *ty).into()
}
}
Type::Floating(ty) => match ty {
FloatingType::Binary32 => {
(f32::from_le_bytes((&bytes[..4]).try_into().unwrap())).into()
}
FloatingType::Binary64 => {
(f64::from_le_bytes((&bytes[..8]).try_into().unwrap())).into()
}
},
_ => unimplemented!(),
};
return Ok(number);
}
Tag::Negate { lhs } => {
let lhs = self.fold_comptime_inner(*lhs)?;
return Ok(lhs.neg()?);
}
Tag::ExplicitCast { lhs, typename } => {
let ty = self.type_of_node(*typename);
let lhs = self.fold_comptime_inner(*lhs)?;
return match ty {
Type::Bool => lhs.into_bool(),
Type::Integer(ty) => lhs.into_int(ty),
Type::Floating(ty) => lhs.into_float(ty),
_ => unimplemented!(),
};
}
Tag::DeclRef(lhs) => {
return self.fold_comptime_inner(*lhs);
}
Tag::Not { lhs } => {
let lhs = self.fold_comptime_inner(*lhs)?;
return lhs.not();
}
Tag::Or { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.or(rhs);
}
Tag::And { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.and(rhs);
}
Tag::Eq { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.eq(rhs);
}
Tag::NEq { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.eq(rhs)?.not();
}
Tag::Lt { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.lt(rhs);
}
Tag::Gt { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.gt(rhs);
}
Tag::Le { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.le(rhs);
}
Tag::Ge { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.ge(rhs);
}
Tag::BitOr { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.bitor(rhs);
}
Tag::BitAnd { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.bitand(rhs);
}
Tag::BitXOr { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.bitxor(rhs);
}
Tag::Shl { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.shl(rhs);
}
Tag::Shr { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.shr(rhs);
}
Tag::Add { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.add(rhs);
}
Tag::Sub { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.sub(rhs);
}
Tag::Mul { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.mul(rhs);
}
Tag::Rem { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.rem(rhs);
}
Tag::Div { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?;
let rhs = self.fold_comptime_inner(rhs)?;
return lhs.div(rhs);
}
_ => {
unreachable!()
}
}
}
todo!()
}
pub fn fold_comptime(&mut self) {
for decl in self.global_decls.clone() {
match self.nodes.get_node(decl) {
Tag::FunctionDecl { body, .. } => {
self.fold_comptime_inner(*body);
}
Tag::GlobalDecl { assignment, .. } => {
self.fold_comptime_inner(*assignment);
}
_ => unreachable!(),
}
}
}
}
impl Tree {
/// type-checks and inserts appropriate explicit-cast nodes.
pub fn typecheck(&mut self) {