diff --git a/src/ast.rs b/src/ast.rs index 0c987f9..ae47fc6 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,7 +1,5 @@ use std::num::NonZero; -use itertools::Itertools; - use crate::string_table::{self, ImmOrIndex}; pub type Node = NonZero; @@ -229,6 +227,9 @@ impl IntegralType { bits: 32, } } + pub fn u64_bitmask(self) -> u64 { + (1u64 << self.bits) - 1 + } } #[derive(Debug, Clone, Eq, Hash)] diff --git a/src/comptime.rs b/src/comptime.rs new file mode 100644 index 0000000..2eed83e --- /dev/null +++ b/src/comptime.rs @@ -0,0 +1,995 @@ +pub mod bigint { + + use crate::lexer::Radix; + /// A base-4_294_967_295 number. + #[derive(Clone)] + pub struct BigInt(Vec); + + impl BigInt { + pub fn parse_digits>(text: C, radix: Radix) -> BigInt { + parse_bigint(text.into_iter(), radix) + } + + pub fn one() -> BigInt { + Self(vec![1]) + } + pub fn zero() -> BigInt { + Self(vec![]) + } + + pub fn num_digits(&self) -> usize { + self.0.iter().rposition(|&d| d != 0).map_or(0, |i| i + 1) + } + + pub fn bit_width(&self) -> usize { + count_bits(&self.0) + } + + pub fn is_zero(&self) -> bool { + bigint_is_zero(&self.0) + } + + pub fn is_one(&self) -> bool { + bigint_is_one(&self.0) + } + + pub fn is_power_of_two(&self) -> bool { + is_power_of_two(&self.0) + } + + pub fn trailing_zeros(&self) -> usize { + trailing_zeros(&self.0) + } + + pub fn from_bytes_le(bytes: &[u8]) -> BigInt { + let data = bytes + .chunks(4) + .map(|chunk| { + let mut int = [0u8; 4]; + int[..chunk.len()].copy_from_slice(chunk); + u32::from_le_bytes(int) + }) + .collect::>(); + + BigInt(data) + } + + pub fn into_bytes_le(&self) -> Vec { + let mut bytes = Vec::::new(); + + for d in &self.0[..] { + bytes.extend(&d.to_le_bytes()); + } + + let count = bytes.iter().rev().take_while(|&&b| b == 0).count(); + bytes.truncate(bytes.len() - count); + + bytes + } + + pub fn normalise(&mut self) { + let len = self.0.iter().rposition(|&d| d != 0).map_or(0, |i| i + 1); + self.0.truncate(len); + } + + pub fn normalised(mut self) -> BigInt { + self.normalise(); + self + } + + pub fn into_u64(&self) -> Option { + if self.0.len() <= 2 { + let mut bytes = [0u8; 8]; + self.0 + .get(0) + .map(|&dw| bytes[..4].copy_from_slice(&dw.to_le_bytes())); + self.0 + .get(1) + .map(|&dw| bytes[4..].copy_from_slice(&dw.to_le_bytes())); + Some(u64::from_le_bytes(bytes)) + } else { + None + } + } + } + + impl core::cmp::PartialEq for BigInt { + fn eq(&self, other: &Self) -> bool { + cmp_bigint(&self.0, &other.0) == core::cmp::Ordering::Equal + } + } + + impl core::cmp::Eq for BigInt {} + + impl core::cmp::PartialOrd for BigInt { + fn partial_cmp(&self, other: &Self) -> Option { + Some(cmp_bigint(&self.0, &other.0)) + } + } + impl core::cmp::Ord for BigInt { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + cmp_bigint(&self.0, &other.0) + } + } + + impl core::ops::Shl for BigInt { + type Output = Self; + + fn shl(mut self, rhs: usize) -> Self::Output { + shl_bitint(&mut self.0, rhs); + self + } + } + impl core::ops::Shr for BigInt { + type Output = Self; + + fn shr(mut self, rhs: usize) -> Self::Output { + shr_bitint(&mut self.0, rhs); + self + } + } + + impl core::ops::Add for BigInt { + type Output = Self; + + fn add(mut self, mut rhs: Self) -> Self::Output { + let (mut digits, carry) = if self.0.len() > rhs.0.len() { + let c = add_bigint(&mut self.0, &rhs.0); + (self.0, c) + } else { + let c = add_bigint(&mut rhs.0, &self.0); + (rhs.0, c) + }; + + if carry { + digits.push(u32::from(carry)); + } + + BigInt(digits) + } + } + + impl core::ops::Sub for BigInt { + type Output = Self; + + fn sub(mut self, rhs: Self) -> Self::Output { + if self.0.len() < rhs.0.len() { + println!("extending self by {} zeroes", rhs.0.len() - self.0.len()); + self.0 + .extend(core::iter::repeat(0).take(rhs.0.len() - self.0.len())); + println!("self: {self:?}"); + } + sub_bigint(&mut self.0, &rhs.0); + + self + } + } + + impl core::ops::Mul for BigInt { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + BigInt(mul_bigint(&self.0, &rhs.0)) + } + } + + impl core::ops::Div for BigInt { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + div_rem_bigint(self, rhs).0 + } + } + + impl core::ops::Rem for BigInt { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + div_rem_bigint(self, rhs).1 + } + } + + impl core::ops::Not for BigInt { + type Output = Self; + + fn not(mut self) -> Self::Output { + self.0.iter_mut().for_each(|c| *c = !*c); + self + } + } + + impl core::fmt::Debug for BigInt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut list = f.debug_list(); + list.entries(self.0.iter().rev()).finish() + } + } + + /// counts used bits in a u32 slice, discards leading zeros in MSB. + /// `[0xff,0xff,0x00,0x00]` -> 16 + /// `[0xff,0xff,0x00]` -> 16 + /// `[0xff,0xff,0x0f]` -> 20 + pub fn count_bits(bytes: &[u32]) -> usize { + let mut bits = bytes.len() * u32::BITS as usize; + + for &d in bytes.iter().rev() { + if d == 0 { + bits -= u32::BITS as usize; + } else { + bits -= d.leading_zeros() as usize; + break; + } + } + + bits + } + + #[test] + fn test_count_bits() { + assert_eq!(count_bits(&[0xffffffff, 0x00, 0x00]), 32); + assert_eq!(count_bits(&[0x00, 0x00, 0x00]), 0); + assert_eq!(count_bits(&[]), 0); + assert_eq!(count_bits(&[0xffffffff, 0xff, 0x00]), 40); + assert_eq!(count_bits(&[0xffffffff, 0xff]), 40); + assert_eq!(count_bits(&[0xffffffff, 0xff, 0xffff]), 64 + 16); + } + #[test] + fn test_count_trailing_zeros() { + assert_eq!(trailing_zeros(&[0xffffffff, 0x00, 0x00]), 0); + assert_eq!(trailing_zeros(&[0x00, 0x00, 0x00]), 0); + assert_eq!(trailing_zeros(&[]), 0); + assert_eq!(trailing_zeros(&[0x00, 0xffffffff, 0xff]), 32); + assert_eq!(trailing_zeros(&[0x00, 0xffffff00, 0xff]), 40); + } + + #[allow(unused)] + /// lhs <=> rhs + fn cmp_bigint(lhs: &[u32], rhs: &[u32]) -> core::cmp::Ordering { + use core::cmp::Ordering; + let lhs_bits = count_bits(lhs); + let rhs_bits = count_bits(rhs); + + match lhs_bits.cmp(&rhs_bits) { + Ordering::Less => Ordering::Less, + Ordering::Greater => Ordering::Greater, + Ordering::Equal => { + for (a, b) in lhs[..(lhs_bits / u32::BITS as usize)] + .iter() + .zip(rhs[..(lhs_bits / u32::BITS as usize)].iter()) + .rev() + { + let ord = a.cmp(b); + if ord != Ordering::Equal { + return ord; + } + } + return Ordering::Equal; + } + } + } + + fn bigint_is_zero(lhs: &[u32]) -> bool { + if lhs.len() == 0 { + true + } else { + lhs.iter().all(|c| c == &0) + } + } + + #[allow(dead_code)] + fn bigint_is_one(lhs: &[u32]) -> bool { + lhs.len() > 0 && lhs[0] == 1 && lhs[1..].iter().all(|c| c == &0) + } + + #[allow(dead_code)] + fn bitnot_bigint(lhs: &mut [u32]) { + for d in lhs.iter_mut() { + *d = !*d; + } + } + + #[allow(dead_code)] + fn u32_mul_bigint(lhs: &mut Vec, scalar: u32) { + match scalar { + 0 => { + lhs.clear(); + lhs.push(0) + } + 1 => {} + _ => { + if scalar.is_power_of_two() { + lhs.push(0); + shl_bitint(lhs.as_mut_slice(), scalar.trailing_zeros() as usize); + } else { + let mut carry = 0; + for a in lhs.iter_mut() { + (*a, carry) = (*a).carrying_mul(scalar, carry); + } + if carry != 0 { + lhs.push(carry); + } + } + } + } + } + + #[allow(dead_code)] + fn u64_mul_bigint(lhs: &mut Vec, scalar: u64) { + let lo = scalar as u32; + let hi = (scalar >> 32) as u32; + u32_mul_bigint(lhs, lo); + shl_bitint(lhs, 32); + u32_mul_bigint(lhs, hi); + } + + #[allow(dead_code)] + fn mul_bigint(lhs: &[u32], rhs: &[u32]) -> Vec { + if bigint_is_zero(lhs) || bigint_is_zero(rhs) { + return vec![]; + } + + let len = lhs.len() + rhs.len() + 1; + let mut product = vec![0u32; len]; + + for (bth, &b) in rhs.iter().enumerate() { + let mut carry = 0u32; + + for (ath, &a) in lhs.iter().enumerate() { + let prod; + (prod, carry) = a.carrying_mul(b, carry); + let (digit, c) = product[ath + bth].carrying_add(prod, false); + carry += c as u32; + product[ath + bth] = digit; + } + + if carry != 0 { + product[bth + lhs.len()] += carry; + } + } + + product + } + + #[allow(dead_code)] + fn sum_digits(digits: &[u32]) -> u64 { + let mut sum = 0u64; + + let mut carry = false; + for &d in digits { + (sum, carry) = sum.carrying_add(d as u64, carry); + } + + sum + carry as u64 + } + + #[allow(dead_code)] + fn count_ones(lhs: &[u32]) -> usize { + lhs.iter() + .fold(0usize, |acc, c| acc + c.count_ones() as usize) + } + + #[allow(dead_code)] + fn trailing_zeros(lhs: &[u32]) -> usize { + lhs.iter() + .enumerate() + .find(|(_, &c)| c != 0) + .map(|(i, &c)| i * u32::BITS as usize + c.trailing_zeros() as usize) + .unwrap_or(0) + } + + #[allow(dead_code)] + fn is_power_of_two(lhs: &[u32]) -> bool { + count_ones(lhs) == 1 + } + + #[allow(dead_code)] + /// divident must be at least as wide as divisor + /// returns (quotient, remainder) + fn div_rem_bigint(divident: BigInt, divisor: BigInt) -> (BigInt, BigInt) { + let divident = divident.normalised(); + let mut divisor = divisor.normalised(); + + if bigint_is_zero(&divisor.0) { + panic!("divide by zero!"); + } + if bigint_is_zero(÷nt.0) { + return (BigInt::zero(), BigInt::zero()); + } + use core::cmp::Ordering; + match cmp_bigint(÷nt.0, &divisor.0) { + Ordering::Less => return (BigInt::zero(), divident), + Ordering::Equal => { + return (BigInt::one(), BigInt::zero()); + } + Ordering::Greater => {} + } + + if divisor.is_power_of_two() { + let exp = divisor.trailing_zeros(); + let (div, rem) = divident.0.split_at(exp.div_floor(u32::BITS as usize)); + let (mut div, mut rem) = (div.to_vec(), rem.to_vec()); + + shr_bitint(&mut div, exp % u32::BITS as usize); + let mask = (1u32 << exp as u32 % u32::BITS) - 1; + if let Some(last) = rem.last_mut() { + *last &= mask; + } + + return (BigInt(div), BigInt(rem)); + } + + if divisor.num_digits() == 1 { + if divisor.0[0] == 1 { + return (divident, BigInt::zero()); + } + + let (div, rem) = div_digit_bigint(divident, divisor.0[0]); + divisor.0.clear(); + divisor.0.push(rem); + return (div, divisor); + } + + let shift = divisor.0.last().unwrap().leading_zeros() as usize; + if shift == 0 { + div_rem_core(divident, &divisor.0) + } else { + let (q, r) = div_rem_core(divident << shift, &(divisor << shift).0); + + (q, r >> shift) + } + } + + fn scalar_div_wide(hi: u32, lo: u32, divisor: u32) -> (u32, u32) { + let (div, rem); + + unsafe { + core::arch::asm! { + "div {0:e}", + in(reg) divisor, + inout("dx") hi => rem, + inout("ax") lo => div, + } + } + + (div, rem) + } + + #[allow(dead_code)] + fn div_digit_bigint(mut divident: BigInt, divisor: u32) -> (BigInt, u32) { + assert!(divisor != 0); + let mut rem = 0; + + for d in divident.0.iter_mut().rev() { + (*d, rem) = scalar_div_wide(rem, *d, divisor); + } + + (divident.normalised(), rem) + } + + fn from_lo_hi(lo: u32, hi: u32) -> u64 { + lo as u64 | (hi as u64) << 32 + } + fn into_lo_hi(qword: u64) -> (u32, u32) { + (qword as u32, (qword >> 32) as u32) + } + + // from rust num_bigint + /// Subtract a multiple. + /// a -= b * c + /// Returns a borrow (if a < b then borrow > 0). + fn sub_mul_digit_same_len(a: &mut [u32], b: &[u32], c: u32) -> u32 { + assert!(a.len() == b.len()); + + // carry is between -big_digit::MAX and 0, so to avoid overflow we store + // offset_carry = carry + big_digit::MAX + let mut offset_carry = u32::MAX; + + for (x, y) in a.iter_mut().zip(b) { + // We want to calculate sum = x - y * c + carry. + // sum >= -(big_digit::MAX * big_digit::MAX) - big_digit::MAX + // sum <= big_digit::MAX + // Offsetting sum by (big_digit::MAX << big_digit::BITS) puts it in DoubleBigDigit range. + let offset_sum = from_lo_hi(u32::MAX, *x) - u32::MAX as u64 + offset_carry as u64 + - *y as u64 * c as u64; + + let (new_x, new_offset_carry) = into_lo_hi(offset_sum); + offset_carry = new_offset_carry; + *x = new_x; + } + + // Return the borrow. + u32::MAX - offset_carry + } + + // from rust num_bigint + fn div_rem_core(mut a: BigInt, b: &[u32]) -> (BigInt, BigInt) { + // sanity check on fast paths + assert!(a.0.len() >= b.len() && b.len() > 1); + + // a0 stores an additional extra most significant digit of the dividend, not stored in a. + let mut a0 = 0; + + // [b1, b0] are the two most significant digits of the divisor. They never change. + let b0 = b[b.len() - 1]; + let b1 = b[b.len() - 2]; + + let q_len = a.0.len() - b.len() + 1; + let mut q = BigInt(vec![0; q_len]); + + for j in (0..q_len).rev() { + assert!(a.0.len() == b.len() + j); + + let a1 = *a.0.last().unwrap(); + let a2 = a.0[a.0.len() - 2]; + + // The first q0 estimate is [a1,a0] / b0. It will never be too small, it may be too large + // by at most 2. + let (mut q0, mut r) = if a0 < b0 { + let (q0, r) = scalar_div_wide(a0, a1, b0); + (q0, r as u64) + } else { + assert!(a0 == b0); + // Avoid overflowing q0, we know the quotient fits in BigDigit. + // [a1,a0] = b0 * (1< a0 { + // q0 is too large. We need to add back one multiple of b. + q0 -= 1; + borrow -= add_bigint(&mut a.0[j..], b) as u32; + } + // The top digit of a, stored in a0, has now been zeroed. + assert!(borrow == a0); + + q.0[j] = q0; + + // Pop off the next top digit of a. + a0 = a.0.pop().unwrap(); + } + + a.0.push(a0); + a.normalise(); + + assert_eq!(cmp_bigint(&a.0, b), core::cmp::Ordering::Less); + + (q.normalised(), a) + } + + #[allow(unused)] + fn shr_bitint(lhs: &mut [u32], shift: usize) { + if bigint_is_zero(lhs) || shift == 0 { + return; + } + + let len = lhs.len(); + let digit_offset = shift / 32; + let bit_shift = shift % 32; + + if digit_offset != 0 { + lhs.copy_within(digit_offset..len, 0); + lhs[(len - digit_offset)..].fill(0); + } + if bit_shift != 0 { + let lo_mask = (1u32 << (u32::BITS as usize - bit_shift)) - 1; + let hi_mask = !lo_mask; + + eprintln!("lhs >> {shift}"); + eprintln!("\tdigit_offset: {digit_offset}"); + eprintln!("\tbit_shift: {bit_shift}"); + eprintln!("\tlo_mask: 0b{lo_mask:0>32b}"); + eprintln!("\thi_mask: 0b{hi_mask:0>32b}"); + + let mut carry = 0u32; + for i in 0..lhs.len() { + let digit = ((lhs[i] as u64) << 32) >> bit_shift; + let lo = digit as u32; + let hi = (digit >> 32) as u32; + + lhs[i] &= hi_mask; + lhs[i] |= hi; + + if i > 0 { + lhs[i - 1] &= lo_mask; + lhs[i - 1] |= lo; + } + } + } + } + + #[allow(unused)] + /// lhs must have shift / 32 + 1 digits past the last bit if shifted-past bits are desired. + fn shl_bitint(lhs: &mut [u32], shift: usize) { + if bigint_is_zero(lhs) || shift == 0 { + return; + } + + let len = lhs.len(); + let digit_offset = shift / 32; + let bit_shift = shift % 32; + + if digit_offset != 0 { + lhs.copy_within(0..(len - digit_offset), digit_offset); + lhs[..digit_offset].fill(0); + } + if bit_shift != 0 { + let hi_mask = (1u32 << bit_shift) - 1; + let lo_mask = !hi_mask; + + eprintln!("lhs << {shift}"); + eprintln!("\tdigit_offset: {digit_offset}"); + eprintln!("\tbit_shift: {bit_shift}"); + eprintln!("\tlo_mask: 0b{lo_mask:0>32b}"); + eprintln!("\thi_mask: 0b{hi_mask:0>32b}"); + + // example with u8 digits, shift = 3; + // hi_mask = 0b00000111 + // lo_mask = 0b11111000 + // + // lhs[i] as u16 = 0b00000000_01111000 + // digit = 0b00000011_11000000 + // hi = 0b00000011 + // lo = 0b11000000 + + let mut carry = 0u32; + for i in (digit_offset..len).rev() { + let digit = (lhs[i] as u64) << bit_shift; + let lo = digit as u32; + let hi = (digit >> 32) as u32; + + lhs[i] &= lo_mask; + lhs[i] |= lo; + + if i + 1 < len { + lhs[i + 1] &= hi_mask; + lhs[i + 1] |= hi; + } + } + } + } + + #[allow(unused)] + /// lhs must be bigger than rhs + fn sub_bigint(lhs: &mut [u32], rhs: &[u32]) { + if bigint_is_zero(rhs) { + return; + } + + let len = lhs.len().min(rhs.len()); + let (l_lo, l_hi) = lhs.split_at_mut(len); + let (r_lo, r_hi) = rhs.split_at(len); + + println!("lhs: {{ lo: {l_lo:?}, hi: {l_hi:?} }}"); + println!("rhs: {{ lo: {r_lo:?}, hi: {r_hi:?} }}"); + + let mut borrow = false; + for (lhs, rhs) in l_lo.iter_mut().zip(r_lo) { + (*lhs, borrow) = lhs.borrowing_sub(*rhs, borrow); + } + + if borrow { + for lhs in l_hi { + (*lhs, borrow) = lhs.borrowing_sub(0, borrow); + } + } + + if borrow || !r_hi.iter().all(|&v| v == 0) { + panic!("sub failed: borrow: {borrow}"); + } + } + + /// lhs must be bigger than rhs + /// returns carry + fn add_bigint(lhs: &mut [u32], rhs: &[u32]) -> bool { + if bigint_is_zero(rhs) { + return false; + } + + let (l_lo, l_hi) = lhs.split_at_mut(rhs.len()); + + let mut carry = false; + for (lhs, rhs) in l_lo.iter_mut().zip(rhs) { + (*lhs, carry) = lhs.carrying_add(*rhs, carry); + } + + if carry { + for d in l_hi.iter_mut() { + (*d, carry) = d.carrying_add(0, carry); + if !carry { + break; + } + } + } + + carry + } + + fn parse_bigint(text: impl Iterator, radix: Radix) -> BigInt { + let digits = text + .filter_map(|c| match c { + '_' => None, + c => Some(radix.map_digit(c)), + }) + .collect::>(); + + let (max, power) = { + let radix = radix.radix() as u64; + let mut power = 1; + let mut base = radix; + while let Some(b) = base.checked_mul(radix) { + if b > u32::MAX as u64 { + break; + } + base = b; + power += 1; + } + (base, power) + }; + let radix = radix.radix() as u32; + + let r = digits.len() % power; + let i = if r == 0 { power } else { r }; + let (head, tail) = digits.split_at(i); + + let first = head + .iter() + .fold(0, |acc, &digit| acc * radix + digit as u32); + let mut data = vec![first]; + + for chunk in tail.chunks(power) { + if data.last() != Some(&0) { + data.push(0); + } + let mut carry = 0u64; + for digit in data.iter_mut() { + carry += *digit as u64 * max as u64; + *digit = carry as u32; + carry >>= u32::BITS; + } + assert!(carry == 0); + let next = chunk + .iter() + .fold(0, |acc, &digit| acc * radix + digit as u32); + + let (res, mut carry) = data[0].carrying_add(next, false); + data[0] = res; + if carry { + for digit in data[1..].iter_mut() { + (*digit, carry) = digit.carrying_add(0, carry); + if !carry { + break; + } + } + } + } + BigInt(data) + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn parse() { + let bigint = super::parse_bigint("2_cafe_babe_dead_beef".chars(), Radix::Hex); + println!("{:#x?}", bigint); + let bigint = super::parse_bigint("f".chars(), Radix::Hex); + println!("{:#x?}", bigint); + } + #[test] + fn add() { + let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); + println!("{:#x?}", a); + let b = super::parse_bigint("cafebabe".chars(), Radix::Hex); + println!("{:#x?}", b); + let sum = a + b; + println!("{:#x?}", sum); + } + #[test] + fn sub() { + let a = super::parse_bigint("deadbeef".chars(), Radix::Hex); + println!("{:#x?}", a); + let b = super::parse_bigint("56d2c".chars(), Radix::Hex); + println!("{:#x?}", b); + let sum = a - b; + println!("{:#x?}", sum); + } + #[test] + fn overflowing_sub() { + let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); + println!("{:#x?}", a); + let b = super::parse_bigint("ffff_ffff".chars(), Radix::Hex); + println!("{:#x?}", b); + let sum = b - a; + println!("{:#x?}", sum); + } + #[test] + fn shr() { + let mut a = super::parse_bigint("cafe_babe_0000".chars(), Radix::Hex); + print!("{:0>8x?} >> 32 ", a); + shr_bitint(&mut a.0, 32); + println!("{:0>8x?}", a); + + let mut a = super::parse_bigint("11110000".chars(), Radix::Bin); + print!("{:0>8x?} >> 32 ", a); + shr_bitint(&mut a.0, 3); + println!("{:0>8x?}", a); + } + #[test] + fn shl() { + let mut a = super::parse_bigint("ffff_ffff".chars(), Radix::Hex); + a.0.extend([0; 4]); + println!("{:0>8x?}", a); + shl_bitint(&mut a.0, 40); + println!("{:0>8x?}", a); + } + #[test] + fn div() { + let a = super::parse_bigint("cafebabe".chars(), Radix::Hex); + let b = super::parse_bigint("dead".chars(), Radix::Hex); + let (div, rem) = div_rem_bigint(a, b); + println!("div: {:0>8x?}", div); + println!("rem: {:0>8x?}", rem); + } + } +} + +use std::ops::Add; + +use bigint::BigInt; + +use crate::ast::IntegralType; + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Incompatible Comptime Number variants.")] + IncompatibleTypes, + #[error("Integer overflow.")] + IntegerOverflow, +} + +type Result = core::result::Result; + +enum ComptimeInt { + Native { bits: u64, ty: IntegralType }, + BigInt { bits: BigInt, ty: IntegralType }, + Comptime(BigInt), +} + +impl ComptimeInt { + pub fn add(self, other: Self) -> Result { + match (self, other) { + (Self::Native { bits: a, ty: aty }, Self::Native { bits: b, ty: bty }) => { + let bits = if aty != bty { + return Err(Error::IncompatibleTypes); + } else if aty.signed { + (a as i64) + .checked_add(b as i64) + .ok_or(Error::IntegerOverflow)? as u64 + } else { + (a as u64) + .checked_add(b as u64) + .ok_or(Error::IntegerOverflow)? as u64 + }; + if bits & !aty.u64_bitmask() != 0 { + return Err(Error::IntegerOverflow); + } + Ok(Self::Native { bits, ty: aty }) + } + (Self::Comptime(a), Self::Comptime(b)) => Ok(Self::Comptime(a.clone().add(b.clone()))), + (Self::Comptime(a), Self::Native { bits, ty }) + | (Self::Native { bits, ty }, Self::Comptime(a)) => { + let b = a.into_u64().ok_or(Error::IncompatibleTypes)?; + + Self::Native { bits, ty }.add(Self::Native { bits: b, ty }) + } + _ => Err(Error::IncompatibleTypes), + } + } + pub fn sub(self, other: Self) -> Result { + match (self, other) { + (Self::Native { bits: a, ty: aty }, Self::Native { bits: b, ty: bty }) => { + let bits = if aty != bty { + return Err(Error::IncompatibleTypes); + } else if aty.signed { + (a as i64) + .checked_sub(b as i64) + .ok_or(Error::IntegerOverflow)? as u64 + } else { + (a as u64) + .checked_sub(b as u64) + .ok_or(Error::IntegerOverflow)? as u64 + }; + if bits & !aty.u64_bitmask() != 0 { + return Err(Error::IntegerOverflow); + } + Ok(Self::Native { bits, ty: aty }) + } + (Self::Comptime(a), Self::Comptime(b)) => Ok(Self::Comptime(a - b)), + (Self::Comptime(a), b @ Self::Native { ty, .. }) => { + let a = a.into_u64().ok_or(Error::IncompatibleTypes)?; + Self::Native { bits: a, ty }.sub(b) + } + (a @ Self::Native { ty, .. }, Self::Comptime(b)) => { + let b = b.into_u64().ok_or(Error::IncompatibleTypes)?; + + a.sub(Self::Native { bits: b, ty }) + } + _ => Err(Error::IncompatibleTypes), + } + } +} + +enum ComptimeFloat { + Binary32(f32), + Binary64(f64), +} + +impl ComptimeFloat { + pub fn add(&self, other: &Self) -> Result { + match (self, other) { + (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a + b)), + (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a + b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn sub(&self, other: &Self) -> Result { + match (self, other) { + (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a - b)), + (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a - b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn mul(&self, other: &Self) -> Result { + match (self, other) { + (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a * b)), + (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a * b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn div(&self, other: &Self) -> Result { + match (self, other) { + (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a / b)), + (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a / b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn rem(&self, other: &Self) -> Result { + match (self, other) { + (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a % b)), + (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a % b)), + _ => Err(Error::IncompatibleTypes), + } + } + pub fn neg(&self) -> Result { + match self { + ComptimeFloat::Binary32(a) => Ok(Self::Binary32(-a)), + ComptimeFloat::Binary64(a) => Ok(Self::Binary64(-a)), + } + } +} + +enum ComptimeNumber { + Integral(ComptimeInt), + Bool(bool), + Floating(ComptimeFloat), +} + +impl ComptimeNumber {} diff --git a/src/lexer.rs b/src/lexer.rs index b86f02c..359fd65 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -582,266 +582,6 @@ fn try_parse_integral_type(source: &mut Chars) -> Result> { Ok(Some(())) } -pub mod bigint { - - use super::Radix; - pub struct BigInt(Vec); - - impl BigInt { - pub fn parse_digits>(text: C, radix: Radix) -> BigInt { - parse_bigint(text.into_iter(), radix) - } - - pub fn bit_width(&self) -> u32 { - count_bits(&self.0) - } - - pub fn from_bytes_le(bytes: &[u8]) -> BigInt { - let data = bytes - .chunks(4) - .map(|chunk| { - let mut int = [0u8; 4]; - int[..chunk.len()].copy_from_slice(chunk); - u32::from_le_bytes(int) - }) - .collect::>(); - - BigInt(data) - } - - pub fn into_bytes_le(&self) -> Vec { - let mut bytes = Vec::::new(); - - for d in &self.0[..] { - bytes.extend(&d.to_le_bytes()); - } - - let count = bytes.iter().rev().take_while(|&&b| b == 0).count(); - bytes.truncate(bytes.len() - count); - - bytes - } - } - - impl core::ops::Add for BigInt { - type Output = Self; - - fn add(mut self, mut rhs: Self) -> Self::Output { - let (mut digits, carry) = if self.0.len() > rhs.0.len() { - let c = add_bigint(&mut self.0, &rhs.0); - (self.0, c) - } else { - let c = add_bigint(&mut rhs.0, &self.0); - (rhs.0, c) - }; - - if carry { - digits.push(u32::from(carry)); - } - - BigInt(digits) - } - } - - impl core::ops::Sub for BigInt { - type Output = Self; - - fn sub(mut self, rhs: Self) -> Self::Output { - if self.0.len() < rhs.0.len() { - println!("extending self by {} zeroes", rhs.0.len() - self.0.len()); - self.0 - .extend(core::iter::repeat(0).take(rhs.0.len() - self.0.len())); - println!("self: {self:?}"); - } - sub_bigint(&mut self.0, &rhs.0); - - self - } - } - - impl core::fmt::Debug for BigInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut list = f.debug_list(); - list.entries(self.0.iter()).finish() - } - } - - /// counts used bits in a u32 slice, discards leading zeros in MSB. - /// `[0xff,0xff,0x00,0x00]` -> 16 - /// `[0xff,0xff,0x00]` -> 16 - /// `[0xff,0xff,0x0f]` -> 20 - pub fn count_bits(bytes: &[u32]) -> u32 { - let mut bits = bytes.len() as u32; - - for d in bytes.iter().rev() { - if *d == 0 { - bits -= u32::BITS; - } else { - bits -= d.leading_zeros(); - break; - } - } - - bits - } - - #[test] - fn test_count_bits() { - assert_eq!(count_bits(&[0xffffffff, 0x00, 0x00]), 32); - assert_eq!(count_bits(&[0xffffffff, 0xff, 0x00]), 40); - assert_eq!(count_bits(&[0xffffffff, 0xff]), 40); - assert_eq!(count_bits(&[0xffffffff, 0xff, 0xffff]), 64 + 16); - } - - #[allow(unused)] - /// lhs must be bigger than rhs - fn sub_bigint(lhs: &mut [u32], rhs: &[u32]) { - let len = lhs.len().min(rhs.len()); - let (l_lo, l_hi) = lhs.split_at_mut(len); - let (r_lo, r_hi) = rhs.split_at(len); - - println!("lhs: {{ lo: {l_lo:?}, hi: {l_hi:?} }}"); - println!("rhs: {{ lo: {r_lo:?}, hi: {r_hi:?} }}"); - - let mut borrow = false; - for (lhs, rhs) in l_lo.iter_mut().zip(r_lo) { - (*lhs, borrow) = lhs.borrowing_sub(*rhs, borrow); - } - - if borrow { - for lhs in l_hi { - (*lhs, borrow) = lhs.borrowing_sub(0, borrow); - } - } - - if borrow || !r_hi.iter().all(|&v| v == 0) { - panic!("sub failed: borrow: {borrow}"); - } - } - - /// lhs must be bigger than rhs - fn add_bigint(lhs: &mut [u32], rhs: &[u32]) -> bool { - let (l_lo, l_hi) = lhs.split_at_mut(rhs.len()); - - let mut carry = false; - for (lhs, rhs) in l_lo.iter_mut().zip(rhs) { - (*lhs, carry) = lhs.carrying_add(*rhs, carry); - } - - if carry { - for d in l_hi.iter_mut() { - (*d, carry) = d.carrying_add(0, carry); - if !carry { - break; - } - } - } - - carry - } - - fn parse_bigint(text: impl Iterator, radix: Radix) -> BigInt { - let digits = text - .filter_map(|c| match c { - '_' => None, - c => Some(radix.map_digit(c)), - }) - .collect::>(); - - let (max, power) = { - let radix = radix.radix() as u64; - let mut power = 1; - let mut base = radix; - while let Some(b) = base.checked_mul(radix) { - if b > u32::MAX as u64 { - break; - } - base = b; - power += 1; - } - (base, power) - }; - let radix = radix.radix() as u32; - - let r = digits.len() % power; - let i = if r == 0 { power } else { r }; - let (head, tail) = digits.split_at(i); - - let first = head - .iter() - .fold(0, |acc, &digit| acc * radix + digit as u32); - let mut data = vec![first]; - - for chunk in tail.chunks(power) { - if data.last() != Some(&0) { - data.push(0); - } - let mut carry = 0u64; - for digit in data.iter_mut() { - carry += *digit as u64 * max as u64; - *digit = carry as u32; - carry >>= u32::BITS; - } - assert!(carry == 0); - let next = chunk - .iter() - .fold(0, |acc, &digit| acc * radix + digit as u32); - - let (res, mut carry) = data[0].carrying_add(next, false); - data[0] = res; - if carry { - for digit in data[1..].iter_mut() { - (*digit, carry) = digit.carrying_add(0, carry); - if !carry { - break; - } - } - } - } - BigInt(data) - } - - #[cfg(test)] - mod tests { - use super::*; - - #[test] - fn parse() { - let bigint = super::parse_bigint("2_cafe_babe_dead_beef".chars(), Radix::Hex); - println!("{:#x?}", bigint); - let bigint = super::parse_bigint("f".chars(), Radix::Hex); - println!("{:#x?}", bigint); - } - #[test] - fn add() { - let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); - println!("{:#x?}", a); - let b = super::parse_bigint("cafebabe".chars(), Radix::Hex); - println!("{:#x?}", b); - let sum = a + b; - println!("{:#x?}", sum); - } - #[test] - fn sub() { - let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); - println!("{:#x?}", a); - let b = super::parse_bigint("ffff_ffff".chars(), Radix::Hex); - println!("{:#x?}", b); - let sum = a - b; - println!("{:#x?}", sum); - } - #[test] - fn overflowing_sub() { - let a = super::parse_bigint("2_0000_0000_0000_0000".chars(), Radix::Hex); - println!("{:#x?}", a); - let b = super::parse_bigint("ffff_ffff".chars(), Radix::Hex); - println!("{:#x?}", b); - let sum = b - a; - println!("{:#x?}", sum); - } - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Radix { Hex, @@ -873,7 +613,7 @@ impl Radix { } #[allow(unused)] - fn radix(self) -> u8 { + pub fn radix(self) -> u8 { match self { Radix::Hex => 16, Radix::Bin => 2, diff --git a/src/lib.rs b/src/lib.rs index 2679f9d..e63e914 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,16 @@ box_into_inner, hash_extract_if, bigint_helper_methods, - map_try_insert + map_try_insert, + iter_intersperse, + int_roundings )] #![allow(unused_macros)] pub mod asm; pub mod ast; pub mod common; +pub mod comptime; pub mod error; pub mod lexer; pub mod mir; diff --git a/src/parser.rs b/src/parser.rs index ce9cd04..f81583f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -5,12 +5,12 @@ use itertools::Itertools; use crate::{ ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type}, common::NextIf, + comptime::bigint::{self, BigInt}, error::{AnalysisError, AnalysisErrorTag}, - lexer::{bigint::BigInt, Radix, TokenIterator}, + lexer::{Radix, TokenIterator}, string_table::{ImmOrIndex, Index, StringTable}, symbol_table::{SymbolKind, SymbolTable}, tokens::Token, - variant, }; #[derive(Debug, thiserror::Error)] @@ -251,7 +251,7 @@ impl Tree { .map(|(_, c)| c) .collect::>(); - let value = crate::lexer::bigint::BigInt::parse_digits(digits, radix); + let value = BigInt::parse_digits(digits, radix); let ty = match iter.clone().next() { Some((_, 'u')) | Some((_, 'i')) => { diff --git a/src/string_table.rs b/src/string_table.rs index bc837cd..88197f8 100644 --- a/src/string_table.rs +++ b/src/string_table.rs @@ -57,7 +57,7 @@ impl StringTable { let ints = unsafe { core::slice::from_raw_parts(bytes.as_ptr().cast::(), bytes.len() / 4) }; - crate::lexer::bigint::count_bits(ints) + bigint::count_bits(ints) as u32 } } } @@ -151,3 +151,5 @@ mod display { } pub use display::ImmOrIndexDisplay; + +use crate::comptime::bigint;