From a64658995a9c6b1f36b26e6f49f3116fa2292e93 Mon Sep 17 00:00:00 2001 From: Janis Date: Tue, 27 Aug 2024 16:21:39 +0200 Subject: [PATCH] comptime math --- Cargo.toml | 2 + src/ast.rs | 3 + src/comptime.rs | 1238 ++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 1166 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 34257e4..0425972 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,8 @@ ansi_term = "0.12.1" clap = "4.5.14" itertools = "0.13.0" log = "0.4.22" +num-bigint = "0.4.6" +num-traits = "0.2.19" petgraph = "0.6.5" thiserror = "1.0.63" unicode-xid = "0.2.4" diff --git a/src/ast.rs b/src/ast.rs index ae47fc6..e4b312e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -230,6 +230,9 @@ impl IntegralType { pub fn u64_bitmask(self) -> u64 { (1u64 << self.bits) - 1 } + pub fn u128_bitmask(self) -> u128 { + (1u128 << self.bits) - 1 + } } #[derive(Debug, Clone, Eq, Hash)] diff --git a/src/comptime.rs b/src/comptime.rs index 2eed83e..1a07284 100644 --- a/src/comptime.rs +++ b/src/comptime.rs @@ -1,5 +1,11 @@ pub mod bigint { + use core::{ + cmp::{Eq, Ord, Ordering, PartialOrd}, + fmt::Debug, + ops::{Add, AddAssign, Div, Mul, Not, Rem, Shl, Shr, Sub, SubAssign}, + }; + use crate::lexer::Radix; /// A base-4_294_967_295 number. #[derive(Clone)] @@ -9,6 +15,13 @@ pub mod bigint { pub fn parse_digits>(text: C, radix: Radix) -> BigInt { parse_bigint(text.into_iter(), radix) } + pub fn from_u32(v: u32) -> BigInt { + Self(vec![v]) + } + pub fn from_u64(v: u64) -> BigInt { + let (lo, hi) = into_lo_hi(v); + Self(vec![lo, hi]) + } pub fn one() -> BigInt { Self(vec![1]) @@ -93,26 +106,52 @@ pub mod bigint { } } - impl core::cmp::PartialEq for BigInt { + impl PartialEq for BigInt { fn eq(&self, other: &Self) -> bool { - cmp_bigint(&self.0, &other.0) == core::cmp::Ordering::Equal + cmp_bigint(&self.0, &other.0) == Ordering::Equal } } - impl core::cmp::Eq for BigInt {} + impl PartialEq for BigInt { + fn eq(&self, other: &u32) -> bool { + self.num_digits() == 1 && self.0[0] == *other + } + } - impl core::cmp::PartialOrd for BigInt { + impl PartialEq for BigInt { + fn eq(&self, other: &u64) -> bool { + let (lo, hi) = into_lo_hi(*other); + cmp_bigint(&self.0, &[lo, hi]) == Ordering::Equal + } + } + + impl PartialOrd for BigInt { + fn partial_cmp(&self, other: &u32) -> Option { + (self.num_digits() == 1).then(|| self.0[0].cmp(other)) + } + } + + impl PartialOrd for BigInt { + fn partial_cmp(&self, other: &u64) -> Option { + let (lo, hi) = into_lo_hi(*other); + Some(cmp_bigint(&self.0, &[lo, hi])) + } + } + + impl Eq for BigInt {} + + impl PartialOrd for BigInt { fn partial_cmp(&self, other: &Self) -> Option { Some(cmp_bigint(&self.0, &other.0)) } } - impl core::cmp::Ord for BigInt { + impl Ord for BigInt { fn cmp(&self, other: &Self) -> std::cmp::Ordering { cmp_bigint(&self.0, &other.0) } } - impl core::ops::Shl for BigInt { + impl Shl for BigInt { type Output = Self; fn shl(mut self, rhs: usize) -> Self::Output { @@ -120,7 +159,7 @@ pub mod bigint { self } } - impl core::ops::Shr for BigInt { + impl Shr for BigInt { type Output = Self; fn shr(mut self, rhs: usize) -> Self::Output { @@ -129,7 +168,7 @@ pub mod bigint { } } - impl core::ops::Add for BigInt { + impl Add for BigInt { type Output = Self; fn add(mut self, mut rhs: Self) -> Self::Output { @@ -149,7 +188,51 @@ pub mod bigint { } } - impl core::ops::Sub for BigInt { + impl Add for BigInt { + type Output = Self; + + fn add(mut self, rhs: u32) -> Self::Output { + self += rhs; + self + } + } + + impl AddAssign for BigInt { + fn add_assign(&mut self, rhs: u32) { + let carry = add_bigint_scalar(&mut self.0, rhs); + if carry { + self.0.push(carry as u32); + } + } + } + + impl Add for BigInt { + type Output = Self; + + fn add(mut self, rhs: u64) -> Self::Output { + self += rhs; + self + } + } + + impl AddAssign for BigInt { + fn add_assign(&mut self, rhs: u64) { + let (lo, hi) = into_lo_hi(rhs); + if hi == 0 { + *self += lo; + } else { + while self.num_digits() < 2 { + self.0.push(0); + } + let carry = add_bigint(&mut self.0, &[lo, hi]); + if carry { + self.0.push(carry as u32); + } + } + } + } + + impl Sub for BigInt { type Output = Self; fn sub(mut self, rhs: Self) -> Self::Output { @@ -165,7 +248,70 @@ pub mod bigint { } } - impl core::ops::Mul for BigInt { + impl Sub for BigInt { + type Output = Self; + + fn sub(mut self, rhs: u32) -> Self::Output { + self -= rhs; + self + } + } + + impl Sub for u32 { + type Output = BigInt; + + fn sub(self, mut rhs: BigInt) -> Self::Output { + if rhs.0.is_empty() { + rhs.0.push(self); + } else { + sub_bigint_in_right(&[self], &mut rhs.0); + } + rhs.normalised() + } + } + + impl Sub for u64 { + type Output = BigInt; + + fn sub(self, mut rhs: BigInt) -> Self::Output { + while rhs.num_digits() < 2 { + rhs.0.push(0); + } + + let (lo, hi) = into_lo_hi(self); + sub_bigint_in_right(&[lo, hi], &mut rhs.0); + + rhs.normalised() + } + } + + impl SubAssign for BigInt { + fn sub_assign(&mut self, rhs: u32) { + sub_bigint_scalar(&mut self.0, rhs); + } + } + + impl Sub for BigInt { + type Output = Self; + + fn sub(mut self, rhs: u64) -> Self::Output { + self -= rhs; + self + } + } + + impl SubAssign for BigInt { + fn sub_assign(&mut self, rhs: u64) { + let (lo, hi) = into_lo_hi(rhs); + while self.num_digits() < 2 { + self.0.push(0); + } + sub_bigint(&mut self.0, &[lo, hi]); + self.normalise(); + } + } + + impl Mul for BigInt { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { @@ -173,7 +319,25 @@ pub mod bigint { } } - impl core::ops::Div for BigInt { + impl Mul for BigInt { + type Output = Self; + + fn mul(mut self, rhs: u32) -> Self::Output { + u32_mul_bigint(&mut self.0, rhs); + self + } + } + + impl Mul for BigInt { + type Output = Self; + + fn mul(self, rhs: u64) -> Self::Output { + let (lo, hi) = into_lo_hi(rhs); + BigInt(mul_bigint(&self.0, &[lo, hi])) + } + } + + impl Div for BigInt { type Output = Self; fn div(self, rhs: Self) -> Self::Output { @@ -181,7 +345,24 @@ pub mod bigint { } } - impl core::ops::Rem for BigInt { + impl Div for BigInt { + type Output = Self; + + fn div(self, rhs: u32) -> Self::Output { + div_digit_bigint(self, rhs).0 + } + } + + impl Div for BigInt { + type Output = Self; + + fn div(self, rhs: u64) -> Self::Output { + let (lo, hi) = into_lo_hi(rhs); + div_rem_bigint(self, BigInt([lo, hi].to_vec())).0 + } + } + + impl Rem for BigInt { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { @@ -189,7 +370,24 @@ pub mod bigint { } } - impl core::ops::Not for BigInt { + impl Rem for BigInt { + type Output = Self; + + fn rem(self, rhs: u32) -> Self::Output { + BigInt::zero() + div_digit_bigint(self, rhs).1 + } + } + + impl Rem for BigInt { + type Output = Self; + + fn rem(self, rhs: u64) -> Self::Output { + let (lo, hi) = into_lo_hi(rhs); + div_rem_bigint(self, BigInt([lo, hi].to_vec())).1 + } + } + + impl Not for BigInt { type Output = Self; fn not(mut self) -> Self::Output { @@ -198,8 +396,8 @@ pub mod bigint { } } - impl core::fmt::Debug for BigInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl Debug for BigInt { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut list = f.debug_list(); list.entries(self.0.iter().rev()).finish() } @@ -244,7 +442,7 @@ pub mod bigint { #[allow(unused)] /// lhs <=> rhs - fn cmp_bigint(lhs: &[u32], rhs: &[u32]) -> core::cmp::Ordering { + fn cmp_bigint(lhs: &[u32], rhs: &[u32]) -> Ordering { use core::cmp::Ordering; let lhs_bits = count_bits(lhs); let rhs_bits = count_bits(rhs); @@ -385,7 +583,60 @@ pub mod bigint { #[allow(dead_code)] /// divident must be at least as wide as divisor /// returns (quotient, remainder) - fn div_rem_bigint(divident: BigInt, divisor: BigInt) -> (BigInt, BigInt) { + pub fn div_rem_bigint_ref(divident: &BigInt, divisor: &BigInt) -> (BigInt, BigInt) { + if bigint_is_zero(&divisor.0) { + panic!("divide by zero!"); + } + if bigint_is_zero(÷nt.0) { + return (BigInt::zero(), BigInt::zero()); + } + use core::cmp::Ordering; + match cmp_bigint(÷nt.0, &divisor.0) { + Ordering::Less => return (BigInt::zero(), divident.clone()), + Ordering::Equal => { + return (BigInt::one(), BigInt::zero()); + } + Ordering::Greater => {} + } + + if divisor.is_power_of_two() { + let exp = divisor.trailing_zeros(); + let (div, rem) = divident.0.split_at(exp.div_floor(u32::BITS as usize)); + let (mut div, mut rem) = (div.to_vec(), rem.to_vec()); + + shr_bitint(&mut div, exp % u32::BITS as usize); + let mask = (1u32 << exp as u32 % u32::BITS) - 1; + if let Some(last) = rem.last_mut() { + *last &= mask; + } + + return (BigInt(div), BigInt(rem)); + } + + if divisor.num_digits() == 1 { + if divisor.0[0] == 1 { + return (divident.clone(), BigInt::zero()); + } + + let (div, rem) = div_digit_bigint(divident.clone(), divisor.0[0]); + let rem = BigInt::zero() + rem; + return (div, rem); + } + + let shift = divisor.0.last().unwrap().leading_zeros() as usize; + if shift == 0 { + div_rem_core(divident.clone(), &divisor.0) + } else { + let (q, r) = div_rem_core(divident.clone() << shift, &(divisor.clone() << shift).0); + + (q, r >> shift) + } + } + + #[allow(dead_code)] + /// divident must be at least as wide as divisor + /// returns (quotient, remainder) + pub fn div_rem_bigint(divident: BigInt, divisor: BigInt) -> (BigInt, BigInt) { let divident = divident.normalised(); let mut divisor = divisor.normalised(); @@ -695,6 +946,56 @@ pub mod bigint { } } + fn sub_bigint_in_right_simple(lhs: &[u32], rhs: &mut [u32]) -> bool { + assert!(lhs.len() == rhs.len()); + let mut borrow = false; + for (l, r) in lhs.iter().zip(rhs) { + (*r, borrow) = l.borrowing_sub(*r, borrow); + } + + borrow + } + + fn sub_bigint_in_right(lhs: &[u32], rhs: &mut [u32]) { + assert!(rhs.len() >= lhs.len()); + + let min_len = lhs.len().min(rhs.len()); + let (r_lo, r_hi) = rhs.split_at_mut(min_len); + let (l_lo, l_hi) = lhs.split_at(min_len); + + let borrow = sub_bigint_in_right_simple(l_lo, r_lo); + + assert!(l_hi.is_empty()); + assert!(!borrow); + assert!(r_hi.iter().all(|&d| d == 0)); + } + + fn sub_bigint_scalar(lhs: &mut [u32], rhs: u32) { + let mut rhs = Some(rhs); + let mut borrow = false; + for lhs in lhs.iter_mut() { + (*lhs, borrow) = lhs.borrowing_sub(rhs.take().unwrap_or(0), borrow); + if !borrow { + break; + } + } + if borrow { + panic!("sub failed: borrow: {borrow}"); + } + } + + fn add_bigint_scalar(lhs: &mut [u32], rhs: u32) -> bool { + let mut rhs = Some(rhs); + let mut carry = false; + for d in lhs.iter_mut() { + (*d, carry) = (*d).carrying_add(rhs.take().unwrap_or(0), carry); + if !carry { + break; + } + } + carry + } + /// lhs must be bigger than rhs /// returns carry fn add_bigint(lhs: &mut [u32], rhs: &[u32]) -> bool { @@ -851,134 +1152,780 @@ pub mod bigint { } } -use std::ops::Add; +pub mod bigsint { + use std::{ + cmp::Ordering, + ops::{Add, AddAssign, Div, Mul, Neg, Not, Rem, Shl, Shr, Sub, SubAssign}, + }; -use bigint::BigInt; + use super::bigint::{self, *}; + + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + enum Sign { + Negative = 0, + None = 1, + Positive = 2, + } + + impl Neg for Sign { + type Output = Self; + + fn neg(self) -> Self::Output { + match self { + Sign::Negative => Self::Positive, + Sign::None => Self::None, + Sign::Positive => Self::Negative, + } + } + } + + /// A base-4_294_967_295 number. + #[derive(Clone, Debug, Eq, Ord)] + pub struct BigSInt { + sign: Sign, + bigint: BigInt, + } + + impl BigSInt { + pub fn zero() -> BigSInt { + Self { + sign: Sign::None, + bigint: BigInt::zero(), + } + } + pub fn one() -> BigSInt { + Self { + sign: Sign::Positive, + bigint: BigInt::one(), + } + } + + pub fn positive(bigint: BigInt) -> BigSInt { + Self { + sign: Sign::Positive, + bigint, + } + } + + pub fn from_u32(v: u32) -> BigSInt { + let sign = core::num::NonZero::new(v) + .map(|_| Sign::Positive) + .unwrap_or(Sign::None); + Self { + sign, + bigint: BigInt::from_u32(v), + } + } + pub fn from_u64(v: u64) -> BigSInt { + let sign = core::num::NonZero::new(v) + .map(|_| Sign::Positive) + .unwrap_or(Sign::None); + Self { + sign, + bigint: BigInt::from_u64(v), + } + } + pub fn from_i32(v: i32) -> BigSInt { + if v >= 0 { + Self::from_u32(v as u32) + } else { + let v = u32::MAX - (v as u32) + 1; + Self { + sign: Sign::Negative, + bigint: BigInt::from_u32(v), + } + } + } + pub fn from_i64(v: i64) -> BigSInt { + if v >= 0 { + Self::from_u64(v as u64) + } else { + let v = u64::MAX - (v as u64) + 1; + Self { + sign: Sign::Negative, + bigint: BigInt::from_u64(v), + } + } + } + + pub fn is_negative(&self) -> bool { + self.sign == Sign::Negative + } + } + + impl PartialEq for BigSInt { + fn eq(&self, other: &Self) -> bool { + self.sign == other.sign && self.bigint == other.bigint + } + } + + impl PartialOrd for BigSInt { + fn partial_cmp(&self, other: &Self) -> Option { + match self.sign.partial_cmp(&other.sign) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.bigint.partial_cmp(&other.bigint) + } + } + + impl Not for BigSInt { + type Output = Self; + + fn not(mut self) -> Self::Output { + match self.sign { + Sign::Negative => { + self.bigint -= 1u32; + self.sign = if self.bigint.is_zero() { + Sign::None + } else { + Sign::Positive + }; + } + Sign::None | Sign::Positive => { + self.bigint += 1u32; + self.sign = Sign::Negative; + } + } + self + } + } + + impl Add for BigSInt { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + match (self.sign, rhs.sign) { + (_, Sign::None) => self, + (Sign::None, _) => rhs, + (Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => Self { + sign: self.sign, + bigint: self.bigint + rhs.bigint, + }, + (Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => { + match self.bigint.cmp(&rhs.bigint) { + Ordering::Less => Self { + sign: rhs.sign, + bigint: rhs.bigint - self.bigint, + }, + Ordering::Equal => Self::zero(), + Ordering::Greater => Self { + sign: self.sign, + bigint: self.bigint - rhs.bigint, + }, + } + } + } + } + } + + impl Add for BigSInt { + type Output = BigSInt; + + fn add(self, rhs: u32) -> Self::Output { + match self.sign { + Sign::Negative => match self.bigint.partial_cmp(&rhs).unwrap() { + Ordering::Less => Self::positive(rhs - self.bigint), + Ordering::Equal => Self::zero(), + Ordering::Greater => -Self::positive(self.bigint - rhs), + }, + Sign::None => Self::from_u32(rhs), + Sign::Positive => Self::positive(self.bigint + rhs), + } + } + } + + impl Add for BigSInt { + type Output = BigSInt; + + fn add(self, rhs: u64) -> Self::Output { + match self.sign { + Sign::Negative => match self.bigint.partial_cmp(&rhs).unwrap() { + Ordering::Less => Self::positive(rhs - self.bigint), + Ordering::Equal => Self::zero(), + Ordering::Greater => -Self::positive(self.bigint - rhs), + }, + Sign::None => Self::from_u64(rhs), + Sign::Positive => Self::positive(self.bigint + rhs), + } + } + } + + impl AddAssign for BigSInt { + fn add_assign(&mut self, rhs: Self) { + let n = core::mem::replace(self, Self::zero()); + *self = n + rhs; + } + } + + impl Sub for BigSInt { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + match (self.sign, rhs.sign) { + (_, Sign::None) => self, + (Sign::None, _) => -rhs, + (Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => Self { + sign: self.sign, + bigint: self.bigint + rhs.bigint, + }, + (Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => { + match self.bigint.cmp(&rhs.bigint) { + Ordering::Less => Self { + sign: -self.sign, + bigint: rhs.bigint - self.bigint, + }, + Ordering::Equal => Self::zero(), + Ordering::Greater => Self { + sign: self.sign, + bigint: self.bigint - rhs.bigint, + }, + } + } + } + } + } + + impl SubAssign for BigSInt { + fn sub_assign(&mut self, rhs: Self) { + let n = core::mem::replace(self, Self::zero()); + *self = n - rhs; + } + } + + impl Shl for BigSInt { + type Output = Self; + + fn shl(self, rhs: usize) -> Self::Output { + Self { + sign: self.sign, + bigint: self.bigint << rhs, + } + } + } + + impl Shr for BigSInt { + type Output = Self; + + fn shr(self, rhs: usize) -> Self::Output { + let rounding = shr_rounding(&self, rhs); + let mut out = Self { + sign: self.sign, + bigint: self.bigint >> rhs, + }; + + if rounding { + out.bigint += 1u32; + } + + out + } + } + + impl Mul for Sign { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + use Sign::*; + match (self, rhs) { + (Negative, Negative) | (Positive, Positive) => Positive, + (None, _) | (_, None) => todo!(), + (Negative, Positive) | (Positive, Negative) => Negative, + } + } + } + + impl Mul for BigSInt { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self { + sign: self.sign * rhs.sign, + bigint: self.bigint * rhs.bigint, + } + } + } + + impl Div for BigSInt { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + div_rem_bigsint(&self, &rhs).0 + } + } + + impl Rem for BigSInt { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + div_rem_bigsint(&self, &rhs).1 + } + } + + fn div_rem_bigsint(lhs: &BigSInt, rhs: &BigSInt) -> (BigSInt, BigSInt) { + let (q, r) = bigint::div_rem_bigint_ref(&lhs.bigint, &rhs.bigint); + let q = BigSInt { + sign: lhs.sign, + bigint: q, + }; + let r = BigSInt { + sign: lhs.sign, + bigint: r, + }; + + if rhs.is_negative() { + (-q, r) + } else { + (q, r) + } + } + + fn shr_rounding(lhs: &BigSInt, shift: usize) -> bool { + if lhs.is_negative() { + let ctz = lhs.bigint.trailing_zeros(); + shift > 0 && ctz < shift + } else { + false + } + } + + impl Neg for BigSInt { + type Output = Self; + + fn neg(mut self) -> Self::Output { + self.sign = -self.sign; + self + } + } +} + +use std::ops::{Add, BitAnd, BitOr, BitXor, Not}; + +use num_bigint::{BigInt, Sign}; +use num_traits::cast::ToPrimitive; use crate::ast::IntegralType; #[derive(Debug, thiserror::Error)] -enum Error { +pub enum Error { #[error("Incompatible Comptime Number variants.")] IncompatibleTypes, #[error("Integer overflow.")] IntegerOverflow, + #[error("Shift cannot fit into u32.")] + ShiftTooLarge, + #[error("Cannot negate unsigned integer")] + UnsignedNegation, } -type Result = core::result::Result; +pub type Result = core::result::Result; -enum ComptimeInt { - Native { bits: u64, ty: IntegralType }, +pub enum ComptimeInt { + Native { bits: u128, ty: IntegralType }, BigInt { bits: BigInt, ty: IntegralType }, Comptime(BigInt), } impl ComptimeInt { pub fn add(self, other: Self) -> Result { - match (self, other) { - (Self::Native { bits: a, ty: aty }, Self::Native { bits: b, ty: bty }) => { - let bits = if aty != bty { - return Err(Error::IncompatibleTypes); - } else if aty.signed { - (a as i64) - .checked_add(b as i64) - .ok_or(Error::IntegerOverflow)? as u64 - } else { - (a as u64) - .checked_add(b as u64) - .ok_or(Error::IntegerOverflow)? as u64 - }; - if bits & !aty.u64_bitmask() != 0 { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.checked_add(b).ok_or(Error::IntegerOverflow)?; + if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } - Ok(Self::Native { bits, ty: aty }) + Ok(Self::Native { bits, ty }) } - (Self::Comptime(a), Self::Comptime(b)) => Ok(Self::Comptime(a.clone().add(b.clone()))), - (Self::Comptime(a), Self::Native { bits, ty }) - | (Self::Native { bits, ty }, Self::Comptime(a)) => { - let b = a.into_u64().ok_or(Error::IncompatibleTypes)?; - - Self::Native { bits, ty }.add(Self::Native { bits: b, ty }) + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let width = ty.bits - ty.signed as u16; + let bits = a + b; + if bits.bits() > width as u64 { + Err(Error::IntegerOverflow) + } else { + Ok(Self::BigInt { bits, ty }) + } + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a + b)), + _ => { + unreachable!() } - _ => Err(Error::IncompatibleTypes), } } + pub fn sub(self, other: Self) -> Result { - match (self, other) { - (Self::Native { bits: a, ty: aty }, Self::Native { bits: b, ty: bty }) => { - let bits = if aty != bty { - return Err(Error::IncompatibleTypes); - } else if aty.signed { - (a as i64) - .checked_sub(b as i64) - .ok_or(Error::IntegerOverflow)? as u64 - } else { - (a as u64) - .checked_sub(b as u64) - .ok_or(Error::IntegerOverflow)? as u64 - }; - if bits & !aty.u64_bitmask() != 0 { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.checked_sub(b).ok_or(Error::IntegerOverflow)?; + if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } - Ok(Self::Native { bits, ty: aty }) + Ok(Self::Native { bits, ty }) } - (Self::Comptime(a), Self::Comptime(b)) => Ok(Self::Comptime(a - b)), - (Self::Comptime(a), b @ Self::Native { ty, .. }) => { - let a = a.into_u64().ok_or(Error::IncompatibleTypes)?; - Self::Native { bits: a, ty }.sub(b) + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let width = ty.bits - ty.signed as u16; + let bits = a - b; + if bits.bits() > width as u64 { + Err(Error::IntegerOverflow) + } else { + Ok(Self::BigInt { bits, ty }) + } } - (a @ Self::Native { ty, .. }, Self::Comptime(b)) => { - let b = b.into_u64().ok_or(Error::IncompatibleTypes)?; + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a - b)), + _ => { + unreachable!() + } + } + } - a.sub(Self::Native { bits: b, ty }) + pub fn mul(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.checked_mul(b).ok_or(Error::IntegerOverflow)?; + if bits & !ty.u128_bitmask() != 0 { + return Err(Error::IntegerOverflow); + } + Ok(Self::Native { bits, ty }) } - _ => Err(Error::IncompatibleTypes), + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let width = ty.bits - ty.signed as u16; + let bits = a * b; + if bits.bits() > width as u64 { + Err(Error::IntegerOverflow) + } else { + Ok(Self::BigInt { bits, ty }) + } + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a * b)), + _ => { + unreachable!() + } + } + } + + pub fn div(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.checked_div(b).ok_or(Error::IntegerOverflow)?; + if bits & !ty.u128_bitmask() != 0 { + return Err(Error::IntegerOverflow); + } + Ok(Self::Native { bits, ty }) + } + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let width = ty.bits - ty.signed as u16; + let bits = a / b; + if bits.bits() > width as u64 { + Err(Error::IntegerOverflow) + } else { + Ok(Self::BigInt { bits, ty }) + } + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a / b)), + _ => { + unreachable!() + } + } + } + + pub fn rem(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.checked_rem(b).ok_or(Error::IntegerOverflow)?; + if bits & !ty.u128_bitmask() != 0 { + return Err(Error::IntegerOverflow); + } + Ok(Self::Native { bits, ty }) + } + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let width = ty.bits - ty.signed as u16; + let bits = a % b; + if bits.bits() > width as u64 { + Err(Error::IntegerOverflow) + } else { + Ok(Self::BigInt { bits, ty }) + } + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a % b)), + _ => { + unreachable!() + } + } + } + + pub fn bitand(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.bitand(b); + Ok(Self::Native { bits, ty }) + } + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let bits = a & b; + Ok(Self::BigInt { bits, ty }) + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a & b)), + _ => { + unreachable!() + } + } + } + + pub fn bitor(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.bitor(b); + Ok(Self::Native { bits, ty }) + } + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let bits = a | b; + Ok(Self::BigInt { bits, ty }) + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a | b)), + _ => { + unreachable!() + } + } + } + + pub fn bitxor(self, other: Self) -> Result { + let (a, b) = self.coalesce(other)?; + match (a, b) { + (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { + let bits = a.bitxor(b); + Ok(Self::Native { bits, ty }) + } + (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { + let bits = a ^ b; + Ok(Self::BigInt { bits, ty }) + } + (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a ^ b)), + _ => { + unreachable!() + } + } + } + + pub fn shl(self, other: Self) -> Result { + use core::ops::Shl; + let shift = other.try_to_u32()?; + match self { + ComptimeInt::Native { bits, ty } => { + let bits = if ty.signed { + (bits as i128) + .checked_shl(shift) + .ok_or(Error::IntegerOverflow)? as u128 + } else { + (bits as u128) + .checked_shl(shift) + .ok_or(Error::IntegerOverflow)? as u128 + } & ty.u128_bitmask(); + + Ok(Self::Native { bits, ty }) + } + ComptimeInt::BigInt { bits, ty } => { + let mut bits = bits.shl(shift); + + for i in 0..shift as u16 { + bits.set_bit((i * ty.bits) as u64, false); + } + Ok(Self::BigInt { bits, ty }) + } + ComptimeInt::Comptime(bits) => Ok(Self::Comptime(bits.shl(shift))), + } + } + + pub fn shr(self, other: Self) -> Result { + use core::ops::Shr; + let shift = other.try_to_u32()?; + match self { + ComptimeInt::Native { bits, ty } => { + let bits = if ty.signed { + (bits as i128) + .checked_shr(shift) + .ok_or(Error::IntegerOverflow)? as u128 + } else { + (bits as u128) + .checked_shr(shift) + .ok_or(Error::IntegerOverflow)? as u128 + }; + + Ok(Self::Native { bits, ty }) + } + ComptimeInt::BigInt { bits, ty } => Ok(Self::BigInt { + bits: bits.shr(shift), + ty, + }), + ComptimeInt::Comptime(bits) => Ok(Self::Comptime(bits.shr(shift))), + } + } + + pub fn neg(self) -> Result { + match self { + Self::Native { bits: a, ty } => { + if ty.signed { + return Err(Error::UnsignedNegation); + } + let bits = (a as i128).checked_neg().ok_or(Error::IntegerOverflow)? as u128; + + if bits & !ty.u128_bitmask() != 0 { + return Err(Error::IntegerOverflow); + } + Ok(Self::Native { bits, ty }) + } + Self::Comptime(a) => Ok(Self::Comptime(-a)), + Self::BigInt { bits, ty } => Ok(Self::BigInt { bits: -bits, ty }), + } + } + + pub fn not(self) -> Result { + match self { + ComptimeInt::Native { bits, ty } => Ok(Self::Native { + bits: !bits | ty.u128_bitmask(), + ty, + }), + ComptimeInt::BigInt { bits, ty } => Ok(Self::BigInt { bits: !bits, ty }), + ComptimeInt::Comptime(bigint) => Ok(Self::Comptime(!bigint)), + } + } + + fn try_to_u32(&self) -> Result { + match self { + ComptimeInt::Native { bits, .. } => bits.to_u32(), + ComptimeInt::BigInt { bits, .. } => bits.to_u32(), + ComptimeInt::Comptime(bits) => bits.to_u32(), + } + .ok_or(Error::ShiftTooLarge) + } + + fn coalesce(self, other: Self) -> Result<(ComptimeInt, ComptimeInt)> { + match (self, other) { + (lhs @ ComptimeInt::Native { ty: a_ty, .. }, ComptimeInt::Comptime(b)) + | (lhs @ ComptimeInt::Native { ty: a_ty, .. }, ComptimeInt::BigInt { bits: b, .. }) => { + let b_signed = b.sign() == Sign::Minus; + if !a_ty.signed && b_signed { + return Err(Error::IncompatibleTypes); + } + + let bits = b.bits() + a_ty.signed as u64; + if bits as u16 > a_ty.bits { + return Err(Error::IncompatibleTypes); + } + let b = if b_signed { + b.to_i128().unwrap() as u128 + } else { + b.to_u128().unwrap() + }; + Ok((lhs, Self::Native { bits: b, ty: a_ty })) + } + (ComptimeInt::Comptime(b), rhs @ ComptimeInt::Native { ty: a_ty, .. }) + | (ComptimeInt::BigInt { bits: b, .. }, rhs @ ComptimeInt::Native { ty: a_ty, .. }) => { + let b_signed = b.sign() == Sign::Minus; + if !a_ty.signed && b_signed { + return Err(Error::IncompatibleTypes); + } + + let bits = b.bits() + a_ty.signed as u64; + if bits as u16 > a_ty.bits { + return Err(Error::IncompatibleTypes); + } + let b = if b_signed { + b.to_i128().unwrap() as u128 + } else { + b.to_u128().unwrap() + }; + Ok((Self::Native { bits: b, ty: a_ty }, rhs)) + } + (lhs @ ComptimeInt::BigInt { ty, .. }, ComptimeInt::Comptime(b)) => { + let b_signed = b.sign() == Sign::Minus; + if !ty.signed && b_signed { + return Err(Error::IncompatibleTypes); + } + + let bits = b.bits() + ty.signed as u64; + if bits as u16 > ty.bits { + return Err(Error::IncompatibleTypes); + } + Ok((lhs, Self::BigInt { bits: b, ty })) + } + (ComptimeInt::Comptime(b), rhs @ ComptimeInt::BigInt { ty, .. }) => { + let b_signed = b.sign() == Sign::Minus; + if !ty.signed && b_signed { + return Err(Error::IncompatibleTypes); + } + + let bits = b.bits() + ty.signed as u64; + if bits as u16 > ty.bits { + return Err(Error::IncompatibleTypes); + } + Ok((Self::BigInt { bits: b, ty }, rhs)) + } + (lhs @ ComptimeInt::Native { ty: a, .. }, rhs @ ComptimeInt::Native { ty: b, .. }) => { + if a == b { + Ok((lhs, rhs)) + } else { + Err(Error::IncompatibleTypes) + } + } + (lhs @ ComptimeInt::BigInt { ty: a, .. }, rhs @ ComptimeInt::BigInt { ty: b, .. }) => { + if a == b { + Ok((lhs, rhs)) + } else { + Err(Error::IncompatibleTypes) + } + } + (lhs, rhs) => Ok((lhs, rhs)), } } } -enum ComptimeFloat { +pub enum ComptimeFloat { Binary32(f32), Binary64(f64), } impl ComptimeFloat { - pub fn add(&self, other: &Self) -> Result { + pub fn add(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a + b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a + b)), _ => Err(Error::IncompatibleTypes), } } - pub fn sub(&self, other: &Self) -> Result { + pub fn sub(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a - b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a - b)), _ => Err(Error::IncompatibleTypes), } } - pub fn mul(&self, other: &Self) -> Result { + pub fn mul(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a * b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a * b)), _ => Err(Error::IncompatibleTypes), } } - pub fn div(&self, other: &Self) -> Result { + pub fn div(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a / b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a / b)), _ => Err(Error::IncompatibleTypes), } } - pub fn rem(&self, other: &Self) -> Result { + pub fn rem(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a % b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a % b)), _ => Err(Error::IncompatibleTypes), } } - pub fn neg(&self) -> Result { + pub fn neg(self) -> Result { match self { ComptimeFloat::Binary32(a) => Ok(Self::Binary32(-a)), ComptimeFloat::Binary64(a) => Ok(Self::Binary64(-a)), @@ -986,10 +1933,147 @@ impl ComptimeFloat { } } -enum ComptimeNumber { +pub enum ComptimeNumber { Integral(ComptimeInt), Bool(bool), Floating(ComptimeFloat), } -impl ComptimeNumber {} +impl ComptimeNumber { + pub fn add(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.add(b)?)) + } + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + Ok(Self::Floating(a.add(b)?)) + } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.add(b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn sub(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.sub(b)?)) + } + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + Ok(Self::Floating(a.sub(b)?)) + } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.sub(b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn mul(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.mul(b)?)) + } + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + Ok(Self::Floating(a.mul(b)?)) + } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.mul(b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn div(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.div(b)?)) + } + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + Ok(Self::Floating(a.div(b)?)) + } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.div(b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn rem(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.rem(b)?)) + } + (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + Ok(Self::Floating(a.rem(b)?)) + } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.rem(b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn neg(self) -> Result { + match self { + ComptimeNumber::Integral(a) => Ok(Self::Integral(a.neg()?)), + ComptimeNumber::Floating(a) => Ok(Self::Floating(a.neg()?)), + //ComptimeNumber::Bool(a) => todo!(), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn not(self) -> Result { + match self { + ComptimeNumber::Integral(a) => Ok(Self::Integral(a.not()?)), + // ComptimeNumber::Floating(a) => Ok(Self::Floating(a.not()?)), + ComptimeNumber::Bool(a) => Ok(Self::Bool(a.not())), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn bitand(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.bitand(b)?)) + } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.sub(b)?)) + // } + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitand(b))), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn bitor(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.bitor(b)?)) + } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.bitor(b)?)) + // } + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitor(b))), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn bitxor(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.bitxor(b)?)) + } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.bitxor(b)?)) + // } + (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitxor(b))), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn shl(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.shl(b)?)) + } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.bitxor(b)?)) + // } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitxor(b))), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn shr(self, other: Self) -> Result { + match (self, other) { + (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { + Ok(Self::Integral(a.shr(b)?)) + } + // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { + // Ok(Self::Floating(a.bitxor(b)?)) + // } + // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitxor(b))), + _ => Err(Error::IncompatibleTypes), + } + } +}