if-else branches in IR

This commit is contained in:
Janis 2024-09-01 12:28:47 +02:00
parent 6bab60d168
commit dd6ce88ad6
3 changed files with 76 additions and 444 deletions

View file

@ -1806,6 +1806,15 @@ impl Tree {
Tag::Ge { .. } => Type::bool(), Tag::Ge { .. } => Type::bool(),
Tag::DeclRef(decl) => self.type_of_node(*decl), Tag::DeclRef(decl) => self.type_of_node(*decl),
Tag::GlobalRef(decl) => self.type_of_node(*decl), Tag::GlobalRef(decl) => self.type_of_node(*decl),
Tag::IfExpr { .. } => Type::void(),
Tag::IfElseExpr {
body, else_expr, ..
} => self.peer_type_of_nodes(*body, *else_expr).expect({
let (lhs, rhs) = (body, else_expr);
let at = self.type_of_node(*lhs);
let bt = self.type_of_node(*rhs);
&format!("incompatible types for %{lhs}({at}) and %{rhs}({bt})")
}),
_ => Type::void(), _ => Type::void(),
} }
} }

View file

@ -63,12 +63,20 @@ impl StringTable {
} }
pub fn get_str(&self, idx: Index) -> &str { pub fn get_str(&self, idx: Index) -> &str {
if idx == Index::none() {
""
} else {
unsafe { core::str::from_utf8_unchecked(&self[idx]) } unsafe { core::str::from_utf8_unchecked(&self[idx]) }
} }
}
pub fn get_bytes(&self, idx: Index) -> &[u8] { pub fn get_bytes(&self, idx: Index) -> &[u8] {
if idx == Index::none() {
&[]
} else {
&self[idx] &self[idx]
} }
}
pub fn insert<B: AsRef<[u8]>>(&mut self, bytes: B) -> Index { pub fn insert<B: AsRef<[u8]>>(&mut self, bytes: B) -> Index {
let bytes = bytes.as_ref(); let bytes = bytes.as_ref();

View file

@ -176,6 +176,12 @@ pub enum Inst {
ReturnValue(Type2), ReturnValue(Type2),
/// no parameters /// no parameters
Return, Return,
/// Node is a bool
/// two labels
/// lhs, rhs
Branch(Node),
/// lhs: Label
Jump,
} }
impl Inst { impl Inst {
@ -543,6 +549,42 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
) )
} }
} }
Tag::IfExpr { condition, body } => {
assert_eq!(self.tree.type_of_node(*condition), Type::Bool);
let condition = self.visit(*condition);
let br = self.ir.push(Inst::Branch(condition), None);
let lhs = self.visit(*body);
let jmp = self.ir.push(Inst::Jump, None);
let nojump = self.ir.push(Inst::Label, Some(StringsIndex::none().into()));
self.ir.data[br as usize] = Some(Data::new(lhs, nojump));
self.ir.data[jmp as usize] = Some(Data::lhs(nojump));
br
}
Tag::IfElseExpr {
condition,
body,
else_expr,
} => {
assert_eq!(self.tree.type_of_node(*condition), Type::Bool);
let condition = self.visit(*condition);
let br = self.ir.push(Inst::Branch(condition), None);
let lhs = self.visit(*body);
let ljmp = self.ir.push(Inst::Jump, None);
let rhs = self.visit(*else_expr);
let rjmp = self.ir.push(Inst::Jump, None);
let nojump = self.ir.push(Inst::Label, Some(StringsIndex::none().into()));
self.ir.data[br as usize] = Some(Data::new(lhs, rhs));
self.ir.data[ljmp as usize] = Some(Data::lhs(nojump));
self.ir.data[rjmp as usize] = Some(Data::lhs(nojump));
br
}
_ => { _ => {
dbg!(&self.tree.nodes[node]); dbg!(&self.tree.nodes[node]);
todo!() todo!()
@ -596,7 +638,7 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
match inst { match inst {
Inst::Label => { Inst::Label => {
let label = self.tree.strings.get_str(data.as_index()); let label = self.tree.strings.get_str(data.as_index());
writeln_indented!(indent - 1, w, "%{} = {label}:", node)?; writeln_indented!(indent - 1, w, "%{} = label \"{label}\":", node)?;
} }
Inst::FunctionStart => { Inst::FunctionStart => {
let label = self.tree.strings.get_str(data.as_index()); let label = self.tree.strings.get_str(data.as_index());
@ -633,27 +675,27 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
} }
Inst::Eq(ty) => { Inst::Eq(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{} = eq_{ty}(%{} - %{})", node, lhs, rhs)?; writeln_indented!(indent, w, "%{} = eq_{ty}(%{} == %{})", node, lhs, rhs)?;
} }
Inst::Neq(ty) => { Inst::Neq(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{} = neq_{ty}(%{} - %{})", node, lhs, rhs)?; writeln_indented!(indent, w, "%{} = neq_{ty}(%{} != %{})", node, lhs, rhs)?;
} }
Inst::Gt(ty) => { Inst::Gt(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{} = gt_{ty}(%{} - %{})", node, lhs, rhs)?; writeln_indented!(indent, w, "%{} = gt_{ty}(%{} > %{})", node, lhs, rhs)?;
} }
Inst::Lt(ty) => { Inst::Lt(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{} = lt_{ty}(%{} - %{})", node, lhs, rhs)?; writeln_indented!(indent, w, "%{} = lt_{ty}(%{} < %{})", node, lhs, rhs)?;
} }
Inst::Ge(ty) => { Inst::Ge(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{} = ge_{ty}(%{} - %{})", node, lhs, rhs)?; writeln_indented!(indent, w, "%{} = ge_{ty}(%{} >= %{})", node, lhs, rhs)?;
} }
Inst::Le(ty) => { Inst::Le(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{} = le_{ty}(%{} - %{})", node, lhs, rhs)?; writeln_indented!(indent, w, "%{} = le_{ty}(%{} <= %{})", node, lhs, rhs)?;
} }
Inst::Mul(ty) => { Inst::Mul(ty) => {
let (lhs, rhs) = data.as_lhs_rhs(); let (lhs, rhs) = data.as_lhs_rhs();
@ -720,6 +762,14 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
let ast = data.lhs; let ast = data.lhs;
writeln_indented!(indent, w, "%{} = extern reference ast-node %{}", node, ast)?; writeln_indented!(indent, w, "%{} = extern reference ast-node %{}", node, ast)?;
} }
Inst::Branch(condition) => {
let (lhs, rhs) = data.as_lhs_rhs();
writeln_indented!(indent, w, "%{node} = br bool %{condition} [%{lhs}, %{rhs}]")?;
}
Inst::Jump => {
let lhs = data.lhs;
writeln_indented!(indent, w, "%{node} = jmp %{lhs}")?;
}
_ => { _ => {
unimplemented!("{inst:?} rendering unimplemented") unimplemented!("{inst:?} rendering unimplemented")
} }
@ -1011,401 +1061,6 @@ impl<'a> Iterator for IRIter<'a> {
} }
} }
struct Assembler<'a> {
ir: IRIter<'a>,
strings: StringTable,
constants: HashMap<ImmOrIndex, Vec<Node>>,
functions: Vec<Function>,
}
use core::fmt::Write;
impl<'a> Assembler<'a> {
fn from_ir(ir: &'a IR, strings: StringTable) -> Assembler<'a> {
Self {
ir: IRIter {
ir,
offset: 0,
item: None,
},
strings,
constants: HashMap::new(),
functions: Vec::new(),
}
}
fn assemble_function(&mut self, name: String) -> core::fmt::Result {
// hashmap of node indices and offsets from the base pointer
let mut allocas = HashMap::<Node, u32>::new();
let mut register_store = RegisterStore::new();
// rax as scratch register
register_store.force_take(Registers::A);
let mut registers = BTreeMap::<Registers, Node>::new();
let mut stack_offset = 0;
let mut func = Function::default();
func.name = name;
let mut param_count = 0;
let mut current_branch = "main".to_owned();
func.branches.insert(current_branch.clone(), String::new());
loop {
let node = self.ir.node();
let Some((inst, data)) = self.ir.next() else {
break;
};
match inst {
Inst::FunctionStart => {
self.ir.offset -= 1;
break;
}
Inst::Label => {
current_branch = self.strings.get_str(data.unwrap().as_index()).to_owned();
func.branches.insert(current_branch.clone(), String::new());
}
Inst::ConstantU32 => {
let value = data.unwrap().as_u32();
match self.constants.entry(ImmOrIndex::U32(value)) {
Entry::Occupied(mut o) => o.get_mut().push(node),
Entry::Vacant(v) => {
v.insert(vec![node]);
}
}
let reg = register_store.take_any().unwrap();
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov {}, {value}",
reg.display(Width::DWord)
)?;
registers.insert(reg, node);
}
Inst::ConstantU64 => {
let value = data.unwrap().as_u64();
match self.constants.entry(ImmOrIndex::U64(value)) {
Entry::Occupied(mut o) => o.get_mut().push(node),
Entry::Vacant(v) => {
v.insert(vec![node]);
}
}
let reg = register_store.take_any().unwrap();
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov {}, {value}",
reg.display(Width::QWord)
)?;
registers.insert(reg, node);
}
Inst::ConstantMultiByte => {
let value = data.unwrap().as_index();
match self.constants.entry(ImmOrIndex::Index(value)) {
Entry::Occupied(mut o) => o.get_mut().push(node),
Entry::Vacant(v) => {
v.insert(vec![node]);
}
}
todo!()
}
Inst::ExternRef => todo!(),
Inst::Alloca => {
let (size, align) = data.unwrap().as_lhs_rhs();
let size = size.next_multiple_of(align);
writeln!(&mut func.entry, "sub rsp, 0x{size:x}")?;
stack_offset += size;
allocas.insert(node, stack_offset);
}
Inst::Load(ty) => {
let src = data.unwrap().lhs;
let src = registers
.iter()
.find(|(_, node)| node == &&src)
.map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
.or_else(|| {
allocas
.get(&src)
.map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
})
.expect(&format!(
"src_reg from node %{src} not found: {registers:?}"
));
let dst_reg = register_store.take_any().unwrap();
match src {
ImmRegMem::Reg(_, _) => {
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov {}, [{}]",
dst_reg.display(Width::from_size(ty.size()).unwrap()),
src,
)?;
}
ImmRegMem::Mem(ref mem) => {
let tmp_reg = register_store.take_any().unwrap();
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov {}, {}",
tmp_reg.display(Width::QWord),
mem,
)?;
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov {}, [{}]",
dst_reg.display(Width::from_size(ty.size()).unwrap()),
tmp_reg.display(Width::QWord),
)?;
register_store.free(tmp_reg);
}
_ => {}
}
if let ImmRegMem::Reg(reg, _) = src {
register_store.free(reg);
registers.remove(&reg);
}
registers.insert(dst_reg, node);
}
Inst::Store(ty) => {
let (src, dst) = data.unwrap().as_lhs_rhs();
let src = registers
.iter()
.find(|(_, node)| node == &&src)
.map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
.or_else(|| {
allocas
.get(&src)
.map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
})
.expect(&format!(
"src_reg from node %{src} not found: {registers:?}"
));
let dst = registers
.iter()
.find(|(_, node)| node == &&dst)
.map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
.or_else(|| {
allocas
.get(&dst)
.map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
})
.expect(&format!(
"src_reg from node %{src} not found: {registers:?}"
));
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov {}, {}",
dst,
src,
)?;
if let ImmRegMem::Reg(reg, _) = src {
register_store.free(reg);
registers.remove(&reg);
}
}
Inst::GetElementPtr(ty) => {
let (ptr, idx) = data.unwrap().as_lhs_rhs();
let src = registers
.iter()
.find(|(_, node)| node == &&ptr)
.map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
.or_else(|| {
allocas
.get(&ptr)
.map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
})
.expect(&format!(
"src_reg from node %{ptr} not found: {registers:?}"
));
let dst_reg = register_store.take_any().unwrap();
if let ImmRegMem::Mem(_) = &src {
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"lea {}, {}",
ImmRegMem::Reg(dst_reg, Width::QWord),
src,
)?;
}
let offset = idx * ty.size();
if offset != 0 {
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"lea {}, [{} + {offset}]",
ImmRegMem::Reg(dst_reg, Width::QWord),
ImmRegMem::Reg(dst_reg, Width::QWord),
)?;
}
if let ImmRegMem::Reg(reg, _) = src {
register_store.free(reg);
registers.remove(&reg);
}
registers.insert(dst_reg, node);
}
Inst::Parameter(_) => {
let param_reg = Registers::sysv_param_idx(param_count).unwrap();
param_count += 1;
register_store.force_take(param_reg);
writeln!(&mut func.entry, "push {}", param_reg.display(Width::QWord))?;
stack_offset += 8;
allocas.insert(node, stack_offset);
registers.insert(param_reg, node);
}
Inst::Eq(_)
| Inst::Neq(_)
| Inst::Gt(_)
| Inst::Lt(_)
| Inst::Ge(_)
| Inst::Le(_) => {}
Inst::Add(ty) => {
let (src, dst) = data.unwrap().as_lhs_rhs();
let (&src_reg, _) = registers.iter().find(|(_, node)| node == &&src).unwrap();
let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap();
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"add {}, {}",
dst_reg.display(Width::from_size(ty.size()).unwrap()),
src_reg.display(Width::from_size(ty.size()).unwrap()),
)?;
if src_reg != dst_reg {
register_store.free(src_reg);
registers.remove(&src_reg);
}
registers.insert(dst_reg, node);
}
Inst::Sub(ty) => {
let (src, dst) = data.unwrap().as_lhs_rhs();
let src = registers
.iter()
.find(|(_, node)| node == &&src)
.map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
.or_else(|| {
allocas
.get(&src)
.map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
})
.expect(&format!(
"src_reg from node %{src} not found: {registers:?}"
));
let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap();
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"sub {}, {}",
dst_reg.display(Width::from_size(ty.size()).unwrap()),
src,
)?;
if let ImmRegMem::Reg(reg, _) = src {
if reg != dst_reg {
register_store.free(reg);
registers.remove(&reg);
}
}
registers.insert(dst_reg, node);
}
Inst::Mul(ty) => {
let (src, dst) = data.unwrap().as_lhs_rhs();
let src = registers
.iter()
.find(|(_, node)| node == &&src)
.map(|(reg, _)| ImmRegMem::Reg(*reg, Width::from_size(ty.size()).unwrap()))
.or_else(|| {
allocas
.get(&src)
.map(|&offset| ImmRegMem::Mem(StackMem::new(offset)))
})
.expect(&format!(
"src_reg from node %{src} not found: {registers:?}"
));
let (&dst_reg, _) = registers.iter().find(|(_, node)| node == &&dst).unwrap();
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"imul {}, {}",
dst_reg.display(Width::from_size(ty.size()).unwrap()),
src,
)?;
if let ImmRegMem::Reg(reg, _) = src {
if reg != dst_reg {
register_store.free(reg);
registers.remove(&reg);
}
}
registers.insert(dst_reg, node);
}
Inst::ShiftLeft(_) => todo!(),
Inst::ShiftRight(_) => todo!(),
Inst::Div(_) => todo!(),
Inst::Rem(_) => todo!(),
Inst::BitAnd(_) => todo!(),
Inst::BitOr(_) => todo!(),
Inst::BitXOr(_) => todo!(),
Inst::Negate(_) => todo!(),
Inst::Not(_) => todo!(),
Inst::ExplicitCast(_, _) => todo!(),
Inst::ReturnValue(_) => {
let val = data.unwrap().lhs;
let (&reg, _) = registers.iter().find(|(_, node)| node == &&val)
.expect(&format!(
"location for node %{val} not found: \nregisters: {registers:?}\nallocas: {allocas:?}"
));
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"mov rax, {}\njmp {}__epilogue",
reg.display(Width::QWord),
func.name
)?;
register_store.free(reg);
registers.remove(&reg);
}
Inst::Return => {
writeln!(
func.branches.get_mut(&current_branch).unwrap(),
"jmp {}__epilogue",
func.name
)?;
}
}
}
func.stack_size = stack_offset;
func.used_registers = register_store.used.into_iter().collect();
self.functions.push(func);
Ok(())
}
fn assemble(&mut self) -> core::fmt::Result {
loop {
let Some((inst, data)) = self.ir.next() else {
break;
};
match inst {
Inst::FunctionStart => {
let name = self.strings.get_str(data.unwrap().as_index());
self.assemble_function(name.to_owned())?
}
_ => {}
}
}
Ok(())
}
fn finish<W: core::fmt::Write>(&self, w: &mut W) -> core::fmt::Result {
writeln!(w, ".intel_syntax")?;
writeln!(w, ".text")?;
for func in self.functions.iter() {
writeln!(w, ".globl {}", func.name)?;
}
for func in self.functions.iter() {
func.write(w)?;
}
Ok(())
}
}
use crate::mir; use crate::mir;
pub struct MirBuilder<'a> { pub struct MirBuilder<'a> {
@ -1911,44 +1566,4 @@ fn u10(x: i10) -> i10 {
let mut mir = MirBuilder::new(&ir, tree.strings); let mut mir = MirBuilder::new(&ir, tree.strings);
mir.build(); mir.build();
} }
#[test]
fn ir() {
let src = "
fn main() -> u32 {
let a: u32 = 0u32 + 3u32;
let b = &a;
return *b * 2u32;
}
fn square(x: u32) -> u32 {
x * x
}
";
let tokens = Tokenizer::new(src.as_bytes()).unwrap();
let mut tree = Tree::new();
tree.parse(tokens.iter()).unwrap();
let mut buf = String::new();
tree.render(&mut buf).unwrap();
println!("{buf}");
let mut ir = IR::new();
let builder = ir.build(&mut tree);
let mut buf = String::new();
builder.render(&mut buf).unwrap();
println!("{buf}");
let strings = tree.strings;
let mut mir = MirBuilder::new(&ir, strings.clone());
mir.build();
let mut assembler = Assembler::from_ir(&ir, strings);
assembler.assemble().unwrap();
let mut buf = String::new();
assembler.finish(&mut buf).unwrap();
println!("{buf}");
}
} }