diff --git a/src/ast.rs b/src/ast.rs index f4ac89c..0c987f9 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -322,6 +322,13 @@ impl Type { } } + pub fn pointee(&self) -> Option<&Type> { + match self { + Self::Pointer { pointee, .. } => Some(&pointee), + _ => None, + } + } + pub fn equal_type(&self, rhs: &Self) -> Option { match (self, rhs) { (Self::ComptimeNumber, Self::Floating(_)) diff --git a/src/common.rs b/src/common.rs index 284fdd6..fc99533 100644 --- a/src/common.rs +++ b/src/common.rs @@ -157,3 +157,13 @@ pub trait FallibleParse: Iterator + Clone { } impl FallibleParse for T where T: Iterator + Clone {} + +#[macro_export] +macro_rules! variant { + ($value:expr => $pattern:pat) => { + let $pattern = $value else { unreachable!() }; + }; + ($pattern:pat = $value:expr) => { + let $pattern = $value else { unreachable!() }; + }; +} diff --git a/src/lib.rs b/src/lib.rs index 00d6a61..2af9238 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ #![allow(unused_macros)] pub mod ast; -pub mod codegen; +// pub mod codegen; pub mod common; pub mod error; pub mod lexer; diff --git a/src/parser.rs b/src/parser.rs index a75bbb1..7833eaf 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -10,6 +10,7 @@ use crate::{ string_table::{ImmOrIndex, Index, StringTable}, symbol_table::{SymbolKind, SymbolTable}, tokens::Token, + variant, }; #[derive(Debug, thiserror::Error)] @@ -132,23 +133,26 @@ impl Tree { } } - pub fn global_decls(&self) -> impl Iterator { - self.global_decls.iter().map(|decl| { - let name = match self.nodes.get_node(*decl) { - Tag::FunctionDecl { proto, body } => { - let Tag::FunctionProto { name, .. } = self.nodes.get_node(*proto) else { - unreachable!() - }; + pub fn global_decls(&self) -> Vec<(Node, String)> { + self.global_decls + .iter() + .map(|decl| { + let name = match self.nodes.get_node(*decl) { + Tag::FunctionDecl { proto, .. } => { + let Tag::FunctionProto { name, .. } = self.nodes.get_node(*proto) else { + unreachable!() + }; - self.get_ident_str(*name).unwrap().to_owned() - } - Tag::GlobalDecl { name, .. } => self.get_ident_str(*name).unwrap().to_owned(), - _ => { - unreachable!() - } - }; - (*decl, name) - }) + self.get_ident_str(*name).unwrap().to_owned() + } + Tag::GlobalDecl { name, .. } => self.get_ident_str(*name).unwrap().to_owned(), + _ => { + unreachable!() + } + }; + (*decl, name) + }) + .collect::>() } #[allow(unused)] diff --git a/src/triples.rs b/src/triples.rs index b4037a8..d601736 100644 --- a/src/triples.rs +++ b/src/triples.rs @@ -1,59 +1,189 @@ #![allow(dead_code)] -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; use crate::{ ast::{Node as AstNode, Tag, Type}, parser::Tree, - string_table::{ImmOrIndex, Index as StringsIndex}, - writeln_indented, + string_table::{ImmOrIndex, Index as StringsIndex, StringTable}, + variant, writeln_indented, }; type Node = u32; +#[derive(Debug)] enum NodeOrList { Node(Node), // node of alloca location List(Vec), // list of references to `Node(_)` } -enum Inst { - Label(String), - Constant(Value), - UnresolvedRef, - ExternRef(AstNode), - Ref(Node), - Parameter { size: u32, align: u32 }, - Add { lhs: Node, rhs: Node }, - Sub { lhs: Node, rhs: Node }, - Div { lhs: Node, rhs: Node }, - Mul { lhs: Node, rhs: Node }, - Rem { lhs: Node, rhs: Node }, - BitAnd { lhs: Node, rhs: Node }, - BitOr { lhs: Node, rhs: Node }, - BitXOr { lhs: Node, rhs: Node }, - Negate { lhs: Node }, - ReturnValue { lhs: Node }, - Return, - ExplicitCast { node: Node, ty: Type }, - Alloc { size: u32, align: u32 }, - AddressOf(Node), - Load { source: Node }, - Store { dest: Node, source: Node }, +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum Type2 { + Integral(u16), + Binary32, + Binary64, + Bool, + Pointer, } -struct Value { - explicit_type: Option, - bytes: ImmOrIndex, +impl Type2 { + fn size(&self) -> u32 { + match self { + Type2::Integral(bits) => bits.div_ceil(8) as u32, + Type2::Binary32 => 4, + Type2::Binary64 => 8, + Type2::Bool => 1, + Type2::Pointer => 8, + } + } + + fn align(&self) -> u32 { + self.size() + } } -impl core::fmt::Display for Value { +impl core::fmt::Display for Type2 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{} {:?}", - self.explicit_type.as_ref().unwrap_or(&Type::any()), - self.bytes - ) + match self { + Type2::Integral(bits) => write!(f, "i{bits}"), + Type2::Binary32 => write!(f, "f32"), + Type2::Binary64 => write!(f, "f64"), + Type2::Bool => write!(f, "bool"), + Type2::Pointer => write!(f, "ptr"), + } + } +} + +impl From for Type2 { + fn from(value: Type) -> Self { + (&value).into() + } +} + +impl From<&Type> for Type2 { + fn from(value: &Type) -> Self { + match value { + Type::Bool => Type2::Bool, + Type::Integer(i) => Type2::Integral(i.bits), + Type::Floating(f) => match f { + crate::ast::FloatingType::Binary32 => Type2::Binary32, + crate::ast::FloatingType::Binary64 => Type2::Binary64, + }, + Type::Pointer { .. } => Type2::Pointer, + _ => todo!(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum Inst { + /// index + Label, + /// index + FunctionStart, + /// u32 + ConstantU32, + /// lo, hi + ConstantU64, + /// index + ConstantMultiByte, + /// ast-node + ExternRef, + /// size, align + Alloca, + /// src + Load(Type2), + /// src, dst + Store(Type2), + /// ptr, index, + GetElementPtr(Type2), + /// size, align + Parameter, + /// lhs, rhs + Add(Type2), + /// lhs, rhs + Sub(Type2), + /// lhs, rhs + Mul(Type2), + /// lhs, rhs + Div(Type2), + /// lhs, rhs + Rem(Type2), + /// lhs, rhs + BitAnd(Type2), + /// lhs, rhs + BitOr(Type2), + /// lhs, rhs + BitXOr(Type2), + /// lhs + Negate(Type2), + /// lhs + ReturnValue, + /// no parameters + Return, +} + +#[derive(Debug, Clone, Copy)] +struct Data { + lhs: u32, + rhs: u32, +} + +impl Data { + fn new(lhs: u32, rhs: u32) -> Self { + Self { lhs, rhs } + } + + fn lhs(lhs: u32) -> Data { + Self { lhs, rhs: 0 } + } + + fn as_u32(&self) -> u32 { + self.lhs + } + fn as_u64(&self) -> u64 { + self.lhs as u64 | (self.rhs as u64) << u32::BITS as u64 + } + fn as_index(&self) -> StringsIndex { + crate::string_table::Index { + start: self.lhs, + end: self.rhs, + } + } + fn as_lhs_rhs(&self) -> (u32, u32) { + (self.lhs, self.rhs) + } +} + +impl From for Data { + fn from(value: u32) -> Self { + Self { lhs: value, rhs: 0 } + } +} + +impl From for Data { + fn from(value: u64) -> Self { + let (lo, hi) = { (value as u32, (value >> u32::BITS as u64) as u32) }; + Self { lhs: lo, rhs: hi } + } +} + +impl From for Data { + fn from(value: crate::string_table::ImmOrIndex) -> Self { + match value { + ImmOrIndex::U64(v) => v.into(), + ImmOrIndex::U32(v) => v.into(), + ImmOrIndex::Index(v) => v.into(), + } + } +} + +impl From for Data { + fn from(value: crate::string_table::Index) -> Self { + Self { + lhs: value.start, + rhs: value.end, + } } } @@ -82,77 +212,48 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> { } } - fn patch_unresolved(&mut self, node: AstNode, resolved: Node) { - match self.lookup.entry(node) { - Entry::Occupied(mut o) => { - match o.get_mut() { - NodeOrList::Node(_) => { - panic!("there shouldn't be a node here.") - } - NodeOrList::List(refs) => { - for &r in refs.iter() { - self.ir.nodes[r as usize] = Inst::Ref(resolved); - } - } - } - o.insert(NodeOrList::Node(resolved)); - } - Entry::Vacant(v) => { - v.insert(NodeOrList::Node(resolved)); - } - } - } - fn visit(&mut self, node: AstNode) -> Node { match &self.tree.nodes[node].clone() { Tag::FunctionDecl { proto, body } => { - self.visit(*proto); + variant!( + Tag::FunctionProto { + name, + parameters, + .. + } = self.tree.nodes.get_node(*proto) + ); + + self.ir.push(Inst::FunctionStart, { + variant!(Tag::Ident { name } = self.tree.nodes.get_node(*name)); + Some((*name).into()) + }); + + if let Some(parameters) = parameters { + variant!( + Tag::ParameterList { parameters } = self.tree.nodes.get_node(*parameters) + ); + + for param in parameters { + variant!(Tag::Parameter { ty, .. } = self.tree.nodes.get_node(*param)); + let ty = self.tree.type_of_node(*ty); + let ir = self.ir.push( + Inst::Parameter, + Some(Data::new(ty.size_of(), ty.align_of())), + ); + + self.lookup.insert(*param, NodeOrList::Node(ir)); + } + } + self.tree.st.into_child(node); let value = self.visit(*body); // TODO: return value of body expression - let node = if value != !0 { - let return_type = { - match self.tree.nodes.get_node(*proto) { - Tag::FunctionProto { return_type, .. } => *return_type, - _ => unreachable!(), - } - }; - self.type_check(return_type, *body); - self.ir.push(Inst::ReturnValue { lhs: value }) + self.tree.st.into_parent(); + if value != !0 { + self.ir.push(Inst::ReturnValue, Some(Data::lhs(value))) } else { !0 - }; - - self.tree.st.into_parent(); - node - } - Tag::FunctionProto { - parameters, name, .. - } => { - let label = self.ir.push(Inst::Label( - self.tree.get_ident_str(*name).unwrap().to_string(), - )); - parameters.map(|p| self.visit(p)); - - self.patch_unresolved(node, label); - - label - } - Tag::ParameterList { parameters } => { - for param in parameters { - self.visit(*param); } - !0 - } - Tag::Parameter { ty, .. } => { - let ty = self.tree.type_of_node(*ty); - let param = self.ir.push(Inst::Parameter { - size: ty.size_of(), - align: ty.align_of(), - }); - - self.lookup.insert(node, NodeOrList::Node(param)); - param } Tag::Block { statements, @@ -170,278 +271,142 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> { } Tag::VarDecl { .. } => { let ty = self.tree.type_of_node(node); - let alloca = self.ir.push(Inst::Alloc { - size: ty.size_of(), - align: ty.align_of(), - }); - self.patch_unresolved(node, alloca); + let alloca = self + .ir + .push(Inst::Alloca, Some(Data::new(ty.size_of(), ty.align_of()))); + + self.lookup.insert(node, NodeOrList::Node(alloca)); alloca } - Tag::GlobalDecl { name, .. } => { - let ty = self.tree.type_of_node(node); - let _label = self.ir.push(Inst::Label( - self.tree.get_ident_str(*name).unwrap().to_string(), - )); - let alloca = self.ir.push(Inst::Alloc { - size: ty.size_of(), - align: ty.align_of(), - }); - self.patch_unresolved(node, alloca); - alloca + Tag::GlobalDecl { .. } => { + // self.ir.push(Inst::Label, { + // variant!(Tag::Ident { name } = self.tree.nodes.get_node(*name)); + // Some((*name).into()) + // }); + unimplemented!() } Tag::ReturnStmt { expr } => { if let Some(expr) = expr { let expr = self.visit(*expr); - self.ir.push(Inst::ReturnValue { lhs: expr }) + self.ir.push(Inst::ReturnValue, Some(Data::lhs(expr))) } else { - self.ir.push(Inst::Return) + self.ir.push(Inst::Return, None) } } Tag::ExprStmt { expr } => self.visit(*expr), Tag::Deref { lhs } => { + let ty = self.tree.type_of_node(*lhs).pointee().unwrap().clone(); let lhs = self.visit(*lhs); - self.ir.push(Inst::Load { source: lhs }) + self.ir.push(Inst::Load(ty.into()), Some(Data::lhs(lhs))) } Tag::Assign { lhs, rhs } => { + let ty = self.tree.type_of_node(*rhs); let dest = self.visit(*lhs); let source = self.visit(*rhs); - self.type_check(*lhs, *rhs); - - self.ir.push(Inst::Store { dest, source }) + self.ir + .push(Inst::Store(ty.into()), Some(Data::new(source, dest))) } - Tag::Add { - lhs: lhs0, - rhs: rhs0, - } => { - let lhs = self.visit(*lhs0); - let rhs = self.visit(*rhs0); - let ty = self.type_check(*lhs0, *rhs0); - if !ty.can_add_sub() { - eprintln!("add is not available for type {ty:?}"); - } - self.ir.push(Inst::Add { lhs, rhs }) + Tag::Add { lhs, rhs } => { + let ty = self.tree.type_of_node(*lhs); + let lhs = self.visit(*lhs); + let rhs = self.visit(*rhs); + self.ir + .push(Inst::Add(ty.into()), Some(Data::new(lhs, rhs))) } - Tag::Sub { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_add_sub() { - eprintln!("sub is not available for type {ty:?}"); - } - - self.ir.push(Inst::Sub { lhs, rhs }) + Tag::Mul { lhs, rhs } => { + let ty = self.tree.type_of_node(*lhs); + let lhs = self.visit(*lhs); + let rhs = self.visit(*rhs); + self.ir + .push(Inst::Mul(ty.into()), Some(Data::new(lhs, rhs))) } - Tag::Mul { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_mul_div_rem() { - eprintln!("mul is not available for type {ty:?}"); - } - - self.ir.push(Inst::Mul { lhs, rhs }) - } - Tag::Div { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_mul_div_rem() { - eprintln!("div is not available for type {ty:?}"); - } - - self.ir.push(Inst::Div { lhs, rhs }) - } - Tag::Rem { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_mul_div_rem() { - eprintln!("rem is not available for type {ty:?}"); - } - - self.ir.push(Inst::Rem { lhs, rhs }) - } - // bitwise - Tag::BitAnd { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_bitxor_and_or() { - eprintln!("bitand is not available for type {ty:?}"); - } - - self.ir.push(Inst::BitAnd { lhs, rhs }) - } - Tag::BitOr { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_bitxor_and_or() { - eprintln!("bitor is not available for type {ty:?}"); - } - - self.ir.push(Inst::BitOr { lhs, rhs }) - } - Tag::BitXOr { - lhs: left, - rhs: right, - } => { - let lhs = self.visit(*left); - let rhs = self.visit(*right); - - let ty = self.type_check(*left, *right); - if !ty.can_bitxor_and_or() { - eprintln!("bitxor is not available for type {ty:?}"); - } - - self.ir.push(Inst::BitXOr { lhs, rhs }) - } - Tag::Negate { lhs: left } => { - let lhs = self.visit(*left); - let ty = self.tree.type_of_node(*left); - if !ty.can_negate() { - eprintln!("negation is not available for type {ty:?}"); - } - self.ir.push(Inst::Negate { lhs }) + Tag::Negate { lhs } => { + let ty = self.tree.type_of_node(*lhs); + let lhs = self.visit(*lhs); + self.ir.push(Inst::Negate(ty.into()), Some(Data::lhs(lhs))) } Tag::DeclRef(decl) => match self.lookup.get_mut(decl) { Some(NodeOrList::Node(decl)) => *decl, - Some(NodeOrList::List(refs)) => { - let unresolved = self.ir.push(Inst::UnresolvedRef); - refs.push(unresolved); - unresolved - } - None => { - let unresolved = self.ir.push(Inst::UnresolvedRef); - self.lookup - .insert(*decl, NodeOrList::List(vec![unresolved])); - unresolved + lookup => { + println!("lookup for ast decl %{}", decl.get()); + println!("{lookup:?}"); + panic!("should not have any unresolved lookups") } }, - Tag::GlobalRef(decl) => self.ir.push(Inst::ExternRef(*decl)), + Tag::GlobalRef(decl) => self.ir.push(Inst::ExternRef, Some(Data::lhs(decl.get()))), Tag::Ref { lhs } => { + let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); - self.ir.push(Inst::AddressOf(lhs)) - } - Tag::Constant { bytes, ty } => { - let bytes = match ty { - Type::ComptimeNumber | Type::Floating(_) | Type::Integer(_) => Value { - explicit_type: Some(ty.clone()), - bytes: *bytes, - }, - _ => { - unimplemented!() - } - }; - self.ir.push(Inst::Constant(bytes)) + self.ir.push(Inst::Load(ty.into()), Some(Data::lhs(lhs))) } + Tag::Constant { bytes, .. } => match bytes { + ImmOrIndex::U64(v) => self.ir.push(Inst::ConstantU64, Some((*v).into())), + ImmOrIndex::U32(v) => self.ir.push(Inst::ConstantU32, Some((*v).into())), + ImmOrIndex::Index(idx) => { + self.ir.push(Inst::ConstantMultiByte, Some((*idx).into())) + } + }, _ => { dbg!(&self.tree.nodes[node]); todo!() } } } - - fn type_check(&mut self, lhs: AstNode, rhs: AstNode) -> Type { - let left_t = match self.type_map.entry(lhs.clone()) { - Entry::Occupied(o) => o.get().clone(), - Entry::Vacant(v) => v.insert(self.tree.type_of_node(lhs)).clone(), - }; - let right_t = match self.type_map.entry(rhs.clone()) { - Entry::Occupied(o) => o.get().clone(), - Entry::Vacant(v) => v.insert(self.tree.type_of_node(rhs)).clone(), - }; - match left_t.equal_type(&right_t) { - Some(t) => { - if left_t == Type::ComptimeNumber { - self.type_map.insert(lhs, t.clone()); - } - if right_t == Type::ComptimeNumber { - self.type_map.insert(rhs, t.clone()); - } - - t - } - None => { - eprintln!( - "incompatible types %{}: {left_t:?} and %{}: {right_t:?}!", - lhs.get(), - rhs.get() - ); - Type::void() - } - } - } } struct IR { nodes: Vec, + data: Vec>, } impl IR { pub fn new() -> Self { - Self { nodes: Vec::new() } + Self { + nodes: Vec::new(), + data: Vec::new(), + } } - fn push(&mut self, inst: Inst) -> u32 { + fn push(&mut self, inst: Inst, data: Option) -> u32 { let node = self.nodes.len() as u32; self.nodes.push(inst); + self.data.push(data); node } - pub fn build(&mut self, tree: &mut Tree) { + pub fn build<'a, 'tree>(&'a mut self, tree: &'tree mut Tree) -> IRBuilder<'tree, 'a> { let global_decls = tree.global_decls.clone(); let mut builder = IRBuilder::new(self, tree); for node in &global_decls { builder.visit(*node); } + builder } } -impl IR { +impl<'tree, 'ir> IRBuilder<'tree, 'ir> { fn render_node( &self, w: &mut W, node: Node, indent: u32, ) -> core::fmt::Result { - match &self.nodes[node as usize] { - Inst::Label(label) => { + let data = self.ir.data[node as usize] + .clone() + .unwrap_or(Data::new(0, 0)); + match &self.ir.nodes[node as usize] { + Inst::Label => { + let label = self.tree.strings.get_str(data.as_index()); writeln_indented!(indent - 1, w, "%{} = {label}:", node)?; } - Inst::UnresolvedRef => { - writeln_indented!(indent, w, "%{} = unresolved reference", node)?; + Inst::FunctionStart => { + let label = self.tree.strings.get_str(data.as_index()); + writeln_indented!(indent - 1, w, "%{} = func {label}:", node)?; } - Inst::Ref(reference) => { - writeln_indented!(indent, w, "%{} = reference(%{})", node, reference)?; - } - Inst::Parameter { size, align } => { + Inst::Parameter => { + let (size, align) = data.as_lhs_rhs(); writeln_indented!( indent, w, @@ -451,78 +416,650 @@ impl IR { align )?; } - Inst::Constant(value) => { - writeln_indented!(indent, w, "%{} = {}", node, value)?; + Inst::ConstantU32 => { + writeln_indented!(indent, w, "%{} = const i32 {}", node, data.as_u32())?; } - Inst::Add { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} + %{}", node, lhs, rhs)?; + Inst::ConstantU64 => { + writeln_indented!(indent, w, "%{} = const i64 {}", node, data.as_u64())?; } - Inst::Sub { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} - %{}", node, lhs, rhs)?; + Inst::ConstantMultiByte => { + let value = self.tree.strings.get_bytes(data.as_index()); + writeln_indented!(indent, w, "%{} = const bytes {:x?}", node, value)?; } - Inst::Mul { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} * %{}", node, lhs, rhs)?; + Inst::Add(ty) => { + let (lhs, rhs) = data.as_lhs_rhs(); + writeln_indented!(indent, w, "%{} = add_{ty}(%{} + %{})", node, lhs, rhs)?; } - Inst::Div { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} / %{}", node, lhs, rhs)?; + Inst::Sub(ty) => { + let (lhs, rhs) = data.as_lhs_rhs(); + writeln_indented!(indent, w, "%{} = sub_{ty}(%{} + %{})", node, lhs, rhs)?; } - Inst::Rem { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} % %{}", node, lhs, rhs)?; + Inst::Mul(ty) => { + let (lhs, rhs) = data.as_lhs_rhs(); + writeln_indented!(indent, w, "%{} = mul_{ty}(%{} + %{})", node, lhs, rhs)?; } - Inst::BitAnd { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} & %{}", node, lhs, rhs)?; + Inst::Negate(ty) => { + writeln_indented!(indent, w, "%{} = negate_{ty}(%{})", node, data.lhs)?; } - Inst::BitOr { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} | %{}", node, lhs, rhs)?; - } - Inst::BitXOr { lhs, rhs } => { - writeln_indented!(indent, w, "%{} = %{} ^ %{}", node, lhs, rhs)?; - } - Inst::Negate { lhs } => { - writeln_indented!(indent, w, "%{} = !%{}", node, lhs)?; - } - Inst::ReturnValue { lhs } => { - writeln_indented!(indent, w, "%{} = return %{}", node, lhs)?; + Inst::ReturnValue => { + writeln_indented!(indent, w, "%{} = return %{}", node, data.lhs)?; } Inst::Return => { writeln_indented!(indent, w, "%{} = return", node)?; } - Inst::Alloc { size, align } => { + Inst::Alloca => { + let (size, align) = data.as_lhs_rhs(); writeln_indented!(indent, w, "%{} = alloca {size} (align: {align})", node)?; } - Inst::AddressOf(val) => { - writeln_indented!(indent, w, "%{} = addr %{val}", node)?; - } - Inst::Load { source } => { - writeln_indented!(indent, w, "%{} = load ptr %{source}", node)?; - } - Inst::Store { dest, source } => { - writeln_indented!(indent, w, "%{} = store ptr %{dest} from %{source}", node)?; - } - Inst::ExternRef(ast_node) => { + Inst::GetElementPtr(ty) => { + let (ptr, idx) = data.as_lhs_rhs(); writeln_indented!( indent, w, - "%{} = extern reference ast-node %{}", - node, - ast_node.get() + "%{node} = getelementptr {ty}, ptr: %{}, idx: {}", + ptr, + idx )?; } - Inst::ExplicitCast { node: lhs, ty } => { - writeln_indented!(indent, w, "%{} = explicit_cast %{} to {}", node, lhs, ty)?; + Inst::Load(ty) => { + let source = data.lhs; + writeln_indented!(indent, w, "%{} = load {ty}, %{source}", node)?; + } + Inst::Store(ty) => { + let (src, dst) = data.as_lhs_rhs(); + writeln_indented!(indent, w, "%{} = store {ty}, ptr %{dst}, %{src}", node)?; + } + Inst::ExternRef => { + let ast = data.lhs; + writeln_indented!(indent, w, "%{} = extern reference ast-node %{}", node, ast)?; + } + _ => { + unimplemented!() } } Ok(()) } pub fn render(&self, w: &mut W) -> core::fmt::Result { - for node in 0..self.nodes.len() { + for node in 0..self.ir.nodes.len() { self.render_node(w, node as u32, 1)?; } Ok(()) } } +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum Registers { + A, + B, + C, + D, + SI, + DI, + R8, + R9, + R10, + R11, + R12, + R13, + R14, + R15, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Width { + QWord, + DWord, + Word, + Byte, +} + +impl Width { + fn from_size(size: u32) -> Option { + match size { + 0..=1 => Some(Self::Byte), + 1..=2 => Some(Self::Word), + 3..=4 => Some(Self::DWord), + 5..=8 => Some(Self::QWord), + _ => None, + } + } +} + +struct RegisterDisplay { + reg: Registers, + width: Width, +} + +impl core::fmt::Display for RegisterDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let prefix = match self.reg { + Registers::SI + | Registers::DI + | Registers::A + | Registers::B + | Registers::C + | Registers::D => match self.width { + Width::QWord => "r", + Width::DWord => "e", + Width::Word | Width::Byte => "", + }, + Registers::R8 + | Registers::R9 + | Registers::R10 + | Registers::R11 + | Registers::R12 + | Registers::R13 + | Registers::R14 + | Registers::R15 => "", + }; + let suffix = match self.reg { + Registers::SI | Registers::DI => match self.width { + Width::QWord | Width::DWord | Width::Word => "", + Width::Byte => "l", + }, + Registers::A | Registers::B | Registers::C | Registers::D => match self.width { + Width::QWord | Width::DWord | Width::Word => "x", + Width::Byte => "l", + }, + Registers::R8 + | Registers::R9 + | Registers::R10 + | Registers::R11 + | Registers::R12 + | Registers::R13 + | Registers::R14 + | Registers::R15 => match self.width { + Width::QWord => "", + Width::DWord => "d", + Width::Word => "w", + Width::Byte => "b", + }, + }; + + let name = match self.reg { + Registers::A => "a", + Registers::B => "b", + Registers::C => "c", + Registers::D => "d", + Registers::SI => "si", + Registers::DI => "di", + Registers::R8 => "r8", + Registers::R9 => "r9", + Registers::R10 => "r10", + Registers::R11 => "r11", + Registers::R12 => "r12", + Registers::R13 => "r13", + Registers::R14 => "r14", + Registers::R15 => "r15", + }; + + write!(f, "%{prefix}{name}{suffix}") + } +} + +impl Registers { + fn display(self, width: Width) -> RegisterDisplay { + RegisterDisplay { reg: self, width } + } + fn all() -> [Registers; 14] { + [ + Self::A, + Self::B, + Self::C, + Self::D, + Self::SI, + Self::DI, + Self::R8, + Self::R9, + Self::R10, + Self::R11, + Self::R12, + Self::R13, + Self::R14, + Self::R15, + ] + } + + fn sysv_param_idx(idx: u32) -> Option { + match idx { + 0 => Some(Self::DI), + 1 => Some(Self::SI), + 2 => Some(Self::D), + 3 => Some(Self::C), + 4 => Some(Self::R8), + 5 => Some(Self::R9), + _ => None, + } + } +} + +struct RegisterStore { + registers: [Option; 14], + used: BTreeSet, +} + +impl RegisterStore { + fn new() -> RegisterStore { + Self { + registers: Registers::all().map(|r| Some(r)), + used: BTreeSet::new(), + } + } + fn take_any(&mut self) -> Option { + let a = self.registers.iter_mut().filter(|r| r.is_some()).next()?; + let reg = a.take()?; + self.used.insert(reg); + Some(reg) + } + fn force_take(&mut self, reg: Registers) { + self.registers[reg as usize] = None; + self.used.insert(reg); + } + fn free(&mut self, reg: Registers) { + self.registers[reg as usize] = Some(reg); + } +} + +struct StackMem { + offset: u32, +} + +impl StackMem { + fn new(offset: u32) -> Self { + Self { offset } + } +} + +impl core::fmt::Display for StackMem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "-{}(%rbp)", self.offset) + } +} + +enum ImmRegMem { + ImmU32(u32), + ImmU64(u64), + Mem(StackMem), + Reg(Registers, Width), +} + +impl core::fmt::Display for ImmRegMem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ImmRegMem::ImmU32(v) => write!(f, "{v}"), + ImmRegMem::ImmU64(v) => write!(f, "{v}"), + ImmRegMem::Mem(mem) => write!(f, "{mem}"), + ImmRegMem::Reg(reg, width) => write!(f, "{}", reg.display(*width)), + } + } +} + +#[derive(Debug, Default)] +struct Function { + name: String, + entry: String, + branches: HashMap, + stack_size: u32, + used_registers: Vec, +} + +impl Function { + fn write(&self, w: &mut W) -> core::fmt::Result { + writeln!(w, "{}:", self.name)?; + + for reg in self.used_registers.iter().filter(|r| r != &&Registers::A) { + writeln!(w, "push {}", reg.display(Width::QWord))?; + } + + writeln!(w, "push %rbp")?; + writeln!(w, "mov %rsp, %rbp")?; + + write!(w, "{}", self.entry)?; + + writeln!(w, "{}__body:", self.name)?; + write!(w, "{}", self.branches.get("main").unwrap())?; + + for (name, content) in &self.branches { + if name != "main" { + writeln!(w, "{}__{name}:", self.name)?; + write!(w, "{content}")?; + } + } + + writeln!(w, "{}__epilogue:", self.name)?; + writeln!(w, "mov %rbp, %rsp")?; + writeln!(w, "pop %rbp")?; + + for reg in self + .used_registers + .iter() + .rev() + .filter(|r| r != &&Registers::A) + { + writeln!(w, "pop {}", reg.display(Width::QWord))?; + } + + writeln!(w, "ret")?; + + Ok(()) + } +} + +struct IRIter<'a> { + ir: &'a IR, + offset: usize, + item: Option<(Inst, Option)>, +} + +impl<'a> IRIter<'a> { + fn node(&self) -> u32 { + self.offset as Node + } +} +impl<'a> Iterator for IRIter<'a> { + type Item = (Inst, Option); + + fn next(&mut self) -> Option { + let inst = self.ir.nodes.get(self.offset)?; + let data = self.ir.data.get(self.offset)?; + self.offset += 1; + self.item = Some((*inst, *data)); + Some((*inst, *data)) + } +} + +struct Assembler<'a> { + ir: IRIter<'a>, + strings: StringTable, + constants: HashMap>, + functions: Vec, +} +use core::fmt::Write; + +impl<'a> Assembler<'a> { + fn from_ir(ir: &'a IR, strings: StringTable) -> Assembler<'a> { + Self { + ir: IRIter { + ir, + offset: 0, + item: None, + }, + strings, + constants: HashMap::new(), + functions: Vec::new(), + } + } + fn assemble_function(&mut self, name: String) -> core::fmt::Result { + // hashmap of node indices and offsets from the base pointer + let mut allocas = HashMap::::new(); + + let mut register_store = RegisterStore::new(); + let mut registers = BTreeMap::::new(); + let mut stack_offset = 0; + let mut func = Function::default(); + func.name = name; + let mut param_count = 0; + + let mut current_branch = "main".to_owned(); + func.branches.insert(current_branch.clone(), String::new()); + + loop { + let node = self.ir.node(); + let Some((inst, data)) = self.ir.next() else { + break; + }; + + match inst { + Inst::FunctionStart => { + self.ir.offset -= 1; + break; + } + Inst::Label => { + current_branch = self.strings.get_str(data.unwrap().as_index()).to_owned(); + func.branches.insert(current_branch.clone(), String::new()); + } + Inst::ConstantU32 => { + let value = data.unwrap().as_u32(); + match self.constants.entry(ImmOrIndex::U32(value)) { + Entry::Occupied(mut o) => o.get_mut().push(node), + Entry::Vacant(v) => { + v.insert(vec![node]); + } + } + let reg = register_store.take_any().unwrap(); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "mov ${value}, {}", + reg.display(Width::DWord) + )?; + registers.insert(reg, node); + } + Inst::ConstantU64 => { + let value = data.unwrap().as_u64(); + match self.constants.entry(ImmOrIndex::U64(value)) { + Entry::Occupied(mut o) => o.get_mut().push(node), + Entry::Vacant(v) => { + v.insert(vec![node]); + } + } + let reg = register_store.take_any().unwrap(); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "mov ${value}, {}", + reg.display(Width::QWord) + )?; + registers.insert(reg, node); + } + Inst::ConstantMultiByte => { + let value = data.unwrap().as_index(); + match self.constants.entry(ImmOrIndex::Index(value)) { + Entry::Occupied(mut o) => o.get_mut().push(node), + Entry::Vacant(v) => { + v.insert(vec![node]); + } + } + todo!() + } + Inst::ExternRef => todo!(), + Inst::Alloca => { + let (size, align) = data.unwrap().as_lhs_rhs(); + let size = size.next_multiple_of(align); + writeln!(&mut func.entry, "sub ${size}, %rsp")?; + stack_offset += size; + allocas.insert(node, stack_offset); + } + Inst::Load(ty) => { + let src = data.unwrap().lhs; + let src = registers + .iter() + .find(|(_, node)| node == &&src) + .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) + .or_else(|| { + allocas + .get(&src) + .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) + }) + .expect(&format!( + "src_reg from node %{src} not found: {registers:?}" + )); + let dst_reg = register_store.take_any().unwrap(); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "mov {}, {}", + src, + dst_reg.display(Width::from_size(ty.size()).unwrap()), + )?; + if let ImmRegMem::Reg(reg, _) = src { + register_store.free(reg); + registers.remove(®); + } + registers.insert(dst_reg, node); + } + Inst::Store(ty) => { + let (src, dst) = data.unwrap().as_lhs_rhs(); + let src = registers + .iter() + .find(|(_, node)| node == &&src) + .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) + .or_else(|| { + allocas + .get(&src) + .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) + }) + .expect(&format!( + "src_reg from node %{src} not found: {registers:?}" + )); + let dst = registers + .iter() + .find(|(_, node)| node == &&dst) + .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) + .or_else(|| { + allocas + .get(&dst) + .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) + }) + .expect(&format!( + "src_reg from node %{src} not found: {registers:?}" + )); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "mov {}, {}", + src, + dst, + )?; + + if let ImmRegMem::Reg(reg, _) = src { + register_store.free(reg); + registers.remove(®); + } + } + Inst::GetElementPtr(_) => todo!(), + Inst::Parameter => { + let param_reg = Registers::sysv_param_idx(param_count).unwrap(); + param_count += 1; + register_store.force_take(param_reg); + registers.insert(param_reg, node); + } + Inst::Add(ty) => { + let (src, dst) = data.unwrap().as_lhs_rhs(); + let (&src_reg, _) = registers.iter().find(|(_, node)| node == &&src).unwrap(); + let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap(); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "add {}, {}", + src_reg.display(Width::from_size(ty.size()).unwrap()), + dst_reg.display(Width::from_size(ty.size()).unwrap()), + )?; + + if src_reg != dst_reg { + register_store.free(src_reg); + registers.remove(&src_reg); + } + registers.insert(dst_reg, node); + } + Inst::Sub(ty) => { + let (src, dst) = data.unwrap().as_lhs_rhs(); + let (&src_reg, _) = registers.iter().find(|(_, node)| node == &&src).unwrap(); + let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap(); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "sub {}, {}", + src_reg.display(Width::from_size(ty.size()).unwrap()), + dst_reg.display(Width::from_size(ty.size()).unwrap()), + )?; + + if src_reg != dst_reg { + register_store.free(src_reg); + registers.remove(&src_reg); + } + registers.insert(dst_reg, node); + } + Inst::Mul(ty) => { + let (src, dst) = data.unwrap().as_lhs_rhs(); + let (&src_reg, _) = registers.iter().find(|(_, node)| node == &&src).unwrap(); + let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap(); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "imul {}, {}", + src_reg.display(Width::from_size(ty.size()).unwrap()), + dst_reg.display(Width::from_size(ty.size()).unwrap()), + )?; + + if src_reg != dst_reg { + register_store.free(src_reg); + registers.remove(&src_reg); + } + registers.insert(dst_reg, node); + } + Inst::Div(_) => todo!(), + Inst::Rem(_) => todo!(), + Inst::BitAnd(_) => todo!(), + Inst::BitOr(_) => todo!(), + Inst::BitXOr(_) => todo!(), + Inst::Negate(_) => todo!(), + Inst::ReturnValue => { + let val = data.unwrap().lhs; + let (®, _) = registers.iter().find(|(_, node)| node == &&val) + .expect(&format!( + "location for node %{val} not found: \nregisters: {registers:?}\nallocas: {allocas:?}" + )); + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "mov {}, %rax\njmp {}__epilogue", + reg.display(Width::QWord), + func.name + )?; + + register_store.free(reg); + registers.remove(®); + } + Inst::Return => { + writeln!( + func.branches.get_mut(¤t_branch).unwrap(), + "jmp {}__epilogue", + func.name + )?; + } + } + } + + func.stack_size = stack_offset; + func.used_registers = register_store.used.into_iter().collect(); + + self.functions.push(func); + Ok(()) + } + + fn assemble(&mut self) -> core::fmt::Result { + loop { + let Some((inst, data)) = self.ir.next() else { + break; + }; + match inst { + Inst::FunctionStart => { + let name = self.strings.get_str(data.unwrap().as_index()); + self.assemble_function(name.to_owned())? + } + _ => {} + } + } + + Ok(()) + } + + fn finish(&self, w: &mut W) -> core::fmt::Result { + for func in self.functions.iter() { + writeln!(w, ".globl {}", func.name)?; + } + for func in self.functions.iter() { + func.write(w)?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::lexer::Tokenizer; @@ -535,14 +1072,12 @@ mod tests { fn main() -> u32 { let a: u32 = 0u32 + 3u32; let ptr_a = &a; - return *ptr_a * global; + return *ptr_a * 2u32; } fn square(x: u32) -> u32 { x * x } - -const global: u32 = 42u32; "; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); @@ -552,12 +1087,18 @@ const global: u32 = 42u32; let mut buf = String::new(); tree.render(&mut buf).unwrap(); println!("{buf}"); - println!("{:#?}", tree.strings); let mut ir = IR::new(); - ir.build(&mut tree); + let builder = ir.build(&mut tree); let mut buf = String::new(); - ir.render(&mut buf).unwrap(); + builder.render(&mut buf).unwrap(); + println!("{buf}"); + + let mut assembler = Assembler::from_ir(&ir, tree.strings); + assembler.assemble().unwrap(); + + let mut buf = String::new(); + assembler.finish(&mut buf).unwrap(); println!("{buf}"); } }