From dd6ce88ad6388a4fc220c21bf595b76bbfc2776c Mon Sep 17 00:00:00 2001
From: Janis <janis@nirgendwo.xyz>
Date: Sun, 1 Sep 2024 12:28:47 +0200
Subject: [PATCH] if-else branches in IR

---
 src/parser.rs       |   9 +
 src/string_table.rs |  12 +-
 src/triples.rs      | 499 +++++---------------------------------------
 3 files changed, 76 insertions(+), 444 deletions(-)

diff --git a/src/parser.rs b/src/parser.rs
index 1a8d82e..d2af06f 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -1806,6 +1806,15 @@ impl Tree {
             Tag::Ge { .. } => Type::bool(),
             Tag::DeclRef(decl) => self.type_of_node(*decl),
             Tag::GlobalRef(decl) => self.type_of_node(*decl),
+            Tag::IfExpr { .. } => Type::void(),
+            Tag::IfElseExpr {
+                body, else_expr, ..
+            } => self.peer_type_of_nodes(*body, *else_expr).expect({
+                let (lhs, rhs) = (body, else_expr);
+                let at = self.type_of_node(*lhs);
+                let bt = self.type_of_node(*rhs);
+                &format!("incompatible types for %{lhs}({at}) and %{rhs}({bt})")
+            }),
             _ => Type::void(),
         }
     }
diff --git a/src/string_table.rs b/src/string_table.rs
index 88197f8..328bf06 100644
--- a/src/string_table.rs
+++ b/src/string_table.rs
@@ -63,11 +63,19 @@ impl StringTable {
     }
 
     pub fn get_str(&self, idx: Index) -> &str {
-        unsafe { core::str::from_utf8_unchecked(&self[idx]) }
+        if idx == Index::none() {
+            ""
+        } else {
+            unsafe { core::str::from_utf8_unchecked(&self[idx]) }
+        }
     }
 
     pub fn get_bytes(&self, idx: Index) -> &[u8] {
-        &self[idx]
+        if idx == Index::none() {
+            &[]
+        } else {
+            &self[idx]
+        }
     }
 
     pub fn insert<B: AsRef<[u8]>>(&mut self, bytes: B) -> Index {
diff --git a/src/triples.rs b/src/triples.rs
index 122e3ea..253ba7d 100644
--- a/src/triples.rs
+++ b/src/triples.rs
@@ -176,6 +176,12 @@ pub enum Inst {
     ReturnValue(Type2),
     /// no parameters
     Return,
+    /// Node is a bool
+    /// two labels
+    /// lhs, rhs
+    Branch(Node),
+    /// lhs: Label
+    Jump,
 }
 
 impl Inst {
@@ -543,6 +549,42 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
                     )
                 }
             }
