#![allow(dead_code)] use std::collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap}; use crate::{ ast::{IntegralType, Node as AstNode, Tag, Type}, parser::Tree, string_table::{ImmOrIndex, Index as StringsIndex, StringTable}, variant, writeln_indented, }; type Node = u32; #[derive(Debug)] enum NodeOrList { Node(Node), // node of alloca location List(Vec), // list of references to `Node(_)` } #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Type2 { Integral(bool, u16), Binary32, Binary64, Bool, Pointer, } impl Into for Type2 { fn into(self) -> mir::Type { self.mir_type() } } impl Type2 { fn mir_type(self) -> mir::Type { match self { Type2::Integral(_, bits) => mir::Type::from_bitsize_int(bits as u32), Type2::Binary32 => mir::Type::SinglePrecision, Type2::Binary64 => mir::Type::DoublePrecision, Type2::Bool => mir::Type::from_bitsize_int(1), Type2::Pointer => mir::Type::QWord, } } fn is_signed(self) -> bool { match self { Type2::Integral(signed, _) => signed, _ => false, } } fn mir_unalignment(self) -> Option<(bool, u16)> { match self { Type2::Integral(signed, bits) => match bits { 8 | 16 | 32 | 64 => None, bits => Some((signed, bits)), }, _ => None, } } fn size(&self) -> u32 { match self { Type2::Integral(_signed, bits) => bits.div_ceil(8) as u32, Type2::Binary32 => 4, Type2::Binary64 => 8, Type2::Bool => 1, Type2::Pointer => 8, } } fn align(&self) -> u32 { self.size() } } impl core::fmt::Display for Type2 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type2::Integral(signed, bits) => write!(f, "{}{bits}", if *signed { "i" } else { "u" }), Type2::Binary32 => write!(f, "f32"), Type2::Binary64 => write!(f, "f64"), Type2::Bool => write!(f, "bool"), Type2::Pointer => write!(f, "ptr"), } } } impl From for Type2 { fn from(value: Type) -> Self { (&value).into() } } impl From<&Type> for Type2 { fn from(value: &Type) -> Self { match value { Type::Bool => Type2::Bool, Type::Integer(i) => Type2::Integral(i.signed, i.bits), Type::Floating(f) => match f { crate::ast::FloatingType::Binary32 => Type2::Binary32, crate::ast::FloatingType::Binary64 => Type2::Binary64, }, Type::Pointer { .. } => Type2::Pointer, _ => { unimplemented!("conversion from {value:?} to triples type not implemented") } } } } #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Inst { /// index Label, /// index FunctionStart, /// u32 ConstantU32, /// lo, hi ConstantU64, /// index ConstantMultiByte, /// ast-node ExternRef, /// size, align Alloca, /// src Load(Type2), /// src, dst Store(Type2), /// ptr, index, GetElementPtr(Type2), /// size, align Parameter(Type2), /// lhs, rhs Add(Type2), /// lhs, rhs Sub(Type2), /// lhs, rhs Mul(Type2), /// lhs, rhs Div(Type2), /// lhs, rhs Rem(Type2), /// lhs, rhs BitAnd(Type2), /// lhs, rhs BitOr(Type2), /// lhs, rhs BitXOr(Type2), /// lhs, rhs ShiftLeft(Type2), /// lhs, rhs ShiftRight(Type2), /// lhs Negate(Type2), /// lhs ExplicitCast(Type2, Type2), /// lhs ReturnValue(Type2), /// no parameters Return, } impl Inst { fn is_constant(self) -> bool { match self { Inst::ConstantU32 | Inst::ConstantU64 | Inst::ConstantMultiByte => true, _ => false, } } } #[derive(Debug, Clone, Copy)] struct Data { lhs: u32, rhs: u32, } impl Data { fn new(lhs: u32, rhs: u32) -> Self { Self { lhs, rhs } } fn lhs(lhs: u32) -> Data { Self { lhs, rhs: 0 } } fn as_u32(&self) -> u32 { self.lhs } fn as_u64(&self) -> u64 { self.lhs as u64 | (self.rhs as u64) << u32::BITS as u64 } fn as_index(&self) -> StringsIndex { crate::string_table::Index { start: self.lhs, end: self.rhs, } } fn as_lhs_rhs(&self) -> (u32, u32) { (self.lhs, self.rhs) } } impl From for Data { fn from(value: u32) -> Self { Self { lhs: value, rhs: 0 } } } impl From for Data { fn from(value: u64) -> Self { let (lo, hi) = { (value as u32, (value >> u32::BITS as u64) as u32) }; Self { lhs: lo, rhs: hi } } } impl From for Data { fn from(value: crate::string_table::ImmOrIndex) -> Self { match value { ImmOrIndex::U64(v) => v.into(), ImmOrIndex::U32(v) => v.into(), ImmOrIndex::Index(v) => v.into(), } } } impl From for Data { fn from(value: crate::string_table::Index) -> Self { Self { lhs: value.start, rhs: value.end, } } } pub struct IRBuilder<'tree, 'ir> { ir: &'ir mut IR, tree: &'tree mut Tree, type_map: HashMap, lookup: HashMap, } impl core::ops::Index for IR { type Output = Inst; fn index(&self, index: Node) -> &Self::Output { &self.nodes[index as usize] } } impl<'tree, 'ir> IRBuilder<'tree, 'ir> { fn new(ir: &'ir mut IR, tree: &'tree mut Tree) -> Self { Self { ir, tree, type_map: HashMap::new(), lookup: HashMap::new(), } } fn visit(&mut self, node: AstNode) -> Node { match &self.tree.nodes[node].clone() { Tag::FunctionDecl { proto, body } => { variant!( Tag::FunctionProto { name, parameters, .. } = self.tree.nodes.get_node(*proto) ); self.ir.push(Inst::FunctionStart, { variant!(Tag::Ident { name } = self.tree.nodes.get_node(*name)); Some((*name).into()) }); if let Some(parameters) = parameters { variant!( Tag::ParameterList { parameters } = self.tree.nodes.get_node(*parameters) ); for param in parameters { variant!(Tag::Parameter { ty, .. } = self.tree.nodes.get_node(*param)); let ty = self.tree.type_of_node(*ty); let size = ty.size_of(); let align = ty.align_of(); let ir = self .ir .push(Inst::Parameter(ty.into()), Some(Data::new(size, align))); self.lookup.insert(*param, NodeOrList::Node(ir)); } } self.tree.st.into_child(node); let value = self.visit(*body); // TODO: return value of body expression self.tree.st.into_parent(); if value != !0 { let ty = self.tree.type_of_node(*body); self.ir .push(Inst::ReturnValue(ty.into()), Some(Data::lhs(value))) } else { !0 } } Tag::Block { statements, trailing_expr, } => { for stmt in statements { self.visit(*stmt); } if let Some(expr) = trailing_expr { self.visit(*expr) } else { !0 } } Tag::VarDecl { .. } => { let ty = self.tree.type_of_node(node); let alloca = self .ir .push(Inst::Alloca, Some(Data::new(ty.size_of(), ty.align_of()))); self.lookup.insert(node, NodeOrList::Node(alloca)); alloca } Tag::GlobalDecl { .. } => { // self.ir.push(Inst::Label, { // variant!(Tag::Ident { name } = self.tree.nodes.get_node(*name)); // Some((*name).into()) // }); unimplemented!() } Tag::ReturnStmt { expr } => { if let Some(expr) = expr { let ty = self.tree.type_of_node(*expr); let expr = self.visit(*expr); self.ir .push(Inst::ReturnValue(ty.into()), Some(Data::lhs(expr))) } else { self.ir.push(Inst::Return, None) } } Tag::ExprStmt { expr } => self.visit(*expr), Tag::Deref { lhs } => { let ty = self.tree.type_of_node(*lhs).pointee().unwrap().clone(); let lhs = self.visit(*lhs); self.ir.push(Inst::Load(ty.into()), Some(Data::lhs(lhs))) } Tag::Assign { lhs, rhs } => { let ty = self.tree.type_of_node(*rhs); let dest = self.visit(*lhs); let source = self.visit(*rhs); self.ir .push(Inst::Store(ty.into()), Some(Data::new(source, dest))) } Tag::Add { lhs, rhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); let rhs = self.visit(*rhs); self.ir .push(Inst::Add(ty.into()), Some(Data::new(lhs, rhs))) } Tag::Sub { lhs, rhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); let rhs = self.visit(*rhs); self.ir .push(Inst::Sub(ty.into()), Some(Data::new(lhs, rhs))) } Tag::Mul { lhs, rhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); let rhs = self.visit(*rhs); self.ir .push(Inst::Mul(ty.into()), Some(Data::new(lhs, rhs))) } Tag::Negate { lhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); self.ir.push(Inst::Negate(ty.into()), Some(Data::lhs(lhs))) } Tag::Shl { lhs, rhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); let rhs = self.visit(*rhs); self.ir .push(Inst::ShiftLeft(ty.into()), Some(Data::new(lhs, rhs))) } Tag::Shr { lhs, rhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); let rhs = self.visit(*rhs); self.ir .push(Inst::ShiftRight(ty.into()), Some(Data::new(lhs, rhs))) } Tag::DeclRef(decl) => match self.lookup.get_mut(decl) { Some(NodeOrList::Node(decl)) => *decl, lookup => { println!("lookup for ast decl %{}", decl.get()); println!("{lookup:?}"); panic!("should not have any unresolved lookups") } }, Tag::GlobalRef(decl) => self.ir.push(Inst::ExternRef, Some(Data::lhs(decl.get()))), Tag::Ref { lhs } => { let ty = self.tree.type_of_node(*lhs); let lhs = self.visit(*lhs); // self.ir.push(Inst::Load(ty.into()), Some(Data::lhs(lhs))) // nothing happens here because lhs is of type pointer self.ir .push(Inst::GetElementPtr(ty.into()), Some(Data::new(lhs, 0))) } Tag::Constant { bytes, .. } => match bytes { ImmOrIndex::U64(v) => self.ir.push(Inst::ConstantU64, Some((*v).into())), ImmOrIndex::U32(v) => self.ir.push(Inst::ConstantU32, Some((*v).into())), ImmOrIndex::Index(idx) => { self.ir.push(Inst::ConstantMultiByte, Some((*idx).into())) } }, Tag::ExplicitCast { lhs, typename } => { let l_ty = self.tree.type_of_node(*lhs).clone(); let r_ty = self.tree.type_of_node(*typename).clone(); let lhs = self.visit(*lhs); if l_ty.bit_width() == r_ty.bit_width() { //noop? lhs } else { self.ir.push( Inst::ExplicitCast(l_ty.into(), r_ty.into()), Some(Data::lhs(lhs)), ) } } _ => { dbg!(&self.tree.nodes[node]); todo!() } } } } pub struct IR { nodes: Vec, data: Vec>, } impl IR { pub fn new() -> Self { Self { nodes: Vec::new(), data: Vec::new(), } } fn push(&mut self, inst: Inst, data: Option) -> u32 { let node = self.nodes.len() as u32; self.nodes.push(inst); self.data.push(data); node } pub fn build<'a, 'tree>(&'a mut self, tree: &'tree mut Tree) -> IRBuilder<'tree, 'a> { let global_decls = tree.global_decls.clone(); let mut builder = IRBuilder::new(self, tree); for node in &global_decls { builder.visit(*node); } builder } } impl<'tree, 'ir> IRBuilder<'tree, 'ir> { fn render_node( &self, w: &mut W, node: Node, indent: u32, ) -> core::fmt::Result { let data = self.ir.data[node as usize] .clone() .unwrap_or(Data::new(0, 0)); let inst = self.ir.nodes[node as usize]; match inst { Inst::Label => { let label = self.tree.strings.get_str(data.as_index()); writeln_indented!(indent - 1, w, "%{} = {label}:", node)?; } Inst::FunctionStart => { let label = self.tree.strings.get_str(data.as_index()); writeln_indented!(indent - 1, w, "%{} = func {label}:", node)?; } Inst::Parameter(ty) => { let (size, align) = data.as_lhs_rhs(); writeln_indented!( indent, w, "%{} = param {ty} (size: {}, align: {})", node, size, align )?; } Inst::ConstantU32 => { writeln_indented!(indent, w, "%{} = const i32 {}", node, data.as_u32())?; } Inst::ConstantU64 => { writeln_indented!(indent, w, "%{} = const i64 {}", node, data.as_u64())?; } Inst::ConstantMultiByte => { let value = self.tree.strings.get_bytes(data.as_index()); writeln_indented!(indent, w, "%{} = const bytes {:x?}", node, value)?; } Inst::Add(ty) => { let (lhs, rhs) = data.as_lhs_rhs(); writeln_indented!(indent, w, "%{} = add_{ty}(%{} + %{})", node, lhs, rhs)?; } Inst::Sub(ty) => { let (lhs, rhs) = data.as_lhs_rhs(); writeln_indented!(indent, w, "%{} = sub_{ty}(%{} - %{})", node, lhs, rhs)?; } Inst::Mul(ty) => { let (lhs, rhs) = data.as_lhs_rhs(); writeln_indented!(indent, w, "%{} = mul_{ty}(%{} * %{})", node, lhs, rhs)?; } Inst::Negate(ty) => { writeln_indented!(indent, w, "%{} = negate_{ty}(%{})", node, data.lhs)?; } Inst::ExplicitCast(from, to) => { writeln_indented!(indent, w, "%{} = cast_{from}_to_{to}(%{})", node, data.lhs)?; } Inst::ShiftLeft(ty) => { writeln_indented!( indent, w, "%{} = shl_{ty}(%{} << %{})", node, data.lhs, data.rhs )?; } Inst::ShiftRight(ty) => { writeln_indented!( indent, w, "%{} = shr_{ty}(%{} >> %{})", node, data.lhs, data.rhs )?; } Inst::ReturnValue(ty) => { writeln_indented!(indent, w, "%{} = return {ty} %{}", node, data.lhs)?; } Inst::Return => { writeln_indented!(indent, w, "%{} = return", node)?; } Inst::Alloca => { let (size, align) = data.as_lhs_rhs(); writeln_indented!(indent, w, "%{} = alloca {size} (align: {align})", node)?; } Inst::GetElementPtr(ty) => { let (ptr, idx) = data.as_lhs_rhs(); writeln_indented!( indent, w, "%{node} = getelementptr {ty}, ptr: %{}, idx: {}", ptr, idx )?; } Inst::Load(ty) => { let source = data.lhs; writeln_indented!(indent, w, "%{} = load {ty}, %{source}", node)?; } Inst::Store(ty) => { let (src, dst) = data.as_lhs_rhs(); writeln_indented!(indent, w, "%{} = store {ty}, ptr %{dst}, %{src}", node)?; } Inst::ExternRef => { let ast = data.lhs; writeln_indented!(indent, w, "%{} = extern reference ast-node %{}", node, ast)?; } _ => { unimplemented!("{inst:?} rendering unimplemented") } } Ok(()) } pub fn render(&self, w: &mut W) -> core::fmt::Result { for node in 0..self.ir.nodes.len() { self.render_node(w, node as u32, 1)?; } Ok(()) } } #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] enum Registers { A, B, C, D, SI, DI, R8, R9, R10, R11, R12, R13, R14, R15, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Width { QWord, DWord, Word, Byte, } impl Width { fn from_size(size: u32) -> Option { match size { 1 => Some(Self::Byte), 2 => Some(Self::Word), 3..=4 => Some(Self::DWord), 5..=8 => Some(Self::QWord), _ => None, } } } struct RegisterDisplay { reg: Registers, width: Width, } impl core::fmt::Display for RegisterDisplay { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let prefix = match self.reg { Registers::SI | Registers::DI | Registers::A | Registers::B | Registers::C | Registers::D => match self.width { Width::QWord => "r", Width::DWord => "e", Width::Word | Width::Byte => "", }, Registers::R8 | Registers::R9 | Registers::R10 | Registers::R11 | Registers::R12 | Registers::R13 | Registers::R14 | Registers::R15 => "", }; let suffix = match self.reg { Registers::SI | Registers::DI => match self.width { Width::QWord | Width::DWord | Width::Word => "", Width::Byte => "l", }, Registers::A | Registers::B | Registers::C | Registers::D => match self.width { Width::QWord | Width::DWord | Width::Word => "x", Width::Byte => "l", }, Registers::R8 | Registers::R9 | Registers::R10 | Registers::R11 | Registers::R12 | Registers::R13 | Registers::R14 | Registers::R15 => match self.width { Width::QWord => "", Width::DWord => "d", Width::Word => "w", Width::Byte => "b", }, }; let name = match self.reg { Registers::A => "a", Registers::B => "b", Registers::C => "c", Registers::D => "d", Registers::SI => "si", Registers::DI => "di", Registers::R8 => "r8", Registers::R9 => "r9", Registers::R10 => "r10", Registers::R11 => "r11", Registers::R12 => "r12", Registers::R13 => "r13", Registers::R14 => "r14", Registers::R15 => "r15", }; write!(f, "{prefix}{name}{suffix}") } } impl Registers { fn display(self, width: Width) -> RegisterDisplay { RegisterDisplay { reg: self, width } } fn all() -> [Registers; 14] { [ Self::A, Self::B, Self::C, Self::D, Self::SI, Self::DI, Self::R8, Self::R9, Self::R10, Self::R11, Self::R12, Self::R13, Self::R14, Self::R15, ] } fn sysv_param_idx(idx: u32) -> Option { match idx { 0 => Some(Self::DI), 1 => Some(Self::SI), 2 => Some(Self::D), 3 => Some(Self::C), 4 => Some(Self::R8), 5 => Some(Self::R9), _ => None, } } } struct RegisterStore { registers: [Option; 14], used: BTreeSet, } impl RegisterStore { fn new() -> RegisterStore { Self { registers: Registers::all().map(|r| Some(r)), used: BTreeSet::new(), } } fn take_any(&mut self) -> Option { let a = self.registers.iter_mut().filter(|r| r.is_some()).next()?; let reg = a.take()?; self.used.insert(reg); Some(reg) } fn force_take(&mut self, reg: Registers) { self.registers[reg as usize] = None; } fn free(&mut self, reg: Registers) { self.registers[reg as usize] = Some(reg); } } struct StackMem { offset: u32, } impl StackMem { fn new(offset: u32) -> Self { Self { offset } } } impl core::fmt::Display for StackMem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[rbp - 0x{:x}]", self.offset) } } enum ImmRegMem { ImmU32(u32), ImmU64(u64), Mem(StackMem), Reg(Registers, Width), } impl core::fmt::Display for ImmRegMem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ImmRegMem::ImmU32(v) => write!(f, "0x{v:x}"), ImmRegMem::ImmU64(v) => write!(f, "0x{v:x}"), ImmRegMem::Mem(mem) => write!(f, "{mem}"), ImmRegMem::Reg(reg, width) => write!(f, "{}", reg.display(*width)), } } } #[derive(Debug, Default)] struct Function { name: String, entry: String, branches: HashMap, stack_size: u32, used_registers: Vec, } impl Function { fn write(&self, w: &mut W) -> core::fmt::Result { writeln!(w, "{}:", self.name)?; for reg in self.used_registers.iter() { writeln!(w, "push {}", reg.display(Width::QWord))?; } writeln!(w, "push rbp")?; writeln!(w, "mov rbp, rsp")?; write!(w, "{}", self.entry)?; writeln!(w, "{}__body:", self.name)?; write!(w, "{}", self.branches.get("main").unwrap())?; for (name, content) in &self.branches { if name != "main" { writeln!(w, "{}__{name}:", self.name)?; write!(w, "{content}")?; } } writeln!(w, "{}__epilogue:", self.name)?; writeln!(w, "mov rsp, rbp")?; writeln!(w, "pop rbp")?; for reg in self.used_registers.iter().rev() { writeln!(w, "pop {}", reg.display(Width::QWord))?; } writeln!(w, "ret")?; Ok(()) } } struct IRIter<'a> { ir: &'a IR, offset: usize, item: Option<(Inst, Option)>, } impl<'a> IRIter<'a> { fn node(&self) -> u32 { self.offset as Node } } impl<'a> Iterator for IRIter<'a> { type Item = (Inst, Option); fn next(&mut self) -> Option { let inst = self.ir.nodes.get(self.offset)?; let data = self.ir.data.get(self.offset)?; self.offset += 1; self.item = Some((*inst, *data)); Some((*inst, *data)) } } struct Assembler<'a> { ir: IRIter<'a>, strings: StringTable, constants: HashMap>, functions: Vec, } use core::fmt::Write; impl<'a> Assembler<'a> { fn from_ir(ir: &'a IR, strings: StringTable) -> Assembler<'a> { Self { ir: IRIter { ir, offset: 0, item: None, }, strings, constants: HashMap::new(), functions: Vec::new(), } } fn assemble_function(&mut self, name: String) -> core::fmt::Result { // hashmap of node indices and offsets from the base pointer let mut allocas = HashMap::::new(); let mut register_store = RegisterStore::new(); // rax as scratch register register_store.force_take(Registers::A); let mut registers = BTreeMap::::new(); let mut stack_offset = 0; let mut func = Function::default(); func.name = name; let mut param_count = 0; let mut current_branch = "main".to_owned(); func.branches.insert(current_branch.clone(), String::new()); loop { let node = self.ir.node(); let Some((inst, data)) = self.ir.next() else { break; }; match inst { Inst::FunctionStart => { self.ir.offset -= 1; break; } Inst::Label => { current_branch = self.strings.get_str(data.unwrap().as_index()).to_owned(); func.branches.insert(current_branch.clone(), String::new()); } Inst::ConstantU32 => { let value = data.unwrap().as_u32(); match self.constants.entry(ImmOrIndex::U32(value)) { Entry::Occupied(mut o) => o.get_mut().push(node), Entry::Vacant(v) => { v.insert(vec![node]); } } let reg = register_store.take_any().unwrap(); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov {}, {value}", reg.display(Width::DWord) )?; registers.insert(reg, node); } Inst::ConstantU64 => { let value = data.unwrap().as_u64(); match self.constants.entry(ImmOrIndex::U64(value)) { Entry::Occupied(mut o) => o.get_mut().push(node), Entry::Vacant(v) => { v.insert(vec![node]); } } let reg = register_store.take_any().unwrap(); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov {}, {value}", reg.display(Width::QWord) )?; registers.insert(reg, node); } Inst::ConstantMultiByte => { let value = data.unwrap().as_index(); match self.constants.entry(ImmOrIndex::Index(value)) { Entry::Occupied(mut o) => o.get_mut().push(node), Entry::Vacant(v) => { v.insert(vec![node]); } } todo!() } Inst::ExternRef => todo!(), Inst::Alloca => { let (size, align) = data.unwrap().as_lhs_rhs(); let size = size.next_multiple_of(align); writeln!(&mut func.entry, "sub rsp, 0x{size:x}")?; stack_offset += size; allocas.insert(node, stack_offset); } Inst::Load(ty) => { let src = data.unwrap().lhs; let src = registers .iter() .find(|(_, node)| node == &&src) .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) .or_else(|| { allocas .get(&src) .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) }) .expect(&format!( "src_reg from node %{src} not found: {registers:?}" )); let dst_reg = register_store.take_any().unwrap(); match src { ImmRegMem::Reg(_, _) => { writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov {}, [{}]", dst_reg.display(Width::from_size(ty.size()).unwrap()), src, )?; } ImmRegMem::Mem(ref mem) => { let tmp_reg = register_store.take_any().unwrap(); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov {}, {}", tmp_reg.display(Width::QWord), mem, )?; writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov {}, [{}]", dst_reg.display(Width::from_size(ty.size()).unwrap()), tmp_reg.display(Width::QWord), )?; register_store.free(tmp_reg); } _ => {} } if let ImmRegMem::Reg(reg, _) = src { register_store.free(reg); registers.remove(®); } registers.insert(dst_reg, node); } Inst::Store(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let src = registers .iter() .find(|(_, node)| node == &&src) .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) .or_else(|| { allocas .get(&src) .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) }) .expect(&format!( "src_reg from node %{src} not found: {registers:?}" )); let dst = registers .iter() .find(|(_, node)| node == &&dst) .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) .or_else(|| { allocas .get(&dst) .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) }) .expect(&format!( "src_reg from node %{src} not found: {registers:?}" )); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov {}, {}", dst, src, )?; if let ImmRegMem::Reg(reg, _) = src { register_store.free(reg); registers.remove(®); } } Inst::GetElementPtr(ty) => { let (ptr, idx) = data.unwrap().as_lhs_rhs(); let src = registers .iter() .find(|(_, node)| node == &&ptr) .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) .or_else(|| { allocas .get(&ptr) .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) }) .expect(&format!( "src_reg from node %{ptr} not found: {registers:?}" )); let dst_reg = register_store.take_any().unwrap(); if let ImmRegMem::Mem(_) = &src { writeln!( func.branches.get_mut(¤t_branch).unwrap(), "lea {}, {}", ImmRegMem::Reg(dst_reg, Width::QWord), src, )?; } let offset = idx * ty.size(); if offset != 0 { writeln!( func.branches.get_mut(¤t_branch).unwrap(), "lea {}, [{} + {offset}]", ImmRegMem::Reg(dst_reg, Width::QWord), ImmRegMem::Reg(dst_reg, Width::QWord), )?; } if let ImmRegMem::Reg(reg, _) = src { register_store.free(reg); registers.remove(®); } registers.insert(dst_reg, node); } Inst::Parameter(_) => { let param_reg = Registers::sysv_param_idx(param_count).unwrap(); param_count += 1; register_store.force_take(param_reg); writeln!(&mut func.entry, "push {}", param_reg.display(Width::QWord))?; stack_offset += 8; allocas.insert(node, stack_offset); registers.insert(param_reg, node); } Inst::Add(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let (&src_reg, _) = registers.iter().find(|(_, node)| node == &&src).unwrap(); let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap(); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "add {}, {}", dst_reg.display(Width::from_size(ty.size()).unwrap()), src_reg.display(Width::from_size(ty.size()).unwrap()), )?; if src_reg != dst_reg { register_store.free(src_reg); registers.remove(&src_reg); } registers.insert(dst_reg, node); } Inst::Sub(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let src = registers .iter() .find(|(_, node)| node == &&src) .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) .or_else(|| { allocas .get(&src) .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) }) .expect(&format!( "src_reg from node %{src} not found: {registers:?}" )); let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap(); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "sub {}, {}", dst_reg.display(Width::from_size(ty.size()).unwrap()), src, )?; if let ImmRegMem::Reg(reg, _) = src { if reg != dst_reg { register_store.free(reg); registers.remove(®); } } registers.insert(dst_reg, node); } Inst::Mul(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let src = registers .iter() .find(|(_, node)| node == &&src) .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap())) .or_else(|| { allocas .get(&src) .map(|&offset| ImmRegMem::Mem(StackMem::new(offset))) }) .expect(&format!( "src_reg from node %{src} not found: {registers:?}" )); let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap(); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "imul {}, {}", dst_reg.display(Width::from_size(ty.size()).unwrap()), src, )?; if let ImmRegMem::Reg(reg, _) = src { if reg != dst_reg { register_store.free(reg); registers.remove(®); } } registers.insert(dst_reg, node); } Inst::ShiftLeft(_) => todo!(), Inst::ShiftRight(_) => todo!(), Inst::Div(_) => todo!(), Inst::Rem(_) => todo!(), Inst::BitAnd(_) => todo!(), Inst::BitOr(_) => todo!(), Inst::BitXOr(_) => todo!(), Inst::Negate(_) => todo!(), Inst::ExplicitCast(_, _) => todo!(), Inst::ReturnValue(_) => { let val = data.unwrap().lhs; let (®, _) = registers.iter().find(|(_, node)| node == &&val) .expect(&format!( "location for node %{val} not found: \nregisters: {registers:?}\nallocas: {allocas:?}" )); writeln!( func.branches.get_mut(¤t_branch).unwrap(), "mov rax, {}\njmp {}__epilogue", reg.display(Width::QWord), func.name )?; register_store.free(reg); registers.remove(®); } Inst::Return => { writeln!( func.branches.get_mut(¤t_branch).unwrap(), "jmp {}__epilogue", func.name )?; } } } func.stack_size = stack_offset; func.used_registers = register_store.used.into_iter().collect(); self.functions.push(func); Ok(()) } fn assemble(&mut self) -> core::fmt::Result { loop { let Some((inst, data)) = self.ir.next() else { break; }; match inst { Inst::FunctionStart => { let name = self.strings.get_str(data.unwrap().as_index()); self.assemble_function(name.to_owned())? } _ => {} } } Ok(()) } fn finish(&self, w: &mut W) -> core::fmt::Result { writeln!(w, ".intel_syntax")?; writeln!(w, ".text")?; for func in self.functions.iter() { writeln!(w, ".globl {}", func.name)?; } for func in self.functions.iter() { func.write(w)?; } Ok(()) } } use crate::mir; pub struct MirBuilder<'a> { ir: IRIter<'a>, pub strings: StringTable, pub functions: HashMap, } impl<'a> MirBuilder<'a> { pub fn new(ir: &'a IR, strings: StringTable) -> MirBuilder<'a> { Self { ir: IRIter { ir, offset: 0, item: None, }, strings, functions: HashMap::new(), } } fn build_function(&mut self, name: StringsIndex) { let mut mir = mir::Mir::new(name); let mut mapping = BTreeMap::::new(); loop { let ir_node = self.ir.node(); let Some((inst, data)) = self.ir.next() else { break; }; let node = match inst { Inst::FunctionStart => { self.ir.offset -= 1; break; } Inst::Label => mir.gen_label(data.unwrap().as_index()), Inst::ConstantU32 => mir.push( mir::Inst::ConstantDWord, mir::Data::imm32(data.unwrap().as_u32()), ), Inst::ConstantU64 => mir.push( mir::Inst::ConstantQWord, mir::Data::imm64(data.unwrap().as_u64()), ), Inst::ConstantMultiByte => { let bytes = self.strings.get_bytes(data.unwrap().as_index()); let mut buf = [0u8; 8]; match bytes.len() { 1 => mir.gen_u8(bytes[0]), 2 => mir.gen_u16(u16::from_le_bytes(bytes[..2].try_into().unwrap())), 3..=4 => { buf[..bytes.len()].copy_from_slice(bytes); mir.gen_u32(u32::from_le_bytes(buf[..4].try_into().unwrap())) } 5..=8 => { buf[..bytes.len()].copy_from_slice(bytes); mir.gen_u64(u64::from_le_bytes(buf[..8].try_into().unwrap())) } _ => { unimplemented!( "constants larger than 8 bytes are not currently supported!" ) } } } Inst::ExternRef => todo!(), Inst::Alloca => { let (l, r) = data.unwrap().as_lhs_rhs(); mir.gen_alloca(l, r) } Inst::Load(ty) => { let ty = mir::Type::from_bytesize_int(ty.size()); let src = *mapping.get(&data.unwrap().as_u32()).unwrap(); mir.gen_load(ty, src) } Inst::Store(ty) => { let ty = mir::Type::from_bytesize_int(ty.size()); let (src, dst) = data.unwrap().as_lhs_rhs(); let src = *mapping.get(&src).unwrap(); let dst = *mapping.get(&dst).unwrap(); mir.gen_store(ty, src, dst) } Inst::GetElementPtr(ty) => { let ty = mir::Type::from_bytesize_int(ty.size()); let (ptr, idx) = data.unwrap().as_lhs_rhs(); let src = *mapping.get(&ptr).unwrap(); mir.gen_get_element_ptr(ty, src, idx) } Inst::Parameter(ty) => { // let (size, _) = data.unwrap().as_lhs_rhs(); mir.gen_param(ty.into()) } Inst::Add(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let lhs = *mapping.get(&src).unwrap(); let rhs = *mapping.get(&dst).unwrap(); match ty { Type2::Integral(signed, bits) => match bits { 8 => mir.gen_add(mir::Type::Byte, lhs, rhs), 16 => mir.gen_add(mir::Type::Word, lhs, rhs), 32 => mir.gen_add(mir::Type::DWord, lhs, rhs), 64 => mir.gen_add(mir::Type::QWord, lhs, rhs), 64.. => { unimplemented!() } bits => { let ty = mir::Type::from_bitsize_int(bits as u32); let sum = mir.gen_add(ty, lhs, rhs); mir.gen_truncate_integer(sum, ty, signed, bits) } }, Type2::Binary32 => mir.gen_add(mir::Type::SinglePrecision, lhs, rhs), Type2::Binary64 => mir.gen_add(mir::Type::DoublePrecision, lhs, rhs), Type2::Pointer => mir.gen_add(mir::Type::QWord, lhs, rhs), _ => unreachable!(), } } Inst::Sub(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let lhs = *mapping.get(&src).unwrap(); let rhs = *mapping.get(&dst).unwrap(); let unalignment = ty.mir_unalignment(); let ty = ty.mir_type(); let sum = mir.gen_sub(ty, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::Mul(ty) => { let (lhs, rhs) = data.unwrap().as_lhs_rhs(); let unalignment = ty.mir_unalignment(); let signed = ty.is_signed(); let ty = ty.mir_type(); let lhs = *mapping.get(&lhs).unwrap(); let rhs = *mapping.get(&rhs).unwrap(); let sum = mir.gen_mul(ty, signed, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::Div(ty) => { let (lhs, rhs) = data.unwrap().as_lhs_rhs(); let unalignment = ty.mir_unalignment(); let signed = ty.is_signed(); let ty = ty.mir_type(); let lhs = *mapping.get(&lhs).unwrap(); let rhs = *mapping.get(&rhs).unwrap(); let sum = mir.gen_div(ty, signed, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::Rem(ty) => { let (lhs, rhs) = data.unwrap().as_lhs_rhs(); let unalignment = ty.mir_unalignment(); let signed = ty.is_signed(); let ty = ty.mir_type(); let lhs = *mapping.get(&lhs).unwrap(); let rhs = *mapping.get(&rhs).unwrap(); let sum = mir.gen_rem(ty, signed, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::BitAnd(ty) => { let (lhs, rhs) = data.unwrap().as_lhs_rhs(); let unalignment = ty.mir_unalignment(); let ty = ty.mir_type(); let (lhs, rhs) = if self.ir.ir.nodes[lhs as usize].is_constant() { (rhs, lhs) } else { (lhs, rhs) }; let lhs = *mapping.get(&lhs).unwrap(); let rhs = *mapping.get(&rhs).unwrap(); let sum = mir.gen_bitand(ty, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::BitOr(ty) => { let (lhs, rhs) = data.unwrap().as_lhs_rhs(); let unalignment = ty.mir_unalignment(); let ty = ty.mir_type(); let (lhs, rhs) = if self.ir.ir.nodes[lhs as usize].is_constant() { (rhs, lhs) } else { (lhs, rhs) }; let lhs = *mapping.get(&lhs).unwrap(); let rhs = *mapping.get(&rhs).unwrap(); let sum = mir.gen_bitor(ty, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::BitXOr(ty) => { let (lhs, rhs) = data.unwrap().as_lhs_rhs(); let unalignment = ty.mir_unalignment(); let ty = ty.mir_type(); let (lhs, rhs) = if self.ir.ir.nodes[lhs as usize].is_constant() { (rhs, lhs) } else { (lhs, rhs) }; let lhs = *mapping.get(&lhs).unwrap(); let rhs = *mapping.get(&rhs).unwrap(); let sum = mir.gen_bitxor(ty, lhs, rhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::ShiftLeft(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let lhs = *mapping.get(&src).unwrap(); let rhs = *mapping.get(&dst).unwrap(); // TODO: check rhs type and pass it to gen_sh{l,r}? let rhs = mir.gen_truncate_integer(rhs, ty.into(), false, 8); match ty { Type2::Integral(signed, bits) => match bits { 8 => mir.gen_shl(mir::Type::Byte, lhs, rhs), 16 => mir.gen_shl(mir::Type::Word, lhs, rhs), 32 => mir.gen_shl(mir::Type::DWord, lhs, rhs), 64 => mir.gen_shl(mir::Type::QWord, lhs, rhs), 64.. => { unimplemented!() } bits => { let ty = mir::Type::from_bitsize_int(bits as u32); let sum = mir.gen_shl(ty, lhs, rhs); mir.gen_truncate_integer(sum, ty, signed, bits) } }, _ => unreachable!(), } } Inst::ShiftRight(ty) => { let (src, dst) = data.unwrap().as_lhs_rhs(); let lhs = *mapping.get(&src).unwrap(); let rhs = *mapping.get(&dst).unwrap(); match ty { Type2::Integral(signed, bits) => match bits { 8 | 16 | 32 | 64 => { let ty = mir::Type::from_bitsize_int(bits as u32); if signed { mir.gen_sar(ty, lhs, rhs) } else { mir.gen_shr(ty, lhs, rhs) } } 64.. => { unimplemented!() } bits => { let ty = mir::Type::from_bitsize_int(bits as u32); let sum = if signed { mir.gen_sar(ty, lhs, rhs) } else { mir.gen_shr(ty, lhs, rhs) }; mir.gen_truncate_integer(sum, ty, signed, bits) } }, _ => unreachable!(), } } Inst::Negate(ty) => { let lhs = data.unwrap().as_u32(); let unalignment = ty.mir_unalignment(); let ty = ty.mir_type(); let lhs = *mapping.get(&lhs).unwrap(); let sum = mir.gen_negate(ty, lhs); if let Some((signed, bits)) = unalignment { mir.gen_truncate_integer(sum, ty, signed, bits) } else { sum } } Inst::ExplicitCast(from, to) => { let lhs = data.unwrap().as_u32(); let from_mir = from.mir_type(); let to_mir = to.mir_type(); let lhs = *mapping.get(&lhs).unwrap(); match (from, to) { (Type2::Integral(a_signed, a), Type2::Integral(b_signed, b)) => { if a > b { mir.gen_truncate_integer(lhs, to_mir, b_signed, b) } else if a < b { mir.gen_extend_integer( lhs, IntegralType::new(a_signed, a), IntegralType::new(b_signed, b), ) } else { unreachable!() } } (Type2::Integral(_, _), Type2::Bool) => { let is_zero = mir.gen_is_zero(from_mir, lhs); mir.gen_negate(mir::Type::Byte, is_zero) } (Type2::Bool, Type2::Integral(b_signed, b)) => mir.gen_extend_integer( lhs, IntegralType::u1(), IntegralType::new(b_signed, b), ), _ => unimplemented!(), } } Inst::ReturnValue(ty) => { let src = data.unwrap().as_u32(); let src = *mapping.get(&src).unwrap(); mir.gen_ret_val(ty.mir_type(), src) } Inst::Return => mir.gen_ret(), #[allow(unreachable_patterns)] _ => { unimplemented!() } }; mapping.insert(ir_node, node); } self.functions.insert(name, mir); } pub fn build(&mut self) { loop { let Some((inst, data)) = self.ir.next() else { break; }; match inst { Inst::FunctionStart => self.build_function(data.unwrap().as_index()), _ => {} } } } } #[cfg(test)] mod tests { use crate::lexer::Tokenizer; use super::*; #[test] fn mir() { let src = " fn inverse_sqrt(x: f32) -> f32 { let three_halfs: f32 = 1.5f32; let x2 = x * 0.5f32; var y = x; let i = 0x5f3759dfu32 - (*(&y as *u32) >> 1u32); y = *(&i as *f32); y = y * (three_halfs - (x2 * y * y)); return y; } "; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); let mut tree = Tree::new(); tree.parse(tokens.iter()).unwrap(); tree.fold_comptime(); let mut buf = String::new(); tree.render(&mut buf).unwrap(); println!("AST:\n{buf}"); let mut ir = IR::new(); let builder = ir.build(&mut tree); let mut buf = String::new(); builder.render(&mut buf).unwrap(); println!("IR:\n{buf}"); let mut mir = MirBuilder::new(&ir, tree.strings); mir.build(); let MirBuilder { strings, functions, .. } = mir; for (_name, mir) in functions { let assembly = mir.assemble(&strings).unwrap(); println!("mir:\n{}", mir.display(&strings)); println!("assembly:\n{assembly}"); } } #[test] fn mir_u10() { let src = " fn u10(x: i10) -> i10 { 5i10 * 3i10 } "; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); let mut tree = Tree::new(); tree.parse(tokens.iter()).unwrap(); let mut buf = String::new(); tree.render(&mut buf).unwrap(); println!("AST:\n{buf}"); let mut ir = IR::new(); let builder = ir.build(&mut tree); let mut buf = String::new(); builder.render(&mut buf).unwrap(); println!("IR:\n{buf}"); let mut mir = MirBuilder::new(&ir, tree.strings); mir.build(); } #[test] fn ir() { let src = " fn main() -> u32 { let a: u32 = 0u32 + 3u32; let b = &a; return *b * 2u32; } fn square(x: u32) -> u32 { x * x } "; let tokens = Tokenizer::new(src.as_bytes()).unwrap(); let mut tree = Tree::new(); tree.parse(tokens.iter()).unwrap(); let mut buf = String::new(); tree.render(&mut buf).unwrap(); println!("{buf}"); let mut ir = IR::new(); let builder = ir.build(&mut tree); let mut buf = String::new(); builder.render(&mut buf).unwrap(); println!("{buf}"); let strings = tree.strings; let mut mir = MirBuilder::new(&ir, strings.clone()); mir.build(); let mut assembler = Assembler::from_ir(&ir, strings); assembler.assemble().unwrap(); let mut buf = String::new(); assembler.finish(&mut buf).unwrap(); println!("{buf}"); } }