From 00359a306cc3cb22f8c41dbf04a766a67804dbad Mon Sep 17 00:00:00 2001 From: Janis Date: Sun, 18 Aug 2024 14:52:24 +0200 Subject: [PATCH] stufffffffffff --- src/ast.rs | 70 ++++++++++---- src/error.rs | 26 +++++ src/lexer.rs | 40 +++++--- src/lib.rs | 1 + src/parser.rs | 225 ++++++++++++++++++++++++++++++++++++++++---- src/string_table.rs | 80 +++++++++++++++- src/triples.rs | 7 +- 7 files changed, 396 insertions(+), 53 deletions(-) create mode 100644 src/error.rs diff --git a/src/ast.rs b/src/ast.rs index dc21424..f4ac89c 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -6,7 +6,7 @@ use crate::string_table::{self, ImmOrIndex}; pub type Node = NonZero; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Tag { Undefined, Root, @@ -46,14 +46,6 @@ pub enum Tag { Ident { name: string_table::Index, }, - IntegralConstant { - bits: string_table::Index, - ty: Option, - }, - FloatingConstant { - bits: u64, - ty: FloatingType, - }, Constant { bytes: ImmOrIndex, ty: Type, @@ -197,7 +189,7 @@ pub enum Tag { }, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum LetOrVar { Let, Var, @@ -298,14 +290,12 @@ impl PartialEq for Type { (Self::Floating(l0), Self::Floating(r0)) => l0 == r0, ( Self::Pointer { - constness: l_constness, - pointee: l_pointee, + pointee: l_pointee, .. }, Self::Pointer { - constness: r_constness, - pointee: r_pointee, + pointee: r_pointee, .. }, - ) => l_constness == r_constness && l_pointee == r_pointee, + ) => l_pointee == r_pointee, ( Self::Fn { parameter_types: l_parameter_types, @@ -322,12 +312,41 @@ impl PartialEq for Type { } impl Type { + pub fn as_primitive_type(&self) -> Option { + match self { + Type::Void => Some(PrimitiveType::Void), + Type::Bool => Some(PrimitiveType::Bool), + Type::Integer(t) => Some(PrimitiveType::IntegralType(*t)), + Type::Floating(t) => Some(PrimitiveType::FloatingType(*t)), + _ => None, + } + } + pub fn equal_type(&self, rhs: &Self) -> Option { match (self, rhs) { (Self::ComptimeNumber, Self::Floating(_)) | (Self::ComptimeNumber, Self::Integer(_)) => Some(rhs.clone()), (Self::Integer(_), Self::ComptimeNumber) | (Self::Floating(_), Self::ComptimeNumber) => Some(self.clone()), + ( + Self::Pointer { + constness: a_const, + pointee: a_ptr, + }, + Self::Pointer { + constness: b_const, + pointee: b_ptr, + }, + ) => { + if a_ptr == b_ptr { + Some(Self::Pointer { + constness: *a_const || *b_const, + pointee: a_ptr.clone(), + }) + } else { + None + } + } _ => { if self.eq(rhs) { Some(self.clone()) @@ -343,6 +362,9 @@ impl Type { pub fn bool() -> Type { Self::Void } + pub fn comptime_number() -> Type { + Self::ComptimeNumber + } pub fn any() -> Type { Self::Any } @@ -353,6 +375,22 @@ impl Type { } } + pub fn bit_width(&self) -> u16 { + match self { + Type::Any => 0, + Type::Void => 0, + Type::Bool => 1, + Type::ComptimeNumber => u16::MAX, + Type::Integer(i) => i.bits, + Type::Floating(f) => match f { + FloatingType::Binary32 => 32, + FloatingType::Binary64 => 64, + }, + Type::Pointer { .. } => 64, + Type::Fn { .. } => 64, + } + } + pub fn can_logical_and_or(&self) -> bool { match self { Type::Bool => true, @@ -452,7 +490,7 @@ impl Type { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PrimitiveType { FloatingType(FloatingType), IntegralType(IntegralType), diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..cdbb30e --- /dev/null +++ b/src/error.rs @@ -0,0 +1,26 @@ +use crate::ast::Type; + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum AnalysisErrorTag { + #[error("Mismatching types in function return.")] + MismatchingTypesFunctionReturn, + #[error("Insufficient bits in type {1} for constant with {0} bits")] + InsufficientBitsInTypeForConstant(u32, Type), +} + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub struct AnalysisError { + inner: AnalysisErrorTag, +} + +impl AnalysisError { + pub fn new(inner: AnalysisErrorTag) -> Self { + Self { inner } + } +} + +impl core::fmt::Display for AnalysisError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + core::fmt::Debug::fmt(&self.inner, f) + } +} diff --git a/src/lexer.rs b/src/lexer.rs index 25f9acc..7f78c18 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -593,18 +593,7 @@ pub mod bigint { } pub fn bit_width(&self) -> u32 { - let mut bits = self.0.len() as u32; - - for d in self.0.iter().rev() { - if *d == 0 { - bits -= u32::BITS; - } else { - bits -= d.leading_zeros(); - break; - } - } - - bits + count_bits(&self.0) } pub fn from_bytes_le(bytes: &[u8]) -> BigInt { @@ -677,6 +666,33 @@ pub mod bigint { } } + /// counts used bits in a u32 slice, discards leading zeros in MSB. + /// `[0xff,0xff,0x00,0x00]` -> 16 + /// `[0xff,0xff,0x00]` -> 16 + /// `[0xff,0xff,0x0f]` -> 20 + pub fn count_bits(bytes: &[u32]) -> u32 { + let mut bits = bytes.len() as u32; + + for d in bytes.iter().rev() { + if *d == 0 { + bits -= u32::BITS; + } else { + bits -= d.leading_zeros(); + break; + } + } + + bits + } + + #[test] + fn test_count_bits() { + assert_eq!(count_bits(&[0xffffffff, 0x00, 0x00]), 32); + assert_eq!(count_bits(&[0xffffffff, 0xff, 0x00]), 40); + assert_eq!(count_bits(&[0xffffffff, 0xff]), 40); + assert_eq!(count_bits(&[0xffffffff, 0xff, 0xffff]), 64 + 16); + } + #[allow(unused)] /// lhs must be bigger than rhs fn sub_bigint(lhs: &mut [u32], rhs: &[u32]) { diff --git a/src/lib.rs b/src/lib.rs index b6d5e3e..00d6a61 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ pub mod ast; pub mod codegen; pub mod common; +pub mod error; pub mod lexer; pub mod parser; pub mod string_table; diff --git a/src/parser.rs b/src/parser.rs index 0750104..a75bbb1 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -5,8 +5,9 @@ use itertools::Itertools; use crate::{ ast::{FloatingType, IntegralType, LetOrVar, Node, PrimitiveType, Tag, Type}, common::NextIf, + error::{AnalysisError, AnalysisErrorTag}, lexer::{bigint::BigInt, Radix, TokenIterator}, - string_table::{Index, StringTable}, + string_table::{ImmOrIndex, Index, StringTable}, symbol_table::{SymbolKind, SymbolTable}, tokens::Token, }; @@ -41,6 +42,11 @@ impl core::ops::Index for Nodes { &self.inner[index.get() as usize] } } +impl core::ops::IndexMut for Nodes { + fn index_mut(&mut self, index: Node) -> &mut Self::Output { + &mut self.inner[index.get() as usize] + } +} impl Nodes { fn new() -> Nodes { @@ -126,6 +132,25 @@ impl Tree { } } + pub fn global_decls(&self) -> impl Iterator { + self.global_decls.iter().map(|decl| { + let name = match self.nodes.get_node(*decl) { + Tag::FunctionDecl { proto, body } => { + let Tag::FunctionProto { name, .. } = self.nodes.get_node(*proto) else { + unreachable!() + }; + + self.get_ident_str(*name).unwrap().to_owned() + } + Tag::GlobalDecl { name, .. } => self.get_ident_str(*name).unwrap().to_owned(), + _ => { + unreachable!() + } + }; + (*decl, name) + }) + } + #[allow(unused)] fn is_integral_type(lexeme: &str) -> Option<()> { let mut iter = lexeme.chars(); @@ -822,16 +847,35 @@ impl Tree { | Token::IntegerConstant => { _ = tokens.next(); let (bits, ty) = Self::parse_integral_constant(token.token(), token.lexeme()); - let index = self.strings.insert(bits.into_bytes_le()); + let bytes = bits.into_bytes_le(); + + const BUF_SIZE: usize = core::mem::size_of::(); + let mut buf = [0u8; BUF_SIZE]; + buf[..bytes.len().min(BUF_SIZE)] + .copy_from_slice(&bytes[..bytes.len().min(BUF_SIZE)]); + let bytes = match bytes.len() { + 0..2 => { + let (buf, _) = buf.split_at(core::mem::size_of::()); + let dw = u32::from_le_bytes(buf.try_into().unwrap()); + ImmOrIndex::U32(dw) + } + 0..4 => { + let (buf, _) = buf.split_at(core::mem::size_of::()); + let qw = u64::from_le_bytes(buf.try_into().unwrap()); + ImmOrIndex::U64(qw) + } + 0.. => { + let idx = self.strings.insert(bytes); + ImmOrIndex::Index(idx) + } + }; + let ty = match ty { Some(int) => Type::Integer(int), None => Type::ComptimeNumber, }; - Ok(self.nodes.push_tag(Tag::Constant { - bytes: crate::string_table::ImmOrIndex::Index(index), - ty, - })) + Ok(self.nodes.push_tag(Tag::Constant { bytes, ty })) } Token::FloatingConstant | Token::FloatingExpConstant @@ -840,8 +884,13 @@ impl Tree { _ = tokens.next(); let (bits, ty) = Self::parse_floating_constant(token.token(), token.lexeme()); + let bytes = match ty { + FloatingType::Binary32 => ImmOrIndex::U32(bits as u32), + FloatingType::Binary64 => ImmOrIndex::U64(bits as u64), + }; + Ok(self.nodes.push_tag(Tag::Constant { - bytes: crate::string_table::ImmOrIndex::U64(bits), + bytes, ty: Type::Floating(ty), })) } @@ -906,7 +955,9 @@ impl Tree { _ => None, } } +} +impl Tree { fn render_node( &mut self, writer: &mut W, @@ -1360,17 +1411,13 @@ impl Tree { ) } Tag::Constant { bytes, ty } => { - let bytes = match bytes { - crate::string_table::ImmOrIndex::U64(i) => &i.to_le_bytes()[..], - crate::string_table::ImmOrIndex::U32(i) => &i.to_le_bytes()[..], - crate::string_table::ImmOrIndex::Index(idx) => self.strings.get_bytes(idx), - }; writeln_indented!( indent, writer, - "%{} = constant{{ ty: {}, bytes: {bytes:?}}}", + "%{} = constant{{ ty: {}, bytes: {}}}", node.get(), - ty + ty, + self.strings.display_idx(bytes) ) } _ => unreachable!(), @@ -1487,6 +1534,146 @@ impl Tree { } } +impl Tree { + /// type-checks and inserts appropriate explicit-cast nodes. + pub fn typecheck(&mut self) { + let mut errors = Vec::new(); + for decl in self.global_decls.clone() { + self.typecheck_node(&mut errors, decl); + } + } + + // TODO: inline types into the AST proper before tackling this. + // for now, comptime_number is not supported in IR gen, then. + fn typecheck_node(&mut self, errors: &mut Vec, node: Node) { + #[allow(unused_variables)] + match self.nodes[node].clone() { + Tag::FunctionProto { .. } => {} + Tag::FunctionDecl { proto, body } => { + let Tag::FunctionProto { return_type, .. } = self.nodes[proto] else { + unreachable!() + }; + + let body_t = self.type_of_node(body); + let ret_t = self.type_of_node(return_type); + + if let Some(peer_t) = body_t.equal_type(&ret_t) { + if body_t == Type::comptime_number() { + let Tag::Block { trailing_expr, .. } = self.nodes[body] else { + unreachable!() + }; + if let Some(expr) = trailing_expr { + let ty = self.nodes.push_tag(Tag::PrimitiveType( + peer_t + .as_primitive_type() + .expect("comptime cannot be cast into a non-primitive type"), + )); + let expr = self.nodes.push_tag(Tag::ExplicitCast { + lhs: expr, + typename: ty, + }); + + let Tag::Block { trailing_expr, .. } = &mut self.nodes[body] else { + unreachable!() + }; + *trailing_expr = Some(expr) + } + } + } else { + errors.push(AnalysisError::new( + AnalysisErrorTag::MismatchingTypesFunctionReturn, + )); + } + } + Tag::Constant { bytes, ty } => { + let bits = self.strings.count_bits(bytes); + if bits < ty.bit_width() as u32 { + errors.push(AnalysisError::new( + AnalysisErrorTag::InsufficientBitsInTypeForConstant(bits, ty.clone()), + )); + } + } + Tag::Block { + statements, + trailing_expr, + } => { + for statement in statements { + self.typecheck_node(errors, statement); + } + if let Some(expr) = trailing_expr { + self.typecheck_node(errors, expr); + } + } + Tag::ReturnStmt { expr } => { + if let Some(expr) = expr { + self.typecheck_node(errors, expr); + } + } + Tag::ExprStmt { expr } => { + self.typecheck_node(errors, expr); + } + Tag::VarDecl { + explicit_type, + assignment, + .. + } => { + assignment.map(|t| self.typecheck_node(errors, t)); + + let explicit_t = explicit_type.map(|t| self.type_of_node(t)); + let assignment_t = assignment.map(|t| self.type_of_node(t)); + + match (explicit_t, assignment_t) { + (None, None) => unreachable!(), + (Some(explicit_t), None) => {} + (Some(explicit_t), Some(assignment_t)) => { + // TODO: ensure types match, explicit-cast comptime_number + } + (None, Some(assignment_t)) => { + // TODO: set explicit_type to assignment_t + } + } + } + Tag::GlobalDecl { + name, + explicit_type, + assignment, + } => todo!(), + Tag::DeclRef(_) => todo!(), + Tag::GlobalRef(_) => todo!(), + Tag::CallExpr { lhs, rhs } => todo!(), + Tag::ArgumentList { parameters } => todo!(), + Tag::Argument { name, expr } => todo!(), + Tag::ExplicitCast { lhs, typename } => todo!(), + Tag::Deref { lhs } => todo!(), + Tag::Ref { lhs } => todo!(), + Tag::Not { lhs } => todo!(), + Tag::Negate { lhs } => todo!(), + Tag::Or { lhs, rhs } => todo!(), + Tag::And { lhs, rhs } => todo!(), + Tag::BitOr { lhs, rhs } => todo!(), + Tag::BitAnd { lhs, rhs } => todo!(), + Tag::BitXOr { lhs, rhs } => todo!(), + Tag::Eq { lhs, rhs } => todo!(), + Tag::NEq { lhs, rhs } => todo!(), + Tag::Lt { lhs, rhs } => todo!(), + Tag::Gt { lhs, rhs } => todo!(), + Tag::Le { lhs, rhs } => todo!(), + Tag::Ge { lhs, rhs } => todo!(), + Tag::Shl { lhs, rhs } => todo!(), + Tag::Shr { lhs, rhs } => todo!(), + Tag::Add { lhs, rhs } => todo!(), + Tag::Sub { lhs, rhs } => todo!(), + Tag::Mul { lhs, rhs } => todo!(), + Tag::Rem { lhs, rhs } => todo!(), + Tag::Div { lhs, rhs } => todo!(), + Tag::Assign { lhs, rhs } => todo!(), + _ => { + unreachable!() + } + } + } +} + static PRECEDENCE_MAP: std::sync::LazyLock> = std::sync::LazyLock::new(|| { HashMap::from([ (Token::PipePipe, 10), @@ -1518,7 +1705,7 @@ mod tests { #[test] fn render_ast() { - let src = "let a: u21 = 3;"; + let src = "let a: u21 = 3u32;"; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); let mut tree = Tree::new(); @@ -1532,8 +1719,8 @@ mod tests { fn render_ast2() { let src = " fn main() -> void { -let a: u32 = 0; -a == 1 +let a: u32 = 0u32; +a == 1u32 } fn square(x: u32) -> u32 { x * x @@ -1553,10 +1740,10 @@ x * x fn render_ast3() { let src = " fn main() -> void { -let a: u32 = 0; +let a: u32 = 0u32; a == global } -const global: u32 = 42; +const global: u32 = 42u32; "; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); diff --git a/src/string_table.rs b/src/string_table.rs index 12356d8..b98e58b 100644 --- a/src/string_table.rs +++ b/src/string_table.rs @@ -1,12 +1,12 @@ use std::{collections::BTreeMap, hash::Hasher}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Index { pub start: u32, pub end: u32, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ImmOrIndex { U64(u64), U32(u32), @@ -19,7 +19,6 @@ impl Index { } } -#[derive(Debug)] pub struct StringTable { bytes: Vec, indices: BTreeMap, @@ -41,6 +40,24 @@ impl StringTable { } } + pub fn display_idx(&self, idx: ImmOrIndex) -> ImmOrIndexDisplay { + ImmOrIndexDisplay::new(self, idx) + } + + pub fn count_bits(&self, idx: ImmOrIndex) -> u32 { + match idx { + ImmOrIndex::U64(v) => u64::BITS - v.leading_zeros(), + ImmOrIndex::U32(v) => u32::BITS - v.leading_zeros(), + ImmOrIndex::Index(idx) => { + let bytes = self.get_bytes(idx); + let ints = unsafe { + core::slice::from_raw_parts(bytes.as_ptr().cast::(), bytes.len() / 4) + }; + crate::lexer::bigint::count_bits(ints) + } + } + } + pub fn get_str(&self, idx: Index) -> &str { unsafe { core::str::from_utf8_unchecked(&self[idx]) } } @@ -73,3 +90,60 @@ impl StringTable { index } } + +mod display { + use core::{fmt::Debug, str}; + + use super::*; + + impl Debug for StringTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list() + .entries(self.indices.iter().map(|(_, idx)| { + struct Test<'a> { + bytes: &'a [u8], + str: Option<&'a str>, + } + impl<'a> Debug for Test<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{{ bytes: {:x?}", self.bytes)?; + if let Some(str) = self.str { + write!(f, ", str: {}", str)?; + } + write!(f, " }}") + } + } + let bytes = self.get_bytes(*idx); + let str = str::from_utf8(bytes).ok(); + Test { bytes, str } + })) + .finish() + } + } + + pub struct ImmOrIndexDisplay<'table> { + table: &'table StringTable, + idx: ImmOrIndex, + } + + impl<'table> ImmOrIndexDisplay<'table> { + pub fn new(table: &'table StringTable, idx: ImmOrIndex) -> Self { + Self { table, idx } + } + } + + impl<'table> core::fmt::Display for ImmOrIndexDisplay<'table> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.idx { + ImmOrIndex::U64(i) => write!(f, "0x{i:0>16x}"), + ImmOrIndex::U32(i) => write!(f, "0x{i:0>8x}"), + ImmOrIndex::Index(idx) => { + let bytes = self.table.get_bytes(idx); + write!(f, "{bytes:?}") + } + } + } + } +} + +pub use display::ImmOrIndexDisplay; diff --git a/src/triples.rs b/src/triples.rs index 10b6ece..b4037a8 100644 --- a/src/triples.rs +++ b/src/triples.rs @@ -3,7 +3,7 @@ use std::collections::{hash_map::Entry, HashMap}; use crate::{ - ast::{FloatingType, IntegralType, Node as AstNode, Tag, Type}, + ast::{Node as AstNode, Tag, Type}, parser::Tree, string_table::{ImmOrIndex, Index as StringsIndex}, writeln_indented, @@ -533,7 +533,7 @@ mod tests { fn ir() { let src = " fn main() -> u32 { - let a: u32 = 0 + 3u32; + let a: u32 = 0u32 + 3u32; let ptr_a = &a; return *ptr_a * global; } @@ -542,7 +542,7 @@ fn square(x: u32) -> u32 { x * x } -const global: u32 = 42; +const global: u32 = 42u32; "; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); @@ -552,6 +552,7 @@ const global: u32 = 42; let mut buf = String::new(); tree.render(&mut buf).unwrap(); println!("{buf}"); + println!("{:#?}", tree.strings); let mut ir = IR::new(); ir.build(&mut tree);