+            Tag::IfExpr { condition, body } => {
+                assert_eq!(self.tree.type_of_node(*condition), Type::Bool);
+
+                let condition = self.visit(*condition);
+                let br = self.ir.push(Inst::Branch(condition), None);
+
+                let lhs = self.visit(*body);
+                let jmp = self.ir.push(Inst::Jump, None);
+                let nojump = self.ir.push(Inst::Label, Some(StringsIndex::none().into()));
+
+                self.ir.data[br as usize] = Some(Data::new(lhs, nojump));
+                self.ir.data[jmp as usize] = Some(Data::lhs(nojump));
+                br
+            }
+            Tag::IfElseExpr {
+                condition,
+                body,
+                else_expr,
+            } => {
+                assert_eq!(self.tree.type_of_node(*condition), Type::Bool);
+
+                let condition = self.visit(*condition);
+                let br = self.ir.push(Inst::Branch(condition), None);
+
+                let lhs = self.visit(*body);
+                let ljmp = self.ir.push(Inst::Jump, None);
+                let rhs = self.visit(*else_expr);
+                let rjmp = self.ir.push(Inst::Jump, None);
+
+                let nojump = self.ir.push(Inst::Label, Some(StringsIndex::none().into()));
+
+                self.ir.data[br as usize] = Some(Data::new(lhs, rhs));
+                self.ir.data[ljmp as usize] = Some(Data::lhs(nojump));
+                self.ir.data[rjmp as usize] = Some(Data::lhs(nojump));
+                br
+            }
             _ => {
                 dbg!(&self.tree.nodes[node]);
                 todo!()
@@ -596,7 +638,7 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
         match inst {
             Inst::Label => {
                 let label = self.tree.strings.get_str(data.as_index());
-                writeln_indented!(indent - 1, w, "%{} = {label}:", node)?;
+                writeln_indented!(indent - 1, w, "%{} = label \"{label}\":", node)?;
             }
             Inst::FunctionStart => {
                 let label = self.tree.strings.get_str(data.as_index());
@@ -633,27 +675,27 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
             }
             Inst::Eq(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
-                writeln_indented!(indent, w, "%{} = eq_{ty}(%{} - %{})", node, lhs, rhs)?;
+                writeln_indented!(indent, w, "%{} = eq_{ty}(%{} == %{})", node, lhs, rhs)?;
             }
             Inst::Neq(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
-                writeln_indented!(indent, w, "%{} = neq_{ty}(%{} - %{})", node, lhs, rhs)?;
+                writeln_indented!(indent, w, "%{} = neq_{ty}(%{} != %{})", node, lhs, rhs)?;
             }
             Inst::Gt(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
-                writeln_indented!(indent, w, "%{} = gt_{ty}(%{} - %{})", node, lhs, rhs)?;
+                writeln_indented!(indent, w, "%{} = gt_{ty}(%{} > %{})", node, lhs, rhs)?;
             }
             Inst::Lt(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
-                writeln_indented!(indent, w, "%{} = lt_{ty}(%{} - %{})", node, lhs, rhs)?;
+                writeln_indented!(indent, w, "%{} = lt_{ty}(%{} < %{})", node, lhs, rhs)?;
             }
             Inst::Ge(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
-                writeln_indented!(indent, w, "%{} = ge_{ty}(%{} - %{})", node, lhs, rhs)?;
+                writeln_indented!(indent, w, "%{} = ge_{ty}(%{} >= %{})", node, lhs, rhs)?;
             }
             Inst::Le(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
-                writeln_indented!(indent, w, "%{} = le_{ty}(%{} - %{})", node, lhs, rhs)?;
+                writeln_indented!(indent, w, "%{} = le_{ty}(%{} <= %{})", node, lhs, rhs)?;
             }
             Inst::Mul(ty) => {
                 let (lhs, rhs) = data.as_lhs_rhs();
@@ -720,6 +762,14 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
                 let ast = data.lhs;
                 writeln_indented!(indent, w, "%{} = extern reference ast-node %{}", node, ast)?;
             }
+            Inst::Branch(condition) => {
+                let (lhs, rhs) = data.as_lhs_rhs();
+                writeln_indented!(indent, w, "%{node} = br bool %{condition} [%{lhs}, %{rhs}]")?;
+            }
+            Inst::Jump => {
+                let lhs = data.lhs;
+                writeln_indented!(indent, w, "%{node} = jmp %{lhs}")?;
+            }
             _ => {
                 unimplemented!("{inst:?} rendering unimplemented")
             }
@@ -1011,401 +1061,6 @@ impl<'a> Iterator for IRIter<'a> {
     }
 }
 
-struct Assembler<'a> {
-    ir: IRIter<'a>,
-    strings: StringTable,
-    constants: HashMap<ImmOrIndex, Vec<Node>>,
-    functions: Vec<Function>,
-}
-use core::fmt::Write;
-
-impl<'a> Assembler<'a> {
-    fn from_ir(ir: &'a IR, strings: StringTable) -> Assembler<'a> {
-        Self {
-            ir: IRIter {
-                ir,
-                offset: 0,
-                item: None,
-            },
-            strings,
-            constants: HashMap::new(),
-            functions: Vec::new(),
-        }
-    }
-    fn assemble_function(&mut self, name: String) -> core::fmt::Result {
-        // hashmap of node indices and offsets from the base pointer
-        let mut allocas = HashMap::<Node, u32>::new();
-
-        let mut register_store = RegisterStore::new();
-        // rax as scratch register
-        register_store.force_take(Registers::A);
-        let mut registers = BTreeMap::<Registers, Node>::new();
-        let mut stack_offset = 0;
-        let mut func = Function::default();
-        func.name = name;
-        let mut param_count = 0;
-
-        let mut current_branch = "main".to_owned();
-        func.branches.insert(current_branch.clone(), String::new());
-
-        loop {
-            let node = self.ir.node();
-            let Some((inst, data)) = self.ir.next() else {
-                break;
-            };
-
-            match inst {
-                Inst::FunctionStart => {
-                    self.ir.offset -= 1;
-                    break;
-                }
-                Inst::Label => {
-                    current_branch = self.strings.get_str(data.unwrap().as_index()).to_owned();
-                    func.branches.insert(current_branch.clone(), String::new());
-                }
-                Inst::ConstantU32 => {
-                    let value = data.unwrap().as_u32();
-                    match self.constants.entry(ImmOrIndex::U32(value)) {
-                        Entry::Occupied(mut o) => o.get_mut().push(node),
-                        Entry::Vacant(v) => {
-                            v.insert(vec![node]);
-                        }
-                    }
-                    let reg = register_store.take_any().unwrap();
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "mov {}, {value}",
-                        reg.display(Width::DWord)
-                    )?;
-                    registers.insert(reg, node);
-                }
-                Inst::ConstantU64 => {
-                    let value = data.unwrap().as_u64();
-                    match self.constants.entry(ImmOrIndex::U64(value)) {
-                        Entry::Occupied(mut o) => o.get_mut().push(node),
-                        Entry::Vacant(v) => {
-                            v.insert(vec![node]);
-                        }
-                    }
-                    let reg = register_store.take_any().unwrap();
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "mov {}, {value}",
-                        reg.display(Width::QWord)
-                    )?;
-                    registers.insert(reg, node);
-                }
-                Inst::ConstantMultiByte => {
-                    let value = data.unwrap().as_index();
-                    match self.constants.entry(ImmOrIndex::Index(value)) {
-                        Entry::Occupied(mut o) => o.get_mut().push(node),
-                        Entry::Vacant(v) => {
-                            v.insert(vec![node]);
-                        }
-                    }
-                    todo!()
-                }
-                Inst::ExternRef => todo!(),
-                Inst::Alloca => {
-                    let (size, align) = data.unwrap().as_lhs_rhs();
-                    let size = size.next_multiple_of(align);
-                    writeln!(&mut func.entry, "sub rsp, 0x{size:x}")?;
-                    stack_offset += size;
-                    allocas.insert(node, stack_offset);
-                }
-                Inst::Load(ty) => {
-                    let src = data.unwrap().lhs;
-                    let src = registers
-                        .iter()
-                        .find(|(_, node)| node == &&src)
-                        .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
-                        .or_else(|| {
-                            allocas
-                                .get(&src)
-                                .map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
-                        })
-                        .expect(&format!(
-                            "src_reg from node %{src} not found: {registers:?}"
-                        ));
-                    let dst_reg = register_store.take_any().unwrap();
-
-                    match src {
-                        ImmRegMem::Reg(_, _) => {
-                            writeln!(
-                                func.branches.get_mut(&current_branch).unwrap(),
-                                "mov {}, [{}]",
-                                dst_reg.display(Width::from_size(ty.size()).unwrap()),
-                                src,
-                            )?;
-                        }
-                        ImmRegMem::Mem(ref mem) => {
-                            let tmp_reg = register_store.take_any().unwrap();
-                            writeln!(
-                                func.branches.get_mut(&current_branch).unwrap(),
-                                "mov {}, {}",
-                                tmp_reg.display(Width::QWord),
-                                mem,
-                            )?;
-                            writeln!(
-                                func.branches.get_mut(&current_branch).unwrap(),
-                                "mov {}, [{}]",
-                                dst_reg.display(Width::from_size(ty.size()).unwrap()),
-                                tmp_reg.display(Width::QWord),
-                            )?;
-                            register_store.free(tmp_reg);
-                        }
-                        _ => {}
-                    }
-
-                    if let ImmRegMem::Reg(reg, _) = src {
-                        register_store.free(reg);
-                        registers.remove(&reg);
-                    }
-                    registers.insert(dst_reg, node);
-                }
-                Inst::Store(ty) => {
-                    let (src, dst) = data.unwrap().as_lhs_rhs();
-                    let src = registers
-                        .iter()
-                        .find(|(_, node)| node == &&src)
-                        .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
-                        .or_else(|| {
-                            allocas
-                                .get(&src)
-                                .map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
-                        })
-                        .expect(&format!(
-                            "src_reg from node %{src} not found: {registers:?}"
-                        ));
-                    let dst = registers
-                        .iter()
-                        .find(|(_, node)| node == &&dst)
-                        .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
-                        .or_else(|| {
-                            allocas
-                                .get(&dst)
-                                .map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
-                        })
-                        .expect(&format!(
-                            "src_reg from node %{src} not found: {registers:?}"
-                        ));
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "mov {}, {}",
-                        dst,
-                        src,
-                    )?;
-
-                    if let ImmRegMem::Reg(reg, _) = src {
-                        register_store.free(reg);
-                        registers.remove(&reg);
-                    }
-                }
-                Inst::GetElementPtr(ty) => {
-                    let (ptr, idx) = data.unwrap().as_lhs_rhs();
-                    let src = registers
-                        .iter()
-                        .find(|(_, node)| node == &&ptr)
-                        .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
-                        .or_else(|| {
-                            allocas
-                                .get(&ptr)
-                                .map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
-                        })
-                        .expect(&format!(
-                            "src_reg from node %{ptr} not found: {registers:?}"
-                        ));
-                    let dst_reg = register_store.take_any().unwrap();
-
-                    if let ImmRegMem::Mem(_) = &src {
-                        writeln!(
-                            func.branches.get_mut(&current_branch).unwrap(),
-                            "lea {}, {}",
-                            ImmRegMem::Reg(dst_reg, Width::QWord),
-                            src,
-                        )?;
-                    }
-
-                    let offset = idx * ty.size();
-                    if offset != 0 {
-                        writeln!(
-                            func.branches.get_mut(&current_branch).unwrap(),
-                            "lea {}, [{} + {offset}]",
-                            ImmRegMem::Reg(dst_reg, Width::QWord),
-                            ImmRegMem::Reg(dst_reg, Width::QWord),
-                        )?;
-                    }
-                    if let ImmRegMem::Reg(reg, _) = src {
-                        register_store.free(reg);
-                        registers.remove(&reg);
-                    }
-                    registers.insert(dst_reg, node);
-                }
-                Inst::Parameter(_) => {
-                    let param_reg = Registers::sysv_param_idx(param_count).unwrap();
-                    param_count += 1;
-                    register_store.force_take(param_reg);
-                    writeln!(&mut func.entry, "push {}", param_reg.display(Width::QWord))?;
-                    stack_offset += 8;
-                    allocas.insert(node, stack_offset);
-                    registers.insert(param_reg, node);
-                }
-                Inst::Eq(_)
-                | Inst::Neq(_)
-                | Inst::Gt(_)
-                | Inst::Lt(_)
-                | Inst::Ge(_)
-                | Inst::Le(_) => {}
-                Inst::Add(ty) => {
-                    let (src, dst) = data.unwrap().as_lhs_rhs();
-                    let (&src_reg, _) = registers.iter().find(|(_, node)| node == &&src).unwrap();
-                    let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap();
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "add {}, {}",
-                        dst_reg.display(Width::from_size(ty.size()).unwrap()),
-                        src_reg.display(Width::from_size(ty.size()).unwrap()),
-                    )?;
-
-                    if src_reg != dst_reg {
-                        register_store.free(src_reg);
-                        registers.remove(&src_reg);
-                    }
-                    registers.insert(dst_reg, node);
-                }
-                Inst::Sub(ty) => {
-                    let (src, dst) = data.unwrap().as_lhs_rhs();
-                    let src = registers
-                        .iter()
-                        .find(|(_, node)| node == &&src)
-                        .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
-                        .or_else(|| {
-                            allocas
-                                .get(&src)
-                                .map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
-                        })
-                        .expect(&format!(
-                            "src_reg from node %{src} not found: {registers:?}"
-                        ));
-                    let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap();
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "sub {}, {}",
-                        dst_reg.display(Width::from_size(ty.size()).unwrap()),
-                        src,
-                    )?;
-
-                    if let ImmRegMem::Reg(reg, _) = src {
-                        if reg != dst_reg {
-                            register_store.free(reg);
-                            registers.remove(&reg);
-                        }
-                    }
-                    registers.insert(dst_reg, node);
-                }
-                Inst::Mul(ty) => {
-                    let (src, dst) = data.unwrap().as_lhs_rhs();
-                    let src = registers
-                        .iter()
-                        .find(|(_, node)| node == &&src)
-                        .map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
-                        .or_else(|| {
-                            allocas
-                                .get(&src)
-                                .map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
-                        })
-                        .expect(&format!(
-                            "src_reg from node %{src} not found: {registers:?}"
-                        ));
-                    let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap();
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "imul {}, {}",
-                        dst_reg.display(Width::from_size(ty.size()).unwrap()),
-                        src,
-                    )?;
-
-                    if let ImmRegMem::Reg(reg, _) = src {
-                        if reg != dst_reg {
-                            register_store.free(reg);
-                            registers.remove(&reg);
-                        }
-                    }
-                    registers.insert(dst_reg, node);
-                }
-                Inst::ShiftLeft(_) => todo!(),
-                Inst::ShiftRight(_) => todo!(),
-                Inst::Div(_) => todo!(),
-                Inst::Rem(_) => todo!(),
-                Inst::BitAnd(_) => todo!(),
-                Inst::BitOr(_) => todo!(),
-                Inst::BitXOr(_) => todo!(),
-                Inst::Negate(_) => todo!(),
-                Inst::Not(_) => todo!(),
-                Inst::ExplicitCast(_, _) => todo!(),
-                Inst::ReturnValue(_) => {
-                    let val = data.unwrap().lhs;
-                    let (&reg, _) = registers.iter().find(|(_, node)| node == &&val)
-                        .expect(&format!(
-                            "location for node %{val} not found: \nregisters: {registers:?}\nallocas: {allocas:?}"
-                        ));
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "mov rax, {}\njmp {}__epilogue",
-                        reg.display(Width::QWord),
-                        func.name
-                    )?;
-
-                    register_store.free(reg);
-                    registers.remove(&reg);
-                }
-                Inst::Return => {
-                    writeln!(
-                        func.branches.get_mut(&current_branch).unwrap(),
-                        "jmp {}__epilogue",
-                        func.name
-                    )?;
-                }
-            }
-        }
-
-        func.stack_size = stack_offset;
-        func.used_registers = register_store.used.into_iter().collect();
-
-        self.functions.push(func);
-        Ok(())
-    }
-
-    fn assemble(&mut self) -> core::fmt::Result {
-        loop {
-            let Some((inst, data)) = self.ir.next() else {
-                break;
-            };
-            match inst {
-                Inst::FunctionStart => {
-                    let name = self.strings.get_str(data.unwrap().as_index());
-                    self.assemble_function(name.to_owned())?
-                }
-                _ => {}
-            }
-        }
-
-        Ok(())
-    }
-
-    fn finish<W: core::fmt::Write>(&self, w: &mut W) -> core::fmt::Result {
-        writeln!(w, ".intel_syntax")?;
-        writeln!(w, ".text")?;
-        for func in self.functions.iter() {
-            writeln!(w, ".globl {}", func.name)?;
-        }
-        for func in self.functions.iter() {
-            func.write(w)?;
-        }
-        Ok(())
-    }
-}
-
 use crate::mir;
 
 pub struct MirBuilder<'a> {
