pub mod bigint { use core::{ cmp::{Eq, Ord, Ordering, PartialOrd}, fmt::Debug, ops::{Add, AddAssign, Div, Mul, Not, Rem, Shl, Shr, Sub, SubAssign}, }; use crate::lexer::Radix; /// A base-4_294_967_295 number. #[derive(Clone)] pub struct BigInt(Vec); impl BigInt { pub fn parse_digits>(text: C, radix: Radix) -> BigInt { Self(parse_bigint(text.into_iter(), radix)) } pub fn from_u32(v: u32) -> BigInt { Self(vec![v]) } pub fn from_u64(v: u64) -> BigInt { let (lo, hi) = into_lo_hi(v); Self(vec![lo, hi]) } pub fn one() -> BigInt { Self(vec![1]) } pub fn zero() -> BigInt { Self(vec![]) } pub fn num_digits(&self) -> usize { self.0.iter().rposition(|&d| d != 0).map_or(0, |i| i + 1) } pub fn bit_width(&self) -> usize { count_bits(&self.0) } pub fn is_zero(&self) -> bool { bigint_is_zero(&self.0) } pub fn is_one(&self) -> bool { bigint_is_one(&self.0) } pub fn is_power_of_two(&self) -> bool { is_power_of_two(&self.0) } pub fn trailing_zeros(&self) -> usize { trailing_zeros(&self.0) } pub fn from_bytes_le(bytes: &[u8]) -> BigInt { let data = bytes .chunks(4) .map(|chunk| { let mut int = [0u8; 4]; int[..chunk.len()].copy_from_slice(chunk); u32::from_le_bytes(int) }) .collect::>(); BigInt(data) } pub fn into_bytes_le(&self) -> Vec { let mut bytes = Vec::::new(); for d in &self.0[..] { bytes.extend(&d.to_le_bytes()); } let count = bytes.iter().rev().take_while(|&&b| b == 0).count(); bytes.truncate(bytes.len() - count); bytes } pub fn normalise(&mut self) { let len = self.0.iter().rposition(|&d| d != 0).map_or(0, |i| i + 1); self.0.truncate(len); } pub fn normalised(mut self) -> BigInt { self.normalise(); self } pub fn into_u64(&self) -> Option { if self.0.len() <= 2 { let mut bytes = [0u8; 8]; self.0 .get(0) .map(|&dw| bytes[..4].copy_from_slice(&dw.to_le_bytes())); self.0 .get(1) .map(|&dw| bytes[4..].copy_from_slice(&dw.to_le_bytes())); Some(u64::from_le_bytes(bytes)) } else { None } } } impl PartialEq for BigInt { fn eq(&self, other: &Self) -> bool { cmp_bigint(&self.0, &other.0) == Ordering::Equal } } impl PartialEq for BigInt { fn eq(&self, other: &u32) -> bool { self.num_digits() == 1 && self.0[0] == *other } } impl PartialEq for BigInt { fn eq(&self, other: &u64) -> bool { let (lo, hi) = into_lo_hi(*other); cmp_bigint(&self.0, &[lo, hi]) == Ordering::Equal } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &u32) -> Option { (self.num_digits() == 1).then(|| self.0[0].cmp(other)) } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &u64) -> Option { let (lo, hi) = into_lo_hi(*other); Some(cmp_bigint(&self.0, &[lo, hi])) } } impl Eq for BigInt {} impl PartialOrd for BigInt { fn partial_cmp(&self, other: &Self) -> Option { Some(cmp_bigint(&self.0, &other.0)) } } impl Ord for BigInt { fn cmp(&self, other: &Self) -> std::cmp::Ordering { cmp_bigint(&self.0, &other.0) } } impl Shl for BigInt { type Output = Self; fn shl(mut self, rhs: usize) -> Self::Output { shl_bitint(&mut self.0, rhs); self } } impl Shr for BigInt { type Output = Self; fn shr(mut self, rhs: usize) -> Self::Output { shr_bitint(&mut self.0, rhs); self } } impl Add for BigInt { type Output = Self; fn add(mut self, mut rhs: Self) -> Self::Output { let (mut digits, carry) = if self.0.len() > rhs.0.len() { let c = add_bigint(&mut self.0, &rhs.0); (self.0, c) } else { let c = add_bigint(&mut rhs.0, &self.0); (rhs.0, c) }; if carry { digits.push(u32::from(carry)); } BigInt(digits) } } impl Add for BigInt { type Output = Self; fn add(mut self, rhs: u32) -> Self::Output { self += rhs; self } } impl AddAssign for BigInt { fn add_assign(&mut self, rhs: u32) { let carry = add_bigint_scalar(&mut self.0, rhs); if carry { self.0.push(carry as u32); } } } impl Add for BigInt { type Output = Self; fn add(mut self, rhs: u64) -> Self::Output { self += rhs; self } } impl AddAssign for BigInt { fn add_assign(&mut self, rhs: u64) { let (lo, hi) = into_lo_hi(rhs); if hi == 0 { *self += lo; } else { while self.num_digits() < 2 { self.0.push(0); } let carry = add_bigint(&mut self.0, &[lo, hi]); if carry { self.0.push(carry as u32); } } } } impl Sub for BigInt { type Output = Self; fn sub(mut self, rhs: Self) -> Self::Output { if self.0.len() < rhs.0.len() { println!("extending self by {} zeroes", rhs.0.len() - self.0.len()); self.0 .extend(core::iter::repeat(0).take(rhs.0.len() - self.0.len())); println!("self: {self:?}"); } sub_bigint(&mut self.0, &rhs.0); self } } impl Sub for BigInt { type Output = Self; fn sub(mut self, rhs: u32) -> Self::Output { self -= rhs; self } } impl Sub for u32 { type Output = BigInt; fn sub(self, mut rhs: BigInt) -> Self::Output { if rhs.0.is_empty() { rhs.0.push(self); } else { sub_bigint_in_right(&[self], &mut rhs.0); } rhs.normalised() } } impl Sub for u64 { type Output = BigInt; fn sub(self, mut rhs: BigInt) -> Self::Output { while rhs.num_digits() < 2 { rhs.0.push(0); } let (lo, hi) = into_lo_hi(self); sub_bigint_in_right(&[lo, hi], &mut rhs.0); rhs.normalised() } } impl SubAssign for BigInt { fn sub_assign(&mut self, rhs: u32) { sub_bigint_scalar(&mut self.0, rhs); } } impl Sub for BigInt { type Output = Self; fn sub(mut self, rhs: u64) -> Self::Output { self -= rhs; self } } impl SubAssign for BigInt { fn sub_assign(&mut self, rhs: u64) { let (lo, hi) = into_lo_hi(rhs); while self.num_digits() < 2 { self.0.push(0); } sub_bigint(&mut self.0, &[lo, hi]); self.normalise(); } } impl Mul for BigInt { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { BigInt(mul_bigint(&self.0, &rhs.0)) } } impl Mul for BigInt { type Output = Self; fn mul(mut self, rhs: u32) -> Self::Output { u32_mul_bigint(&mut self.0, rhs); self } } impl Mul for BigInt { type Output = Self; fn mul(self, rhs: u64) -> Self::Output { let (lo, hi) = into_lo_hi(rhs); BigInt(mul_bigint(&self.0, &[lo, hi])) } } impl Div for BigInt { type Output = Self; fn div(self, rhs: Self) -> Self::Output { div_rem_bigint(self, rhs).0 } } impl Div for BigInt { type Output = Self; fn div(self, rhs: u32) -> Self::Output { div_digit_bigint(self, rhs).0 } } impl Div for BigInt { type Output = Self; fn div(self, rhs: u64) -> Self::Output { let (lo, hi) = into_lo_hi(rhs); div_rem_bigint(self, BigInt([lo, hi].to_vec())).0 } } impl Rem for BigInt { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { div_rem_bigint(self, rhs).1 } } impl Rem for BigInt { type Output = Self; fn rem(self, rhs: u32) -> Self::Output { BigInt::zero() + div_digit_bigint(self, rhs).1 } } impl Rem for BigInt { type Output = Self; fn rem(self, rhs: u64) -> Self::Output { let (lo, hi) = into_lo_hi(rhs); div_rem_bigint(self, BigInt([lo, hi].to_vec())).1 } } impl Not for BigInt { type Output = Self; fn not(mut self) -> Self::Output { self.0.iter_mut().for_each(|c| *c = !*c); self } } impl Debug for BigInt { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut list = f.debug_list(); list.entries(self.0.iter().rev()).finish() } } /// counts used bits in a u32 slice, discards leading zeros in MSB. /// `[0xff,0xff,0x00,0x00]` -> 16 /// `[0xff,0xff,0x00]` -> 16 /// `[0xff,0xff,0x0f]` -> 20 pub fn count_bits(bytes: &[u32]) -> usize { let mut bits = bytes.len() * u32::BITS as usize; for &d in bytes.iter().rev() { if d == 0 { bits -= u32::BITS as usize; } else { bits -= d.leading_zeros() as usize; break; } } bits } #[test] fn test_count_bits() { assert_eq!(count_bits(&[0xffffffff, 0x00, 0x00]), 32); assert_eq!(count_bits(&[0x00, 0x00, 0x00]), 0); assert_eq!(count_bits(&[]), 0); assert_eq!(count_bits(&[0xffffffff, 0xff, 0x00]), 40); assert_eq!(count_bits(&[0xffffffff, 0xff]), 40); assert_eq!(count_bits(&[0xffffffff, 0xff, 0xffff]), 64 + 16); } #[test] fn test_count_trailing_zeros() { assert_eq!(trailing_zeros(&[0xffffffff, 0x00, 0x00]), 0); assert_eq!(trailing_zeros(&[0x00, 0x00, 0x00]), 0); assert_eq!(trailing_zeros(&[]), 0); assert_eq!(trailing_zeros(&[0x00, 0xffffffff, 0xff]), 32); assert_eq!(trailing_zeros(&[0x00, 0xffffff00, 0xff]), 40); } #[allow(unused)] /// lhs <=> rhs fn cmp_bigint(lhs: &[u32], rhs: &[u32]) -> Ordering { use core::cmp::Ordering; let lhs_bits = count_bits(lhs); let rhs_bits = count_bits(rhs); match lhs_bits.cmp(&rhs_bits) { Ordering::Less => Ordering::Less, Ordering::Greater => Ordering::Greater, Ordering::Equal => { for (a, b) in lhs[..(lhs_bits / u32::BITS as usize)] .iter() .zip(rhs[..(lhs_bits / u32::BITS as usize)].iter()) .rev() { let ord = a.cmp(b); if ord != Ordering::Equal { return ord; } } return Ordering::Equal; } } } fn bigint_is_zero(lhs: &[u32]) -> bool { if lhs.len() == 0 { true } else { lhs.iter().all(|c| c == &0) } } #[allow(dead_code)] fn bigint_is_one(lhs: &[u32]) -> bool { lhs.len() > 0 && lhs[0] == 1 && lhs[1..].iter().all(|c| c == &0) } #[allow(dead_code)] fn bitnot_bigint(lhs: &mut [u32]) { for d in lhs.iter_mut() { *d = !*d; } } #[allow(dead_code)] fn u32_mul_bigint(lhs: &mut Vec, scalar: u32) { match scalar { 0 => { lhs.clear(); lhs.push(0) } 1 => {} _ => { if scalar.is_power_of_two() { lhs.push(0); shl_bitint(lhs.as_mut_slice(), scalar.trailing_zeros() as usize); } else { let mut carry = 0; for a in lhs.iter_mut() { (*a, carry) = (*a).carrying_mul(scalar, carry); } if carry != 0 { lhs.push(carry); } } } } } #[allow(dead_code)] fn u64_mul_bigint(lhs: &mut Vec, scalar: u64) { let lo = scalar as u32; let hi = (scalar >> 32) as u32; u32_mul_bigint(lhs, lo); shl_bitint(lhs, 32); u32_mul_bigint(lhs, hi); } #[allow(dead_code)] fn mul_bigint(lhs: &[u32], rhs: &[u32]) -> Vec { if bigint_is_zero(lhs) || bigint_is_zero(rhs) { return vec![]; } let len = lhs.len() + rhs.len() + 1; let mut product = vec![0u32; len]; for (bth, &b) in rhs.iter().enumerate() { let mut carry = 0u32; for (ath, &a) in lhs.iter().enumerate() { let prod; (prod, carry) = a.carrying_mul(b, carry); let (digit, c) = product[ath + bth].carrying_add(prod, false); carry += c as u32; product[ath + bth] = digit; } if carry != 0 { product[bth + lhs.len()] += carry; } } product } #[allow(dead_code)] fn sum_digits(digits: &[u32]) -> u64 { let mut sum = 0u64; let mut carry = false; for &d in digits { (sum, carry) = sum.carrying_add(d as u64, carry); } sum + carry as u64 } #[allow(dead_code)] fn count_ones(lhs: &[u32]) -> usize { lhs.iter() .fold(0usize, |acc, c| acc + c.count_ones() as usize) } #[allow(dead_code)] fn trailing_zeros(lhs: &[u32]) -> usize { lhs.iter() .enumerate() .find(|(_, &c)| c != 0) .map(|(i, &c)| i * u32::BITS as usize + c.trailing_zeros() as usize) .unwrap_or(0) } #[allow(dead_code)] fn is_power_of_two(lhs: &[u32]) -> bool { count_ones(lhs) == 1 } #[allow(dead_code)] /// divident must be at least as wide as divisor /// returns (quotient, remainder) pub fn div_rem_bigint_ref(divident: &BigInt, divisor: &BigInt) -> (BigInt, BigInt) { if bigint_is_zero(&divisor.0) { panic!("divide by zero!"); } if bigint_is_zero(÷nt.0) { return (BigInt::zero(), BigInt::zero()); } use core::cmp::Ordering; match cmp_bigint(÷nt.0, &divisor.0) { Ordering::Less => return (BigInt::zero(), divident.clone()), Ordering::Equal => { return (BigInt::one(), BigInt::zero()); } Ordering::Greater => {} } if divisor.is_power_of_two() { let exp = divisor.trailing_zeros(); let (div, rem) = divident.0.split_at(exp.div_floor(u32::BITS as usize)); let (mut div, mut rem) = (div.to_vec(), rem.to_vec()); shr_bitint(&mut div, exp % u32::BITS as usize); let mask = (1u32 << exp as u32 % u32::BITS) - 1; if let Some(last) = rem.last_mut() { *last &= mask; } return (BigInt(div), BigInt(rem)); } if divisor.num_digits() == 1 { if divisor.0[0] == 1 { return (divident.clone(), BigInt::zero()); } let (div, rem) = div_digit_bigint(divident.clone(), divisor.0[0]); let rem = BigInt::zero() + rem; return (div, rem); } let shift = divisor.0.last().unwrap().leading_zeros() as usize; if shift == 0 { div_rem_core(divident.clone(), &divisor.0) } else { let (q, r) = div_rem_core(divident.clone() << shift, &(divisor.clone() << shift).0); (q, r >> shift) } } #[allow(dead_code)] /// divident must be at least as wide as divisor /// returns (quotient, remainder) pub fn div_rem_bigint(divident: BigInt, divisor: BigInt) -> (BigInt, BigInt) { let divident = divident.normalised(); let mut divisor = divisor.normalised(); if bigint_is_zero(&divisor.0) { panic!("divide by zero!"); } if bigint_is_zero(÷nt.0) { return (BigInt::zero(), BigInt::zero()); } use core::cmp::Ordering; match cmp_bigint(÷nt.0, &divisor.0) { Ordering::Less => return (BigInt::zero(), divident), Ordering::Equal => { return (BigInt::one(), BigInt::zero()); } Ordering::Greater => {} } if divisor.is_power_of_two() { let exp = divisor.trailing_zeros(); let (div, rem) = divident.0.split_at(exp.div_floor(u32::BITS as usize)); let (mut div, mut rem) = (div.to_vec(), rem.to_vec()); shr_bitint(&mut div, exp % u32::BITS as usize); let mask = (1u32 << exp as u32 % u32::BITS) - 1; if let Some(last) = rem.last_mut() { *last &= mask; } return (BigInt(div), BigInt(rem)); } if divisor.num_digits() == 1 { if divisor.0[0] == 1 { return (divident, BigInt::zero()); } let (div, rem) = div_digit_bigint(divident, divisor.0[0]); divisor.0.clear(); divisor.0.push(rem); return (div, divisor); } let shift = divisor.0.last().unwrap().leading_zeros() as usize; if shift == 0 { div_rem_core(divident, &divisor.0) } else { let (q, r) = div_rem_core(divident << shift, &(divisor << shift).0); (q, r >> shift) } } fn scalar_div_wide(hi: u32, lo: u32, divisor: u32) -> (u32, u32) { let (div, rem); unsafe { core::arch::asm! { "div {0:e}", in(reg) divisor, inout("dx") hi => rem, inout("ax") lo => div, } } (div, rem) } #[allow(dead_code)] fn div_digit_bigint(mut divident: BigInt, divisor: u32) -> (BigInt, u32) { assert!(divisor != 0); let mut rem = 0; for d in divident.0.iter_mut().rev() { (*d, rem) = scalar_div_wide(rem, *d, divisor); } (divident.normalised(), rem) } fn from_lo_hi(lo: u32, hi: u32) -> u64 { lo as u64 | (hi as u64) << 32 } fn into_lo_hi(qword: u64) -> (u32, u32) { (qword as u32, (qword >> 32) as u32) } // from rust num_bigint /// Subtract a multiple. /// a -= b * c /// Returns a borrow (if a < b then borrow > 0). fn sub_mul_digit_same_len(a: &mut [u32], b: &[u32], c: u32) -> u32 { assert!(a.len() == b.len()); // carry is between -big_digit::MAX and 0, so to avoid overflow we store // offset_carry = carry + big_digit::MAX let mut offset_carry = u32::MAX; for (x, y) in a.iter_mut().zip(b) { // We want to calculate sum = x - y * c + carry. // sum >= -(big_digit::MAX * big_digit::MAX) - big_digit::MAX // sum <= big_digit::MAX // Offsetting sum by (big_digit::MAX << big_digit::BITS) puts it in DoubleBigDigit range. let offset_sum = from_lo_hi(u32::MAX, *x) - u32::MAX as u64 + offset_carry as u64 - *y as u64 * c as u64; let (new_x, new_offset_carry) = into_lo_hi(offset_sum); offset_carry = new_offset_carry; *x = new_x; } // Return the borrow. u32::MAX - offset_carry } // from rust num_bigint fn div_rem_core(mut a: BigInt, b: &[u32]) -> (BigInt, BigInt) { // sanity check on fast paths assert!(a.0.len() >= b.len() && b.len() > 1); // a0 stores an additional extra most significant digit of the dividend, not stored in a. let mut a0 = 0; // [b1, b0] are the two most significant digits of the divisor. They never change. let b0 = b[b.len() - 1]; let b1 = b[b.len() - 2]; let q_len = a.0.len() - b.len() + 1; let mut q = BigInt(vec![0; q_len]); for j in (0..q_len).rev() { assert!(a.0.len() == b.len() + j); let a1 = *a.0.last().unwrap(); let a2 = a.0[a.0.len() - 2]; // The first q0 estimate is [a1,a0] / b0. It will never be too small, it may be too large // by at most 2. let (mut q0, mut r) = if a0 < b0 { let (q0, r) = scalar_div_wide(a0, a1, b0); (q0, r as u64) } else { assert!(a0 == b0); // Avoid overflowing q0, we know the quotient fits in BigDigit. // [a1,a0] = b0 * (1< a0 { // q0 is too large. We need to add back one multiple of b. q0 -= 1; borrow -= add_bigint(&mut a.0[j..], b) as u32; } // The top digit of a, stored in a0, has now been zeroed. assert!(borrow == a0); q.0[j] = q0; // Pop off the next top digit of a. a0 = a.0.pop().unwrap(); } a.0.push(a0); a.normalise(); assert_eq!(cmp_bigint(&a.0, b), core::cmp::Ordering::Less); (q.normalised(), a) } #[allow(unused)] fn shr_bitint(lhs: &mut [u32], shift: usize) { if bigint_is_zero(lhs) || shift == 0 { return; } let len = lhs.len(); let digit_offset = shift / 32; let bit_shift = shift % 32; if digit_offset != 0 { lhs.copy_within(digit_offset..len, 0); lhs[(len - digit_offset)..].fill(0); } if bit_shift != 0 { let lo_mask = (1u32 << (u32::BITS as usize - bit_shift)) - 1; let hi_mask = !lo_mask; eprintln!("lhs >> {shift}"); eprintln!("\tdigit_offset: {digit_offset}"); eprintln!("\tbit_shift: {bit_shift}"); eprintln!("\tlo_mask: 0b{lo_mask:0>32b}"); eprintln!("\thi_mask: 0b{hi_mask:0>32b}"); let mut carry = 0u32; for i in 0..lhs.len() { let digit = ((lhs[i] as u64) << 32) >> bit_shift; let lo = digit as u32; let hi = (digit >> 32) as u32; lhs[i] &= hi_mask; lhs[i] |= hi; if i > 0 { lhs[i - 1] &= lo_mask; lhs[i - 1] |= lo; } } } } #[allow(unused)] /// lhs must have shift / 32 + 1 digits past the last bit if shifted-past bits are desired. fn shl_bitint(lhs: &mut [u32], shift: usize) { if bigint_is_zero(lhs) || shift == 0 { return; } let len = lhs.len(); let digit_offset = shift / 32; let bit_shift = shift % 32; if digit_offset != 0 { lhs.copy_within(0..(len - digit_offset), digit_offset); lhs[..digit_offset].fill(0); } if bit_shift != 0 { let hi_mask = (1u32 << bit_shift) - 1; let lo_mask = !hi_mask; eprintln!("lhs << {shift}"); eprintln!("\tdigit_offset: {digit_offset}"); eprintln!("\tbit_shift: {bit_shift}"); eprintln!("\tlo_mask: 0b{lo_mask:0>32b}"); eprintln!("\thi_mask: 0b{hi_mask:0>32b}"); // example with u8 digits, shift = 3; // hi_mask = 0b00000111 // lo_mask = 0b11111000 // // lhs[i] as u16 = 0b00000000_01111000 // digit = 0b00000011_11000000 // hi = 0b00000011 // lo = 0b11000000 let mut carry = 0u32; for i in (digit_offset..len).rev() { let digit = (lhs[i] as u64) << bit_shift; let lo = digit as u32; let hi = (digit >> 32) as u32; lhs[i] &= lo_mask; lhs[i] |= lo; if i + 1 < len { lhs[i + 1] &= hi_mask; lhs[i + 1] |= hi; } } } } #[allow(unused)] /// lhs must be bigger than rhs fn sub_bigint(lhs: &mut [u32], rhs: &[u32]) { if bigint_is_zero(rhs) { return; } let len = lhs.len().min(rhs.len()); let (l_lo, l_hi) = lhs.split_at_mut(len); let (r_lo, r_hi) = rhs.split_at(len); println!("lhs: {{ lo: {l_lo:?}, hi: {l_hi:?} }}"); println!("rhs: {{ lo: {r_lo:?}, hi: {r_hi:?} }}"); let mut borrow = false; for (lhs, rhs) in l_lo.iter_mut().zip(r_lo) { (*lhs, borrow) = lhs.borrowing_sub(*rhs, borrow); } if borrow { for lhs in l_hi { (*lhs, borrow) = lhs.borrowing_sub(0, borrow); } } if borrow || !r_hi.iter().all(|&v| v == 0) { panic!("sub failed: borrow: {borrow}"); } } fn sub_bigint_in_right_simple(lhs: &[u32], rhs: &mut [u32]) -> bool { assert!(lhs.len() == rhs.len()); let mut borrow = false; for (l, r) in lhs.iter().zip(rhs) { (*r, borrow) = l.borrowing_sub(*r, borrow); } borrow } fn sub_bigint_in_right(lhs: &[u32], rhs: &mut [u32]) { assert!(rhs.len() >= lhs.len()); let min_len = lhs.len().min(rhs.len()); let (r_lo, r_hi) = rhs.split_at_mut(min_len); let (l_lo, l_hi) = lhs.split_at(min_len); let borrow = sub_bigint_in_right_simple(l_lo, r_lo); assert!(l_hi.is_empty()); assert!(!borrow); assert!(r_hi.iter().all(|&d| d == 0)); } fn sub_bigint_scalar(lhs: &mut [u32], rhs: u32) { let mut rhs = Some(rhs); let mut borrow = false; for lhs in lhs.iter_mut() { (*lhs, borrow) = lhs.borrowing_sub(rhs.take().unwrap_or(0), borrow); if !borrow { break; } } if borrow { panic!("sub failed: borrow: {borrow}"); } } fn add_bigint_scalar(lhs: &mut [u32], rhs: u32) -> bool { let mut rhs = Some(rhs); let mut carry = false; for d in lhs.iter_mut() { (*d, carry) = (*d).carrying_add(rhs.take().unwrap_or(0), carry); if !carry { break; } } carry } /// lhs must be bigger than rhs /// returns carry fn add_bigint(lhs: &mut [u32], rhs: &[u32]) -> bool { if bigint_is_zero(rhs) { return false; } let (l_lo, l_hi) = lhs.split_at_mut(rhs.len()); let mut carry = false; for (lhs, rhs) in l_lo.iter_mut().zip(rhs) { (*lhs, carry) = lhs.carrying_add(*rhs, carry); } if carry { for d in l_hi.iter_mut() { (*d, carry) = d.carrying_add(0, carry); if !carry { break; } } } carry } pub fn parse_bigint(text: impl Iterator, radix: Radix) -> Vec { let digits = text .filter_map(|c| match c { '_' => None, c => Some(radix.map_digit(c)), }) .collect::>(); let (max, power) = { let radix = radix.radix() as u64; let mut power = 1; let mut base = radix; while let Some(b) = base.checked_mul(radix) { if b > u32::MAX as u64 { break; } base = b; power += 1; } (base, power) }; let radix = radix.radix() as u32; let r = digits.len() % power; let i = if r == 0 { power } else { r }; let (head, tail) = digits.split_at(i); let first = head .iter() .fold(0, |acc, &digit| acc * radix + digit as u32); let mut data = vec![first]; for chunk in tail.chunks(power) { if data.last() != Some(&0) { data.push(0); } let mut carry = 0u64; for digit in data.iter_mut() { carry += *digit as u64 * max as u64; *digit = carry as u32; carry >>= u32::BITS; } assert!(carry == 0); let next = chunk .iter() .fold(0, |acc, &digit| acc * radix + digit as u32); let (res, mut carry) = data[0].carrying_add(next, false); data[0] = res; if carry { for digit in data[1..].iter_mut() { (*digit, carry) = digit.carrying_add(0, carry); if !carry { break; } } } } data } #[cfg(test)] mod tests { use super::*; #[test] fn parse() { let bigint = BigInt::parse_digits("2_cafe_babe_dead_beef".chars(), Radix::Hex); println!("{:#x?}", bigint); let bigint = BigInt::parse_digits("f".chars(), Radix::Hex); println!("{:#x?}", bigint); } #[test] fn add() { let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex); println!("{:#x?}", a); let b = BigInt::parse_digits("cafebabe".chars(), Radix::Hex); println!("{:#x?}", b); let sum = a + b; println!("{:#x?}", sum); } #[test] fn sub() { let a = BigInt::parse_digits("deadbeef".chars(), Radix::Hex); println!("{:#x?}", a); let b = BigInt::parse_digits("56d2c".chars(), Radix::Hex); println!("{:#x?}", b); let sum = a - b; println!("{:#x?}", sum); } #[test] fn overflowing_sub() { let a = BigInt::parse_digits("2_0000_0000_0000_0000".chars(), Radix::Hex); println!("{:#x?}", a); let b = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex); println!("{:#x?}", b); let sum = b - a; println!("{:#x?}", sum); } #[test] fn shr() { let mut a = BigInt::parse_digits("cafe_babe_0000".chars(), Radix::Hex); print!("{:0>8x?} >> 32 ", a); shr_bitint(&mut a.0, 32); println!("{:0>8x?}", a); let mut a = BigInt::parse_digits("11110000".chars(), Radix::Bin); print!("{:0>8x?} >> 32 ", a); shr_bitint(&mut a.0, 3); println!("{:0>8x?}", a); } #[test] fn shl() { let mut a = BigInt::parse_digits("ffff_ffff".chars(), Radix::Hex); a.0.extend([0; 4]); println!("{:0>8x?}", a); shl_bitint(&mut a.0, 40); println!("{:0>8x?}", a); } #[test] fn div() { let a = BigInt::parse_digits("cafebabe".chars(), Radix::Hex); let b = BigInt::parse_digits("dead".chars(), Radix::Hex); let (div, rem) = div_rem_bigint(a, b); println!("div: {:0>8x?}", div); println!("rem: {:0>8x?}", rem); } } } pub mod bigsint { use std::{ cmp::Ordering, ops::{Add, AddAssign, Div, Mul, Neg, Not, Rem, Shl, Shr, Sub, SubAssign}, }; use super::bigint::{self, *}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] enum Sign { Negative = 0, None = 1, Positive = 2, } impl Neg for Sign { type Output = Self; fn neg(self) -> Self::Output { match self { Sign::Negative => Self::Positive, Sign::None => Self::None, Sign::Positive => Self::Negative, } } } /// A base-4_294_967_295 number. #[derive(Clone, Debug, Eq, Ord)] pub struct BigSInt { sign: Sign, bigint: BigInt, } impl BigSInt { pub fn zero() -> BigSInt { Self { sign: Sign::None, bigint: BigInt::zero(), } } pub fn one() -> BigSInt { Self { sign: Sign::Positive, bigint: BigInt::one(), } } pub fn positive(bigint: BigInt) -> BigSInt { Self { sign: Sign::Positive, bigint, } } pub fn from_u32(v: u32) -> BigSInt { let sign = core::num::NonZero::new(v) .map(|_| Sign::Positive) .unwrap_or(Sign::None); Self { sign, bigint: BigInt::from_u32(v), } } pub fn from_u64(v: u64) -> BigSInt { let sign = core::num::NonZero::new(v) .map(|_| Sign::Positive) .unwrap_or(Sign::None); Self { sign, bigint: BigInt::from_u64(v), } } pub fn from_i32(v: i32) -> BigSInt { if v >= 0 { Self::from_u32(v as u32) } else { let v = u32::MAX - (v as u32) + 1; Self { sign: Sign::Negative, bigint: BigInt::from_u32(v), } } } pub fn from_i64(v: i64) -> BigSInt { if v >= 0 { Self::from_u64(v as u64) } else { let v = u64::MAX - (v as u64) + 1; Self { sign: Sign::Negative, bigint: BigInt::from_u64(v), } } } pub fn is_negative(&self) -> bool { self.sign == Sign::Negative } } impl PartialEq for BigSInt { fn eq(&self, other: &Self) -> bool { self.sign == other.sign && self.bigint == other.bigint } } impl PartialOrd for BigSInt { fn partial_cmp(&self, other: &Self) -> Option { match self.sign.partial_cmp(&other.sign) { Some(core::cmp::Ordering::Equal) => {} ord => return ord, } self.bigint.partial_cmp(&other.bigint) } } impl Not for BigSInt { type Output = Self; fn not(mut self) -> Self::Output { match self.sign { Sign::Negative => { self.bigint -= 1u32; self.sign = if self.bigint.is_zero() { Sign::None } else { Sign::Positive }; } Sign::None | Sign::Positive => { self.bigint += 1u32; self.sign = Sign::Negative; } } self } } impl Add for BigSInt { type Output = Self; fn add(self, rhs: Self) -> Self::Output { match (self.sign, rhs.sign) { (_, Sign::None) => self, (Sign::None, _) => rhs, (Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => Self { sign: self.sign, bigint: self.bigint + rhs.bigint, }, (Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => { match self.bigint.cmp(&rhs.bigint) { Ordering::Less => Self { sign: rhs.sign, bigint: rhs.bigint - self.bigint, }, Ordering::Equal => Self::zero(), Ordering::Greater => Self { sign: self.sign, bigint: self.bigint - rhs.bigint, }, } } } } } impl Add for BigSInt { type Output = BigSInt; fn add(self, rhs: u32) -> Self::Output { match self.sign { Sign::Negative => match self.bigint.partial_cmp(&rhs).unwrap() { Ordering::Less => Self::positive(rhs - self.bigint), Ordering::Equal => Self::zero(), Ordering::Greater => -Self::positive(self.bigint - rhs), }, Sign::None => Self::from_u32(rhs), Sign::Positive => Self::positive(self.bigint + rhs), } } } impl Add for BigSInt { type Output = BigSInt; fn add(self, rhs: u64) -> Self::Output { match self.sign { Sign::Negative => match self.bigint.partial_cmp(&rhs).unwrap() { Ordering::Less => Self::positive(rhs - self.bigint), Ordering::Equal => Self::zero(), Ordering::Greater => -Self::positive(self.bigint - rhs), }, Sign::None => Self::from_u64(rhs), Sign::Positive => Self::positive(self.bigint + rhs), } } } impl AddAssign for BigSInt { fn add_assign(&mut self, rhs: Self) { let n = core::mem::replace(self, Self::zero()); *self = n + rhs; } } impl Sub for BigSInt { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { match (self.sign, rhs.sign) { (_, Sign::None) => self, (Sign::None, _) => -rhs, (Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => Self { sign: self.sign, bigint: self.bigint + rhs.bigint, }, (Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => { match self.bigint.cmp(&rhs.bigint) { Ordering::Less => Self { sign: -self.sign, bigint: rhs.bigint - self.bigint, }, Ordering::Equal => Self::zero(), Ordering::Greater => Self { sign: self.sign, bigint: self.bigint - rhs.bigint, }, } } } } } impl SubAssign for BigSInt { fn sub_assign(&mut self, rhs: Self) { let n = core::mem::replace(self, Self::zero()); *self = n - rhs; } } impl Shl for BigSInt { type Output = Self; fn shl(self, rhs: usize) -> Self::Output { Self { sign: self.sign, bigint: self.bigint << rhs, } } } impl Shr for BigSInt { type Output = Self; fn shr(self, rhs: usize) -> Self::Output { let rounding = shr_rounding(&self, rhs); let mut out = Self { sign: self.sign, bigint: self.bigint >> rhs, }; if rounding { out.bigint += 1u32; } out } } impl Mul for Sign { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { use Sign::*; match (self, rhs) { (Negative, Negative) | (Positive, Positive) => Positive, (None, _) | (_, None) => todo!(), (Negative, Positive) | (Positive, Negative) => Negative, } } } impl Mul for BigSInt { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { Self { sign: self.sign * rhs.sign, bigint: self.bigint * rhs.bigint, } } } impl Div for BigSInt { type Output = Self; fn div(self, rhs: Self) -> Self::Output { div_rem_bigsint(&self, &rhs).0 } } impl Rem for BigSInt { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { div_rem_bigsint(&self, &rhs).1 } } fn div_rem_bigsint(lhs: &BigSInt, rhs: &BigSInt) -> (BigSInt, BigSInt) { let (q, r) = bigint::div_rem_bigint_ref(&lhs.bigint, &rhs.bigint); let q = BigSInt { sign: lhs.sign, bigint: q, }; let r = BigSInt { sign: lhs.sign, bigint: r, }; if rhs.is_negative() { (-q, r) } else { (q, r) } } fn shr_rounding(lhs: &BigSInt, shift: usize) -> bool { if lhs.is_negative() { let ctz = lhs.bigint.trailing_zeros(); shift > 0 && ctz < shift } else { false } } impl Neg for BigSInt { type Output = Self; fn neg(mut self) -> Self::Output { self.sign = -self.sign; self } } } use std::{ cmp::Ordering, ops::{BitAnd, BitOr, BitXor, Not}, }; use num_bigint::{BigInt, BigUint, Sign}; use num_traits::{cast::ToPrimitive, ToBytes}; use crate::ast::{FloatingType, IntegralType, Type}; #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Incompatible Comptime Number variants.")] IncompatibleTypes, #[error("Integer overflow.")] IntegerOverflow, #[error("Shift cannot fit into u32.")] ShiftTooLarge, #[error("Cannot negate unsigned integer")] UnsignedNegation, #[error("Incomparable floats.")] FloatingCmp, #[error("Not a comptime expression.")] NotComptime, } pub type Result = core::result::Result; #[derive(Debug, PartialEq, Eq)] pub enum ComptimeInt { Native { bits: u128, ty: IntegralType }, BigInt { bits: BigInt, ty: IntegralType }, Comptime(BigInt), } impl ComptimeInt { pub fn add(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.checked_add(b).ok_or(Error::IntegerOverflow)?; if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let width = ty.bits - ty.signed as u16; let bits = a + b; if bits.bits() > width as u64 { Err(Error::IntegerOverflow) } else { Ok(Self::BigInt { bits, ty }) } } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a + b)), _ => { unreachable!() } } } pub fn sub(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.checked_sub(b).ok_or(Error::IntegerOverflow)?; if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let width = ty.bits - ty.signed as u16; let bits = a - b; if bits.bits() > width as u64 { Err(Error::IntegerOverflow) } else { Ok(Self::BigInt { bits, ty }) } } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a - b)), _ => { unreachable!() } } } pub fn mul(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.checked_mul(b).ok_or(Error::IntegerOverflow)?; if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let width = ty.bits - ty.signed as u16; let bits = a * b; if bits.bits() > width as u64 { Err(Error::IntegerOverflow) } else { Ok(Self::BigInt { bits, ty }) } } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a * b)), _ => { unreachable!() } } } pub fn div(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.checked_div(b).ok_or(Error::IntegerOverflow)?; if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let width = ty.bits - ty.signed as u16; let bits = a / b; if bits.bits() > width as u64 { Err(Error::IntegerOverflow) } else { Ok(Self::BigInt { bits, ty }) } } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a / b)), _ => { unreachable!() } } } pub fn rem(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.checked_rem(b).ok_or(Error::IntegerOverflow)?; if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let width = ty.bits - ty.signed as u16; let bits = a % b; if bits.bits() > width as u64 { Err(Error::IntegerOverflow) } else { Ok(Self::BigInt { bits, ty }) } } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a % b)), _ => { unreachable!() } } } pub fn bitand(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.bitand(b); Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let bits = a & b; Ok(Self::BigInt { bits, ty }) } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a & b)), _ => { unreachable!() } } } pub fn bitor(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.bitor(b); Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let bits = a | b; Ok(Self::BigInt { bits, ty }) } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a | b)), _ => { unreachable!() } } } pub fn bitxor(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; match (a, b) { (ComptimeInt::Native { bits: a, ty }, ComptimeInt::Native { bits: b, .. }) => { let bits = a.bitxor(b); Ok(Self::Native { bits, ty }) } (ComptimeInt::BigInt { bits: a, ty }, ComptimeInt::BigInt { bits: b, .. }) => { let bits = a ^ b; Ok(Self::BigInt { bits, ty }) } (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => Ok(Self::Comptime(a ^ b)), _ => { unreachable!() } } } pub fn cmp(self, other: Self) -> Result { let (a, b) = self.coalesce(other)?; let ord = match (a, b) { (ComptimeInt::Native { bits: a, .. }, ComptimeInt::Native { bits: b, .. }) => a.cmp(&b), (ComptimeInt::BigInt { bits: a, .. }, ComptimeInt::BigInt { bits: b, .. }) => a.cmp(&b), (ComptimeInt::Comptime(a), ComptimeInt::Comptime(b)) => a.cmp(&b), _ => { unreachable!() } }; Ok(ord) } pub fn shl(self, other: Self) -> Result { use core::ops::Shl; let shift = other.try_to_u32()?; match self { ComptimeInt::Native { bits, ty } => { let bits = if ty.signed { (bits as i128) .checked_shl(shift) .ok_or(Error::IntegerOverflow)? as u128 } else { (bits as u128) .checked_shl(shift) .ok_or(Error::IntegerOverflow)? as u128 } & ty.u128_bitmask(); Ok(Self::Native { bits, ty }) } ComptimeInt::BigInt { bits, ty } => { let mut bits = bits.shl(shift); for i in 0..shift as u16 { bits.set_bit((i * ty.bits) as u64, false); } Ok(Self::BigInt { bits, ty }) } ComptimeInt::Comptime(bits) => Ok(Self::Comptime(bits.shl(shift))), } } pub fn shr(self, other: Self) -> Result { use core::ops::Shr; let shift = other.try_to_u32()?; match self { ComptimeInt::Native { bits, ty } => { let bits = if ty.signed { (bits as i128) .checked_shr(shift) .ok_or(Error::IntegerOverflow)? as u128 } else { (bits as u128) .checked_shr(shift) .ok_or(Error::IntegerOverflow)? as u128 }; Ok(Self::Native { bits, ty }) } ComptimeInt::BigInt { bits, ty } => Ok(Self::BigInt { bits: bits.shr(shift), ty, }), ComptimeInt::Comptime(bits) => Ok(Self::Comptime(bits.shr(shift))), } } pub fn neg(self) -> Result { match self { Self::Native { bits: a, ty } => { if ty.signed { return Err(Error::UnsignedNegation); } let bits = (a as i128).checked_neg().ok_or(Error::IntegerOverflow)? as u128; if bits & !ty.u128_bitmask() != 0 { return Err(Error::IntegerOverflow); } Ok(Self::Native { bits, ty }) } Self::Comptime(a) => Ok(Self::Comptime(-a)), Self::BigInt { bits, ty } => Ok(Self::BigInt { bits: -bits, ty }), } } pub fn not(self) -> Result { match self { ComptimeInt::Native { bits, ty } => Ok(Self::Native { bits: !bits | ty.u128_bitmask(), ty, }), ComptimeInt::BigInt { bits, ty } => Ok(Self::BigInt { bits: !bits, ty }), ComptimeInt::Comptime(bigint) => Ok(Self::Comptime(!bigint)), } } fn try_to_u32(&self) -> Result { match self { ComptimeInt::Native { bits, .. } => bits.to_u32(), ComptimeInt::BigInt { bits, .. } => bits.to_u32(), ComptimeInt::Comptime(bits) => bits.to_u32(), } .ok_or(Error::ShiftTooLarge) } fn coalesce(self, other: Self) -> Result<(ComptimeInt, ComptimeInt)> { match (self, other) { (lhs @ ComptimeInt::Native { ty: a_ty, .. }, ComptimeInt::Comptime(b)) | (lhs @ ComptimeInt::Native { ty: a_ty, .. }, ComptimeInt::BigInt { bits: b, .. }) => { let b_signed = b.sign() == Sign::Minus; if !a_ty.signed && b_signed { return Err(Error::IncompatibleTypes); } let bits = b.bits() + a_ty.signed as u64; if bits as u16 > a_ty.bits { return Err(Error::IncompatibleTypes); } let b = if b_signed { b.to_i128().unwrap() as u128 } else { b.to_u128().unwrap() }; Ok((lhs, Self::Native { bits: b, ty: a_ty })) } (ComptimeInt::Comptime(b), rhs @ ComptimeInt::Native { ty: a_ty, .. }) | (ComptimeInt::BigInt { bits: b, .. }, rhs @ ComptimeInt::Native { ty: a_ty, .. }) => { let b_signed = b.sign() == Sign::Minus; if !a_ty.signed && b_signed { return Err(Error::IncompatibleTypes); } let bits = b.bits() + a_ty.signed as u64; if bits as u16 > a_ty.bits { return Err(Error::IncompatibleTypes); } let b = if b_signed { b.to_i128().unwrap() as u128 } else { b.to_u128().unwrap() }; Ok((Self::Native { bits: b, ty: a_ty }, rhs)) } (lhs @ ComptimeInt::BigInt { ty, .. }, ComptimeInt::Comptime(b)) => { let b_signed = b.sign() == Sign::Minus; if !ty.signed && b_signed { return Err(Error::IncompatibleTypes); } let bits = b.bits() + ty.signed as u64; if bits as u16 > ty.bits { return Err(Error::IncompatibleTypes); } Ok((lhs, Self::BigInt { bits: b, ty })) } (ComptimeInt::Comptime(b), rhs @ ComptimeInt::BigInt { ty, .. }) => { let b_signed = b.sign() == Sign::Minus; if !ty.signed && b_signed { return Err(Error::IncompatibleTypes); } let bits = b.bits() + ty.signed as u64; if bits as u16 > ty.bits { return Err(Error::IncompatibleTypes); } Ok((Self::BigInt { bits: b, ty }, rhs)) } (lhs @ ComptimeInt::Native { ty: a, .. }, rhs @ ComptimeInt::Native { ty: b, .. }) => { if a == b { Ok((lhs, rhs)) } else { Err(Error::IncompatibleTypes) } } (lhs @ ComptimeInt::BigInt { ty: a, .. }, rhs @ ComptimeInt::BigInt { ty: b, .. }) => { if a == b { Ok((lhs, rhs)) } else { Err(Error::IncompatibleTypes) } } (lhs, rhs) => Ok((lhs, rhs)), } } } #[derive(Debug, PartialEq)] pub enum ComptimeFloat { Binary32(f32), Binary64(f64), } impl ComptimeFloat { pub fn add(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a + b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a + b)), _ => Err(Error::IncompatibleTypes), } } pub fn sub(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a - b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a - b)), _ => Err(Error::IncompatibleTypes), } } pub fn mul(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a * b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a * b)), _ => Err(Error::IncompatibleTypes), } } pub fn div(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a / b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a / b)), _ => Err(Error::IncompatibleTypes), } } pub fn rem(self, other: Self) -> Result { match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => Ok(Self::Binary32(a % b)), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => Ok(Self::Binary64(a % b)), _ => Err(Error::IncompatibleTypes), } } pub fn neg(self) -> Result { match self { ComptimeFloat::Binary32(a) => Ok(Self::Binary32(-a)), ComptimeFloat::Binary64(a) => Ok(Self::Binary64(-a)), } } pub fn cmp(self, other: Self) -> Result { let ord = match (self, other) { (ComptimeFloat::Binary32(a), ComptimeFloat::Binary32(b)) => a.partial_cmp(&b), (ComptimeFloat::Binary64(a), ComptimeFloat::Binary64(b)) => a.partial_cmp(&b), _ => { return Err(Error::IncompatibleTypes); } }; ord.ok_or(Error::FloatingCmp) } } pub enum ComptimeNumber { Integral(ComptimeInt), Bool(bool), Floating(ComptimeFloat), } impl From for ComptimeNumber { fn from(value: bool) -> Self { Self::Bool(value) } } impl From for ComptimeNumber { fn from(value: f32) -> Self { Self::Floating(ComptimeFloat::Binary32(value)) } } impl From for ComptimeNumber { fn from(value: f64) -> Self { Self::Floating(ComptimeFloat::Binary64(value)) } } impl From for ComptimeNumber { fn from(value: BigInt) -> Self { Self::Integral(ComptimeInt::Comptime(value)) } } impl From<(BigInt, IntegralType)> for ComptimeNumber { fn from((bits, ty): (BigInt, IntegralType)) -> Self { Self::Integral(ComptimeInt::BigInt { bits, ty }) } } impl From<(u128, IntegralType)> for ComptimeNumber { fn from((bits, ty): (u128, IntegralType)) -> Self { Self::Integral(ComptimeInt::Native { bits, ty }) } } impl ComptimeNumber { pub fn bit_count(&self) -> u16 { match self { ComptimeNumber::Integral(i) => match i { ComptimeInt::Native { ty, .. } => ty.bits, ComptimeInt::BigInt { ty, .. } => ty.bits, ComptimeInt::Comptime(i) => i.bits() as u16, }, ComptimeNumber::Bool(_) => 1, ComptimeNumber::Floating(f) => match f { ComptimeFloat::Binary32(_) => 32, ComptimeFloat::Binary64(_) => 64, }, } } pub fn add(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.add(b)?)) } (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { Ok(Self::Floating(a.add(b)?)) } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.add(b)), _ => Err(Error::IncompatibleTypes), } } pub fn sub(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.sub(b)?)) } (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { Ok(Self::Floating(a.sub(b)?)) } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.sub(b)), _ => Err(Error::IncompatibleTypes), } } pub fn mul(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.mul(b)?)) } (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { Ok(Self::Floating(a.mul(b)?)) } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.mul(b)), _ => Err(Error::IncompatibleTypes), } } pub fn div(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.div(b)?)) } (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { Ok(Self::Floating(a.div(b)?)) } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.div(b)), _ => Err(Error::IncompatibleTypes), } } pub fn rem(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.rem(b)?)) } (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { Ok(Self::Floating(a.rem(b)?)) } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Self::Bool(a.rem(b)), _ => Err(Error::IncompatibleTypes), } } pub fn neg(self) -> Result { match self { ComptimeNumber::Integral(a) => Ok(Self::Integral(a.neg()?)), ComptimeNumber::Floating(a) => Ok(Self::Floating(a.neg()?)), //ComptimeNumber::Bool(a) => todo!(), _ => Err(Error::IncompatibleTypes), } } pub fn not(self) -> Result { match self { ComptimeNumber::Integral(a) => Ok(Self::Integral(a.not()?)), // ComptimeNumber::Floating(a) => Ok(Self::Floating(a.not()?)), ComptimeNumber::Bool(a) => Ok(Self::Bool(a.not())), _ => Err(Error::IncompatibleTypes), } } pub fn bitand(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.bitand(b)?)) } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.sub(b)?)) // } (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitand(b))), _ => Err(Error::IncompatibleTypes), } } pub fn bitor(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.bitor(b)?)) } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.bitor(b)?)) // } (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitor(b))), _ => Err(Error::IncompatibleTypes), } } pub fn bitxor(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.bitxor(b)?)) } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.bitxor(b)?)) // } (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitxor(b))), _ => Err(Error::IncompatibleTypes), } } pub fn shl(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.shl(b)?)) } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.bitxor(b)?)) // } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitxor(b))), _ => Err(Error::IncompatibleTypes), } } pub fn shr(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { Ok(Self::Integral(a.shr(b)?)) } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.bitxor(b)?)) // } // (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a.bitxor(b))), _ => Err(Error::IncompatibleTypes), } } pub fn or(self, other: Self) -> Result { match (self, other) { // (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { // Ok(Self::Integral(a.shr(b)?)) // } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.bitxor(b)?)) // } (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a || b)), _ => Err(Error::IncompatibleTypes), } } pub fn and(self, other: Self) -> Result { match (self, other) { // (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => { // Ok(Self::Integral(a.shr(b)?)) // } // (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => { // Ok(Self::Floating(a.bitxor(b)?)) // } (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a && b)), _ => Err(Error::IncompatibleTypes), } } pub fn eq(self, other: Self) -> Result { match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => Ok(Self::Bool(a == b)), (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => Ok(Self::Bool(a == b)), (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => Ok(Self::Bool(a == b)), _ => Err(Error::IncompatibleTypes), } } pub fn cmp(self, other: Self) -> Result { let ord = match (self, other) { (ComptimeNumber::Integral(a), ComptimeNumber::Integral(b)) => a.cmp(b)?, (ComptimeNumber::Floating(a), ComptimeNumber::Floating(b)) => a.cmp(b)?, (ComptimeNumber::Bool(a), ComptimeNumber::Bool(b)) => a.cmp(&b), _ => { return Err(Error::IncompatibleTypes); } }; Ok(ord) } pub fn lt(self, other: Self) -> Result { Ok(Self::Bool(self.cmp(other)? == Ordering::Less)) } pub fn gt(self, other: Self) -> Result { Ok(Self::Bool(self.cmp(other)? == Ordering::Greater)) } pub fn ge(self, other: Self) -> Result { Ok(Self::Bool(self.cmp(other)? != Ordering::Less)) } pub fn le(self, other: Self) -> Result { Ok(Self::Bool(self.cmp(other)? != Ordering::Greater)) } pub fn into_bool(self) -> Result { match self { ComptimeNumber::Integral(i) => match i { ComptimeInt::Native { bits, .. } => Ok((bits != 0).into()), ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => { Ok((bits.sign() != Sign::NoSign).into()) } }, ComptimeNumber::Floating(ComptimeFloat::Binary32(f)) => Ok((f != 0.0).into()), ComptimeNumber::Floating(ComptimeFloat::Binary64(f)) => Ok((f != 0.0).into()), a => Ok(a), } } pub fn into_int(self, ty: IntegralType) -> Result { match self { ComptimeNumber::Integral(i) => match i { ComptimeInt::Native { bits, .. } => Ok((bits & ty.u128_bitmask(), ty).into()), ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => { let max = BigUint::from(2u32).pow((ty.bits - ty.signed as u16) as u32); let (sign, data) = bits.into_parts(); let data = data.clamp(BigUint::ZERO, max); Ok((BigInt::from_biguint(sign, data), ty).into()) } }, ComptimeNumber::Bool(b) => Ok((b as u128 & ty.u128_bitmask(), ty).into()), ComptimeNumber::Floating(f) => match f { ComptimeFloat::Binary32(f) => Ok((f as u128 & ty.u128_bitmask(), ty).into()), ComptimeFloat::Binary64(f) => Ok((f as u128 & ty.u128_bitmask(), ty).into()), }, } } pub fn into_float(self, ty: FloatingType) -> Result { let f = match self { ComptimeNumber::Integral(i) => match i { ComptimeInt::Native { bits, .. } => bits as f64, ComptimeInt::Comptime(bits) | ComptimeInt::BigInt { bits, .. } => { bits.to_f64().unwrap_or(f64::NAN) } }, ComptimeNumber::Bool(b) => { if b { 1.0f64 } else { 0.0f64 } } ComptimeNumber::Floating(f) => match f { ComptimeFloat::Binary32(f) => f as f64, ComptimeFloat::Binary64(f) => f as f64, }, }; match ty { FloatingType::Binary32 => Ok((f as f32).into()), FloatingType::Binary64 => Ok(f.into()), } } pub fn into_bytes_and_type(self) -> (Vec, Type) { match self { ComptimeNumber::Integral(i) => match i { ComptimeInt::Native { bits, ty } => { let bytes = (u128::BITS - bits.leading_zeros() + 7) / 8; ( bits.to_le_bytes()[..bytes as usize].to_vec(), Type::Integer(ty), ) } ComptimeInt::BigInt { bits, ty } => { (bits.to_le_bytes().to_vec(), Type::Integer(ty)) } ComptimeInt::Comptime(bits) => { (bits.to_le_bytes().to_vec(), Type::comptime_number()) } }, ComptimeNumber::Bool(b) => (vec![b as u8], Type::bool()), ComptimeNumber::Floating(f) => match f { ComptimeFloat::Binary32(f) => ( f.to_le_bytes().to_vec(), Type::Floating(FloatingType::Binary32), ), ComptimeFloat::Binary64(f) => ( f.to_le_bytes().to_vec(), Type::Floating(FloatingType::Binary64), ), }, } } }