use std::{hash::Hash, ops::Range, sync::Arc}; use chumsky::{ IterParser, Parser, error::EmptyErr, extra::{self, SimpleState}, input::{IterInput, MapExtra}, pratt::{infix, left, postfix, prefix, right}, prelude::{Recursive, choice, just, recursive}, recursive::Direct, select, text, }; use internment::Intern; use lexer::{Token, TokenItemIterator, TokenIterator}; use thiserror::Error; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum IntSize { Bits(u16), Pointer, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum InnerType { Top, Bottom, Unit, Bool, /// A signed integer constant; concrete type undetermined AnyInt, /// An unsigned integer constant; concrete type undetermined AnyUInt, /// A string slice Str, Int { signed: bool, size: IntSize, }, Float { float_type: FloatType, }, Pointer { pointee: Type, }, Array { element: Type, size: usize, }, Function { return_type: Type, parameter_types: Vec, }, Tuple { elements: Vec, }, TypeUnion { types: Vec, }, TypeIntersection { types: Vec, }, } type Type = internment::Intern; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum FloatType { F32, F64, } #[derive(Debug, PartialEq, Clone)] pub enum Value { Bool(bool), Int(i64), UInt(u64), F64(f64), F32(f32), String(String), Unit, } impl Eq for Value {} impl Hash for Value { fn hash(&self, state: &mut H) { core::mem::discriminant(self).hash(state); match self { Value::Bool(b) => b.hash(state), Value::Int(i) => i.hash(state), Value::UInt(u) => u.hash(state), Value::F64(f) => { werkzeug::util::hash_f64(state, f); } Value::F32(f) => { werkzeug::util::hash_f32(state, f); } Value::String(s) => s.hash(state), Value::Unit => {} } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ControlFlowKind { Return, Break, Continue, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Index(u32); #[derive(Debug)] pub enum AstNode { Root { files: Vec, }, File { decls: Vec, }, FunctionProto { name: String, return_type: Type, parameter_list: Index, }, ParameterList { parameters: Vec, }, Parameter(Parameter), FunctionDecl(FunctionDecl), Block { statements: Vec, expr: Option, }, Constant { ty: Type, value: Intern, }, NoopExpr, Stmt { expr: Index, }, ControlFlow { kind: ControlFlowKind, expr: Option, }, VarDecl { mutable: bool, name: String, var_type: Type, }, Assignment { dest: Index, expr: Index, }, GlobalDecl { name: String, var_type: Type, value: Index, }, StructDecl { name: String, fields: Vec, }, FieldDecl { name: String, field_type: Type, }, FieldAccess { expr: Index, field: String, }, UnresolvedDeclRef { name: String, }, DeclRef { decl: Index, }, TypeDeclRef { ty: Index, }, ExplicitCast { expr: Index, ty: Type, }, Deref { expr: Index, }, AddressOf { expr: Index, }, PlaceToValue { expr: Index, }, ValueToPlace { expr: Index, }, CallExpr { callee: Index, arguments: Vec, }, Argument { expr: Index, }, Not(Index), Negate(Index), Multiply { left: Index, right: Index, }, Divide { left: Index, right: Index, }, Modulus { left: Index, right: Index, }, Add { left: Index, right: Index, }, Subtract { left: Index, right: Index, }, BitOr { left: Index, right: Index, }, BitAnd { left: Index, right: Index, }, BitXor { left: Index, right: Index, }, LogicalOr { left: Index, right: Index, }, LogicalAnd { left: Index, right: Index, }, Eq { left: Index, right: Index, }, NotEq { left: Index, right: Index, }, Less { left: Index, right: Index, }, LessEq { left: Index, right: Index, }, Greater { left: Index, right: Index, }, GreaterEq { left: Index, right: Index, }, ShiftLeft { left: Index, right: Index, }, ShiftRight { left: Index, right: Index, }, Subscript { expr: Index, index: Index, }, If { condition: Index, then: Index, r#else: Option, }, Else { expr: Index, }, Comment { text: String, }, Attributes { attrs: Vec, }, Doc { text: String, }, Error { err: Box, }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum Visibility { #[default] Private, Public, } #[derive(Debug, Error)] pub enum ParseError<'a> { #[error("End of file.")] EOF, #[error("Unexpected token: {0:?}")] UnexpectedToken(Token<'a>), #[error("Not a type.")] NotAType, } #[derive(Default, Debug)] pub struct Ast { nodes: Vec, } impl Ast { pub fn new() -> Self { Self::default() } pub fn push(&mut self, node: AstNode) -> Index { let index = self.nodes.len() as u32; self.nodes.push(node); Index(index) } } #[derive(Debug)] pub struct FunctionDecl { attrs: Option, name: String, visibility: Visibility, return_type: Type, parameter_list: ParameterList, body: Index, } #[derive(Debug)] pub struct Parameter { mutable: bool, name: String, param_type: Type, } #[derive(Debug)] pub struct ParameterList { parameters: Vec, } fn parse() { todo!() } struct SpannedToken<'a> { token: Token<'a>, span: std::ops::Range, } #[derive(Clone)] struct SpannedTokenInput<'a> { inner: TokenItemIterator<'a>, } impl<'a> Iterator for SpannedTokenInput<'a> { type Item = (Token<'a>, Range); fn next(&mut self) -> Option { self.inner.next().map(|item| (item.token, item.span)) } } type TokenInput<'a> = IterInput, Range>; fn new_token_input<'a>(input: &'a str) -> TokenInput<'a> { let num_bytes = input.len() as u32; let token_iter = TokenIterator::new(input).into_token_items(); let spanned_input = SpannedTokenInput { inner: token_iter }; IterInput::new(spanned_input, num_bytes..num_bytes) } fn type_parser<'a, E>() -> impl Parser<'a, TokenInput<'a>, Type, E> where E: chumsky::extra::ParserExtra<'a, TokenInput<'a>, Error = EmptyErr> + 'a, { let primitives = select! { Token::Void => InnerType::Unit, Token::F32 => InnerType::Float { float_type: FloatType::F32 }, Token::F64 => InnerType::Float { float_type: FloatType::F64 }, Token::Bool => InnerType::Bool, Token::U1 => InnerType::Int { signed: false, size: IntSize::Bits(1) }, Token::U8 => InnerType::Int { signed: false, size: IntSize::Bits(8) }, Token::U16 => InnerType::Int { signed: false, size: IntSize::Bits(16) }, Token::U32 => InnerType::Int { signed: false, size: IntSize::Bits(32) }, Token::U64 => InnerType::Int { signed: false, size: IntSize::Bits(64) }, Token::USize => InnerType::Int { signed: false, size: IntSize::Pointer }, Token::I8 => InnerType::Int { signed: true, size: IntSize::Bits(8) }, Token::I16 => InnerType::Int { signed: true, size: IntSize::Bits(16) }, Token::I32 => InnerType::Int { signed: true, size: IntSize::Bits(32) }, Token::I64 => InnerType::Int { signed: true, size: IntSize::Bits(64) }, Token::ISize => InnerType::Int { signed: true, size: IntSize::Pointer }, }; let u16 = text::int(10) .to_slice() .from_str::() .try_map(|u, _span| u.map_err(|_| EmptyErr::default())); let integral_type = choice((just::<_, _, extra::Default>('u'), just('i'))) .then(u16) .map(|(sign, size)| InnerType::Int { signed: sign == 'i', size: IntSize::Bits(size), }); let custom_int = select! {Token::Ident(ident) => ident}.try_map(move |s, _span| { integral_type .parse(s) .into_result() .map_err(|_| EmptyErr::default()) }); recursive(|ty| { let pointer = just(Token::Star) .ignore_then(choice(( just(Token::Mutable).to(true), just(Token::Const).to(false), ))) .then(ty) .map(|(_mutable, pointee)| InnerType::Pointer { pointee }); choice((primitives, custom_int, pointer)).map(|p| Intern::new(p)) }) } fn visibility<'a>() -> impl Parser<'a, TokenInput<'a>, Visibility, ParserExtra> { choice((just(Token::Pub).to(Visibility::Public),)) .or_not() .map(|v| v.unwrap_or(Visibility::Private)) } fn func_parser() { let ident = select! {Token::Ident(ident) => ident}; let param = just(Token::Mutable) .to(()) .or_not() .then(ident) .then_ignore(just(Token::Colon)) .then(type_parser::()) .map_with(|((mutable, name), param_type), e| { e.state().push(AstNode::Parameter(Parameter { mutable: mutable.is_some(), name: name.to_string(), param_type, })) }); let params = param .separated_by(just(Token::Comma)) .allow_trailing() .collect::>() .delimited_by(just(Token::OpenParens), just(Token::CloseParens)) .labelled("function parameters") .map(|params| ParameterList { parameters: params }); let func = visibility() .then_ignore(just(Token::Fn)) .then(ident) .then(params) // optional return type .then( just(Token::MinusGreater) .ignore_then(type_parser()) .or_not(), ) .then(block()) .map_with(|((((vis, ident), params), ret), body), e| { e.state().push(AstNode::FunctionDecl(FunctionDecl { attrs: None, name: ident.to_string(), visibility: vis, return_type: ret.unwrap_or_else(|| Intern::new(InnerType::Unit)), parameter_list: params, body, })) }); } type ParserExtra = chumsky::extra::Full, ()>; fn block<'a>() -> impl Parser<'a, TokenInput<'a>, Index, ParserExtra> + Clone { just(Token::OpenBrace) .ignored() .then_ignore(just(Token::CloseBrace)) .map_with(|_, e: &mut MapExtra<'_, '_, _, ParserExtra>| { e.state().push(AstNode::Block { statements: vec![], expr: None, }) }) } fn unit<'a>() -> impl Parser<'a, TokenInput<'a>, Index, ParserExtra> + Clone { just(Token::OpenParens) .ignored() .ignore_then(just(Token::CloseParens)) .map_with(|_, e: &mut MapExtra, ParserExtra>| { e.state().push(AstNode::Constant { ty: Intern::new(InnerType::Unit), value: Intern::new(Value::Unit), }) }) } type E<'a, 'b> = MapExtra<'a, 'b, TokenInput<'a>, ParserExtra>; fn simple_expr<'a, 'b>( expr: Recursive, Index, ParserExtra>>, ) -> impl Parser<'a, TokenInput<'a>, Index, ParserExtra> + Clone { let ident = select! {Token::Ident(ident) => ident}.map_with( |ident, e: &mut MapExtra, ParserExtra>| { e.state().push(AstNode::UnresolvedDeclRef { name: ident.to_string(), }) }, ); choice(( unit(), ident, expr.delimited_by(just(Token::OpenParens), just(Token::CloseParens)), block(), )) } fn expr<'a>() -> impl Parser<'a, TokenInput<'a>, Index, ParserExtra> { let assignment = choice(( just(Token::Equal), just(Token::PlusEqual), just(Token::MinusEqual), just(Token::StarEqual), just(Token::SlashEqual), just(Token::PercentEqual), just(Token::AmpersandEqual), just(Token::PipeEqual), just(Token::CaretEqual), just(Token::LessLessEqual), just(Token::GreaterGreaterEqual), )); let logical_or = just(Token::PipePipe); let logical_and = just(Token::AmpersandAmpersand); let or = just(Token::Pipe); let xor = just(Token::Caret); let and = just(Token::Ampersand); let equality = choice((just(Token::BangEqual), just(Token::EqualEqual))); let relational = choice(( just(Token::LessEqual), just(Token::Less), just(Token::GreaterEqual), just(Token::Greater), )); let shift = choice((just(Token::LessLess), just(Token::GreaterGreater))); let additive = choice((just(Token::Plus), just(Token::Minus))); let multiplicative = choice((just(Token::Star), just(Token::Slash), just(Token::Percent))); let prefixes = choice(( just(Token::Bang), just(Token::Minus), just(Token::Star), just(Token::Ampersand), )); let r#as = just(Token::As).ignore_then(type_parser::()); // TODO: postfix: function call, field access, array subscript recursive(|_expr| { let simple = simple_expr(_expr.clone()); let subscript = _expr.clone().delimited_by( just(Token::OpenSquareBracket), just(Token::CloseSquareBracket), ); let arguments = _expr .separated_by(just(Token::Comma)) .allow_trailing() .collect::>() .delimited_by(just(Token::OpenParens), just(Token::CloseParens)); let field = just(Token::Dot).ignore_then(select! {Token::Ident(ident) => ident}); let assignment_expr = simple.pratt(( postfix(100, subscript, |expr, index, e: &mut E| { let node = AstNode::Subscript { expr, index }; e.state().push(node) }), postfix(100, arguments, |callee, arguments, e: &mut E| { let node = AstNode::CallExpr { callee, arguments }; e.state().push(node) }), postfix(100, field, |expr, field: &str, e: &mut E| { let node = AstNode::FieldAccess { expr, field: field.to_string(), }; e.state().push(node) }), postfix(99, r#as, |expr, ty, e: &mut E| { let node = AstNode::ExplicitCast { expr, ty }; e.state().push(node) }), prefix(95, prefixes, |op, expr, e: &mut E| { let node = match op { Token::Bang => AstNode::Not(expr), Token::Minus => AstNode::Negate(expr), Token::Star => AstNode::Deref { expr }, Token::Ampersand => AstNode::AddressOf { expr }, _ => unreachable!(), }; e.state().push(node) }), infix(left(90), multiplicative, |left, op, right, e: &mut E| { let node = match op { Token::Star => AstNode::Multiply { left, right }, Token::Slash => AstNode::Divide { left, right }, Token::Percent => AstNode::Modulus { left, right }, _ => unreachable!(), }; e.state().push(node) }), infix(left(80), additive, |left, op, right, e: &mut E| { let node = match op { Token::Plus => AstNode::Add { left, right }, Token::Minus => AstNode::Subtract { left, right }, _ => unreachable!(), }; e.state().push(node) }), infix(left(70), shift, |left, op, right, e: &mut E| { let node = match op { Token::LessLess => AstNode::ShiftLeft { left, right }, Token::GreaterGreater => AstNode::ShiftRight { left, right }, _ => unreachable!(), }; e.state().push(node) }), infix(left(60), relational, |left, op, right, e: &mut E| { let node = match op { Token::Less => AstNode::Less { left, right }, Token::LessEqual => AstNode::LessEq { left, right }, Token::Greater => AstNode::Greater { left, right }, Token::GreaterEqual => AstNode::GreaterEq { left, right }, _ => unreachable!(), }; e.state().push(node) }), infix(left(50), equality, |left, op, right, e: &mut E| { let node = match op { Token::EqualEqual => AstNode::Eq { left, right }, Token::BangEqual => AstNode::NotEq { left, right }, _ => unreachable!(), }; e.state().push(node) }), infix(left(40), and, |left, _op, right, e: &mut E| { let node = AstNode::BitAnd { left, right }; e.state().push(node) }), infix(left(30), xor, |left, _op, right, e: &mut E| { let node = AstNode::BitXor { left, right }; e.state().push(node) }), infix(left(20), or, |left, _op, right, e: &mut E| { let node = AstNode::BitOr { left, right }; e.state().push(node) }), infix(left(10), logical_and, |left, _op, right, e: &mut E| { let node = AstNode::LogicalAnd { left, right }; e.state().push(node) }), infix(left(5), logical_or, |left, _op, right, e: &mut E| { let node = AstNode::LogicalOr { left, right }; e.state().push(node) }), infix(right(1), assignment, |left, op, right, e: &mut E| { let left = match op { Token::Equal => { let node = AstNode::Assignment { dest: left, expr: right, }; return e.state().push(node); } Token::PlusEqual => e.state().push(AstNode::Add { left, right }), Token::MinusEqual => e.state().push(AstNode::Subtract { left, right }), Token::StarEqual => e.state().push(AstNode::Multiply { left, right }), Token::SlashEqual => e.state().push(AstNode::Divide { left, right }), Token::PercentEqual => e.state().push(AstNode::Modulus { left, right }), Token::AmpersandEqual => e.state().push(AstNode::BitAnd { left, right }), Token::PipeEqual => e.state().push(AstNode::BitOr { left, right }), Token::CaretEqual => e.state().push(AstNode::BitXor { left, right }), Token::LessLessEqual => e.state().push(AstNode::ShiftLeft { left, right }), Token::GreaterGreaterEqual => { e.state().push(AstNode::ShiftRight { left, right }) } _ => unreachable!(), }; let node = AstNode::Assignment { dest: left, expr: right, }; e.state().push(node) }), )); let else_expr = just(Token::Else).ignore_then(_expr.clone()); let if_expr = just(Token::If) .ignore_then( _expr .clone() .delimited_by(just(Token::OpenParens), just(Token::CloseParens)), ) .then(_expr.clone()) .then(else_expr.or_not()) .map_with(|((condition, then), r#else), e: &mut E| { let node = AstNode::If { condition, then, r#else, }; e.state().push(node) }); let expr = choice((if_expr, assignment_expr)).labelled("expression"); Arc::new(expr) }) } mod constants; #[cfg(test)] mod tests { use chumsky::{Parser, extra::SimpleState}; use crate::{Ast, AstNode, new_token_input, type_parser}; #[test] fn print_ast_node_size() { eprintln!("Size of AstNode: {}", std::mem::size_of::()); } #[test] fn parse_types() { let ty = type_parser::() .parse(new_token_input("i32")) .unwrap(); assert_eq!( *ty, crate::InnerType::Int { signed: true, size: crate::IntSize::Bits(32) } ); let ty = type_parser::() .parse(new_token_input("*const i32")) .unwrap(); assert_eq!( *ty, crate::InnerType::Pointer { pointee: crate::Intern::new(crate::InnerType::Int { signed: true, size: crate::IntSize::Bits(32) }) } ); let ty = type_parser::() .parse(new_token_input("*mut *const u8")) .unwrap(); assert_eq!( *ty, crate::InnerType::Pointer { pointee: crate::Intern::new(crate::InnerType::Pointer { pointee: crate::Intern::new(crate::InnerType::Int { signed: false, size: crate::IntSize::Bits(8) }) }) } ); let ty = type_parser::() .parse(new_token_input("i10")) .unwrap(); assert_eq!( *ty, crate::InnerType::Int { signed: true, size: crate::IntSize::Bits(10) } ); } #[test] fn parse_exprs() { let print_ast = |tokens| { let mut state = SimpleState(Ast::new()); let out = crate::expr().parse_with_state(tokens, &mut state).unwrap(); eprintln!("{:?}", state.0); }; print_ast(new_token_input("()")); print_ast(new_token_input("!() as i32")); } }