@@ -1911,44 +1566,4 @@ fn u10(x: i10) -> i10 {
         let mut mir = MirBuilder::new(&ir, tree.strings);
         mir.build();
     }
-
-    #[test]
-    fn ir() {
-        let src = "
-fn main() -> u32 {
-    let a: u32 = 0u32 + 3u32;
-    let b = &a;
-    return *b * 2u32;
-}
-
-fn square(x: u32) -> u32 {
-    x * x
-}
-";
-        let tokens = Tokenizer::new(src.as_bytes()).unwrap();
-
-        let mut tree = Tree::new();
-        tree.parse(tokens.iter()).unwrap();
-
-        let mut buf = String::new();
-        tree.render(&mut buf).unwrap();
-        println!("{buf}");
-
-        let mut ir = IR::new();
-        let builder = ir.build(&mut tree);
-        let mut buf = String::new();
-        builder.render(&mut buf).unwrap();
-        println!("{buf}");
-
-        let strings = tree.strings;
-        let mut mir = MirBuilder::new(&ir, strings.clone());
-        mir.build();
-
-        let mut assembler = Assembler::from_ir(&ir, strings);
-        assembler.assemble().unwrap();
-
-        let mut buf = String::new();
-        assembler.finish(&mut buf).unwrap();
-        println!("{buf}");
-    }
 }