From b0f52da5865283753b304e97f8251b51e38d20fa Mon Sep 17 00:00:00 2001 From: Janis Date: Sun, 1 Sep 2024 14:28:48 +0200 Subject: [PATCH] phi node --- src/mir.rs | 135 ++++++++++++++++++++++++++++++++++++------- src/parser.rs | 8 +++ src/triples.rs | 151 +++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 268 insertions(+), 26 deletions(-) diff --git a/src/mir.rs b/src/mir.rs index 94b40db..be069eb 100644 --- a/src/mir.rs +++ b/src/mir.rs @@ -211,6 +211,13 @@ pub enum Inst { IsGe(bool), /// lhs IsLe(bool), + // jrcxz for now + /// lhs, rhs + Branch(u32), + /// lhs + Jump, + /// lhs, rhs + Phi2(Type), } impl Inst { @@ -229,8 +236,11 @@ impl Inst { | Inst::ConstantSinglePrecision | Inst::ConstantDoublePrecision | Inst::Cmp(_) + | Inst::Branch(_) + | Inst::Jump | Inst::Return => None, - Inst::GetElementPtr(ty) + Inst::Phi2(ty) + | Inst::GetElementPtr(ty) | Inst::Load(ty) | Inst::LoadRegister(ty) | Inst::Parameter(ty) @@ -270,6 +280,8 @@ impl Inst { // TODO: need to account for spilled values eventually; probably move this to `Mir`. match self { Inst::Label + | Inst::Branch(_) + | Inst::Jump | Inst::ConstantBytes | Inst::ConstantByte | Inst::ConstantWord @@ -312,6 +324,7 @@ impl Inst { | Inst::IsLt(_) | Inst::IsGe(_) | Inst::IsLe(_) + | Inst::Phi2(_) | Inst::ShiftLeft(_) | Inst::ShiftRightSigned(_) | Inst::ShiftRightUnsigned(_) => true, @@ -374,6 +387,11 @@ impl Data { pub fn binary(lhs: u32, rhs: u32) -> Data { Self { binary: (lhs, rhs) } } + pub fn binary_noderefs(lhs: NodeRef, rhs: NodeRef) -> Data { + Self { + binary: (lhs.0, rhs.0), + } + } pub fn none() -> Data { Self { none: () } } @@ -731,8 +749,8 @@ impl NodeRef { } /// invalid pseudo-handle to the past-the-end node. - const MAX: Self = NodeRef(u32::MAX); - const MIN: Self = NodeRef(0); + pub const MAX: Self = NodeRef(u32::MAX); + pub const MIN: Self = NodeRef(0); } impl Display for NodeRef { @@ -742,14 +760,18 @@ impl Display for NodeRef { } impl Mir { - fn get_node(&self, node: NodeRef) -> (Inst, Data) { + pub fn get_node(&self, node: NodeRef) -> (Inst, Data) { (self.nodes[node.index()], self.data[node.index()]) } #[allow(dead_code)] - fn get_node_mut(&mut self, node: NodeRef) -> (&mut Inst, &mut Data) { + pub fn get_node_mut(&mut self, node: NodeRef) -> (&mut Inst, &mut Data) { (&mut self.nodes[node.index()], &mut self.data[node.index()]) } + pub fn set_node_data(&mut self, node: NodeRef, data: Data) { + self.data[node.index()] = data; + } + fn indices(&self) -> impl Iterator { (0..self.nodes.len() as u32).map(|n| NodeRef::from(n)) } @@ -845,6 +867,15 @@ impl Mir { }; self.push(Inst::Cmp(ty), Data::binary(lhs, rhs)) } + pub fn gen_jmp(&mut self, to: u32) -> u32 { + self.push(Inst::Jump, Data::node(to)) + } + pub fn gen_branch(&mut self, on: u32, lhs: u32, rhs: u32) -> u32 { + self.push(Inst::Branch(on), Data::binary(lhs, rhs)) + } + pub fn gen_phi2(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + self.push(Inst::Phi2(ty), Data::binary(lhs, rhs)) + } pub fn gen_cmp_byte( &mut self, @@ -1300,6 +1331,18 @@ impl Mir { Inst::Return => { writeln!(w, "%{node} = return") } + Inst::Jump => { + let lhs = data.as_node(); + writeln!(w, "%{node} = jmp %{lhs}") + } + Inst::Branch(condition) => { + let (lhs, rhs) = data.as_binary(); + writeln!(w, "%{node} = br bool %{condition}, [%{lhs}, %{rhs}]") + } + Inst::Phi2(ty) => { + let (lhs, rhs) = data.as_binary(); + writeln!(w, "%{node} = phi2 [{ty} %{lhs}, {ty} %{rhs}]") + } } } @@ -1555,6 +1598,9 @@ pub mod liveness { | Inst::LoadRegister(_) => { references.insert((data.as_noderef(), node)); } + Inst::Branch(condition) => { + references.insert((NodeRef(condition), node)); + } Inst::Cmp(_) | Inst::Store(_) | Inst::Add(_) @@ -1571,6 +1617,7 @@ pub mod liveness { | Inst::BitAnd(_) | Inst::BitOr(_) | Inst::BitXOr(_) + | Inst::Phi2(_) | Inst::ShiftLeft(_) | Inst::ShiftRightSigned(_) | Inst::ShiftRightUnsigned(_) => { @@ -1586,6 +1633,7 @@ pub mod liveness { // don't want a wildcard match here to make sure new instructions // are handled here when they are added. Inst::Return + | Inst::Jump | Inst::Parameter(_) | Inst::Label | Inst::ConstantBytes @@ -1705,6 +1753,15 @@ pub mod liveness { if let Some(dst) = self.mir.dst_node(noderef) { _ = self.colors.try_insert(dst, Color::Tentative(color)); } + + // for any Phi(y_1,y_2, y_n) give y_i the color of Phi + // reasonably certain that this will never fail to color all phi nodes the same color. + if let Some(inputs) = self.mir.get_phi_inputs(noderef) { + eprintln!("coloring {inputs:?} {color}"); + for node in inputs { + _ = self.colors.insert(node, Color::Tentative(color)); + } + } } fn colorise(&mut self) -> BTreeMap { @@ -1726,6 +1783,16 @@ pub mod liveness { } impl Mir { + fn get_phi_inputs(&self, node: NodeRef) -> Option> { + let (inst, data) = self.get_node(node); + match inst { + Inst::Phi2(_) => { + let (lhs, rhs) = data.as_binary_noderefs(); + Some([lhs, rhs].to_vec()) + } + _ => None, + } + } /// returns the in/out operand, if it exists: example would be (%node = add rax, rcx) -> rax fn dst_node(&self, node: NodeRef) -> Option { // for each node, look at the dst node and see if it has preferred @@ -1765,9 +1832,10 @@ impl Mir { | Inst::ConstantDoublePrecision | Inst::ExternRef | Inst::Alloca + | Inst::Jump + | Inst::Return | Inst::Store(_) | Inst::ReturnValue(_) - | Inst::Return | Inst::SignExtend(_) | Inst::ZeroExtend(_) | Inst::Mul(_) @@ -1777,6 +1845,8 @@ impl Mir { | Inst::Rem(_) | Inst::RemSigned(_) | Inst::Cmp(_) + | Inst::Branch(_) + | Inst::Phi2(_) | Inst::IsEq(_) | Inst::IsNeq(_) | Inst::IsGt(_) @@ -1906,8 +1976,8 @@ impl core::fmt::Display for RipRelative { pub struct Function { name: StringsIndex, constants: BTreeMap, - branches: HashMap, - current_branch: StringsIndex, + branches: BTreeMap, + current_branch: NodeRef, stack_offset: u32, dirty_registers: BTreeSet, } @@ -1916,8 +1986,8 @@ pub struct Function { impl Function { fn new(name: StringsIndex) -> Self { - let current_branch = StringsIndex::none(); - let branches = HashMap::from([(current_branch, String::new())]); + let current_branch = NodeRef::MIN; + let branches = BTreeMap::from([(current_branch, String::new())]); Self { name, constants: BTreeMap::new(), @@ -1932,9 +2002,9 @@ impl Function { fn dirty_register(&mut self, reg: Register) { self.dirty_registers.insert(reg); } - fn create_new_branch(&mut self, index: StringsIndex) { - self.current_branch = index; - self.branches.insert(index, String::new()); + fn create_new_branch(&mut self, node: NodeRef) { + self.current_branch = node; + self.branches.insert(node, String::new()); } fn get_constant_label(&self, i: usize) -> String { @@ -2009,15 +2079,11 @@ impl Function { writeln!(w, "mov rbp, rsp")?; writeln!(w, "sub rsp, {}", self.stack_offset)?; - write!( - w, - "{}", - self.branches.remove(&StringsIndex::none()).unwrap() - )?; + write!(w, "{}", self.branches.remove(&NodeRef::MIN).unwrap())?; for (branch, content) in &self.branches { if name != "main" { - writeln!(w, "{name}_{}:", strings.get_str(*branch))?; + writeln!(w, "{name}_L{}:", branch.0)?; write!(w, "{content}")?; } } @@ -2076,6 +2142,7 @@ impl Mir { } Inst::GetElementPtr(ty) => liveness.get_register(node.into()).unwrap().into(), Inst::Parameter(ty) + | Inst::Phi2(ty) | Inst::Add(ty) | Inst::Sub(ty) | Inst::Mul(ty) @@ -2144,7 +2211,7 @@ impl Mir { match inst { Inst::Label => { - func.create_new_branch(data.as_index()); + func.create_new_branch(NodeRef(node)); } Inst::ConstantBytes | Inst::ConstantByte @@ -3098,6 +3165,34 @@ impl Mir { Inst::Return => { writeln!(func.current_branch(), "jmp {name}__epilogue")?; } + Inst::Jump => { + let lhs = data.as_node(); + if lhs != node + 1 { + writeln!(func.current_branch(), "jmp {name}__L{lhs}")?; + } + } + Inst::Branch(condition) => { + let cond = + self.node_as_operand(&liveness, &mapping, &mut func, strings, condition); + let (lhs, rhs) = data.as_binary(); + writeln!(func.current_branch(), "test {cond}, {cond}")?; + + match (lhs, rhs) { + _ if lhs == node + 1 => { + writeln!(func.current_branch(), "jz {name}__L{rhs}")?; + } + _ if rhs == node + 1 => { + writeln!(func.current_branch(), "jnz {name}__L{lhs}")?; + } + _ => { + writeln!(func.current_branch(), "jnz {name}__L{lhs}")?; + writeln!(func.current_branch(), "jz {name}__L{rhs}")?; + } + } + } + Inst::Phi2(ty) => { + // noop, need to ensure that input nodes are merged within their branch + } } } diff --git a/src/parser.rs b/src/parser.rs index d2af06f..1da999e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1674,6 +1674,14 @@ impl Tree { Ok(()) } + pub fn peer_type_of_nodes_unwrap(&self, lhs: Node, rhs: Node) -> Type { + self.peer_type_of_nodes(lhs, rhs).expect({ + let at = self.type_of_node(lhs); + let bt = self.type_of_node(rhs); + &format!("incompatible types for %{lhs}({at}) and %{rhs}({bt})") + }) + } + pub fn peer_type_of_nodes(&self, lhs: Node, rhs: Node) -> Option { let lty = self.type_of_node(lhs); let rty = self.type_of_node(rhs); diff --git a/src/triples.rs b/src/triples.rs index 253ba7d..24ce946 100644 --- a/src/triples.rs +++ b/src/triples.rs @@ -73,6 +73,22 @@ impl Type2 { fn align(&self) -> u32 { self.size() } + + fn try_from_ast_type(ty: &Type) -> Option { + match ty { + Type::Bool => Some(Type2::Bool), + Type::Integer(i) => Some(Type2::Integral(i.signed, i.bits)), + Type::Floating(f) => match f { + crate::ast::FloatingType::Binary32 => Some(Type2::Binary32), + crate::ast::FloatingType::Binary64 => Some(Type2::Binary64), + }, + Type::Pointer { .. } => Some(Type2::Pointer), + _ => { + None + //unimplemented!("conversion from {value:?} to triples type not implemented") + } + } + } } impl core::fmt::Display for Type2 { @@ -182,6 +198,8 @@ pub enum Inst { Branch(Node), /// lhs: Label Jump, + /// lhs, rhs + Phi2(Type2), } impl Inst { @@ -555,11 +573,12 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> { let condition = self.visit(*condition); let br = self.ir.push(Inst::Branch(condition), None); - let lhs = self.visit(*body); + let label_lhs = self.ir.push(Inst::Label, Some(StringsIndex::none().into())); + let _ = self.visit(*body); let jmp = self.ir.push(Inst::Jump, None); let nojump = self.ir.push(Inst::Label, Some(StringsIndex::none().into())); - self.ir.data[br as usize] = Some(Data::new(lhs, nojump)); + self.ir.data[br as usize] = Some(Data::new(label_lhs, nojump)); self.ir.data[jmp as usize] = Some(Data::lhs(nojump)); br } @@ -569,21 +588,29 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> { else_expr, } => { assert_eq!(self.tree.type_of_node(*condition), Type::Bool); + let ty = self.tree.peer_type_of_nodes_unwrap(*body, *else_expr); let condition = self.visit(*condition); let br = self.ir.push(Inst::Branch(condition), None); + let label_lhs = self.ir.push(Inst::Label, Some(StringsIndex::none().into())); let lhs = self.visit(*body); let ljmp = self.ir.push(Inst::Jump, None); + let label_rhs = self.ir.push(Inst::Label, Some(StringsIndex::none().into())); let rhs = self.visit(*else_expr); let rjmp = self.ir.push(Inst::Jump, None); let nojump = self.ir.push(Inst::Label, Some(StringsIndex::none().into())); + let phi = if let Some(ty) = Type2::try_from_ast_type(&ty) { + self.ir.push(Inst::Phi2(ty), Some(Data::new(lhs, rhs))) + } else { + br + }; - self.ir.data[br as usize] = Some(Data::new(lhs, rhs)); + self.ir.data[br as usize] = Some(Data::new(label_lhs, label_rhs)); self.ir.data[ljmp as usize] = Some(Data::lhs(nojump)); self.ir.data[rjmp as usize] = Some(Data::lhs(nojump)); - br + phi } _ => { dbg!(&self.tree.nodes[node]); @@ -764,12 +791,20 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> { } Inst::Branch(condition) => { let (lhs, rhs) = data.as_lhs_rhs(); - writeln_indented!(indent, w, "%{node} = br bool %{condition} [%{lhs}, %{rhs}]")?; + writeln_indented!( + indent, + w, + "%{node} = br bool %{condition}, [%{lhs}, %{rhs}]" + )?; } Inst::Jump => { let lhs = data.lhs; writeln_indented!(indent, w, "%{node} = jmp %{lhs}")?; } + Inst::Phi2(ty) => { + let (lhs, rhs) = data.as_lhs_rhs(); + writeln_indented!(indent, w, "%{node} = phi [{ty} %{lhs}, {ty} %{rhs}]")?; + } _ => { unimplemented!("{inst:?} rendering unimplemented") } @@ -1085,6 +1120,22 @@ impl<'a> MirBuilder<'a> { fn build_function(&mut self, name: StringsIndex) { let mut mir = mir::Mir::new(name); let mut mapping = BTreeMap::::new(); + // map of label -> unresolved mir jump or branch instruction + // stored as a tree of (label, unresolved) + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] + enum LeftRight { + Left(mir::NodeRef), + Right(mir::NodeRef), + } + impl LeftRight { + fn noderef(self) -> mir::NodeRef { + match self { + LeftRight::Left(noderef) => noderef, + LeftRight::Right(noderef) => noderef, + } + } + } + let mut unresolved_jumps_branches = BTreeSet::<(Node, LeftRight)>::new(); loop { let ir_node = self.ir.node(); @@ -1096,7 +1147,43 @@ impl<'a> MirBuilder<'a> { self.ir.offset -= 1; break; } - Inst::Label => mir.gen_label(data.unwrap().as_index()), + Inst::Label => { + let label = mir.gen_label(data.unwrap().as_index()); + let range = unresolved_jumps_branches + .range( + (ir_node, LeftRight::Left(mir::NodeRef::MIN)) + ..=(ir_node, LeftRight::Right(mir::NodeRef::MAX)), + ) + .map(|(_, n)| n) + .cloned() + .collect::>(); + + for unresolved in range { + unresolved_jumps_branches.remove(&(ir_node, unresolved)); + + let mir_node = unresolved.noderef(); + let (inst, data) = mir.get_node_mut(mir_node); + + match inst { + mir::Inst::Jump => { + *data = mir::Data::node(label); + } + mir::Inst::Branch(_) => { + let (lhs, rhs) = data.as_binary_noderefs(); + + *data = match unresolved { + LeftRight::Left(_) => mir::Data::binary(label, rhs.0), + LeftRight::Right(_) => mir::Data::binary(lhs.0, label), + }; + } + _ => { + unreachable!() + } + } + } + + label + } Inst::ConstantU32 => mir.push( mir::Inst::ConstantDWord, mir::Data::imm32(data.unwrap().as_u32()), @@ -1464,6 +1551,58 @@ impl<'a> MirBuilder<'a> { mir.gen_ret_val(ty.mir_type(), src) } Inst::Return => mir.gen_ret(), + Inst::Jump => { + let label = data.unwrap().as_u32(); + + let jmp = mir.gen_jmp(label); + + let label = match mapping.get(&label) { + Some(&label) => label, + None => { + unresolved_jumps_branches + .insert((label, LeftRight::Left(mir::NodeRef(jmp)))); + 0 + } + }; + + mir.set_node_data(mir::NodeRef(jmp), mir::Data::node(label)); + + jmp + } + Inst::Branch(condition) => { + let condition = *mapping.get(&condition).unwrap(); + let (lhs, rhs) = data.unwrap().as_lhs_rhs(); + + let br = mir.gen_branch(condition, lhs, rhs); + + let lhs = match mapping.get(&lhs) { + Some(&n) => n, + None => { + unresolved_jumps_branches + .insert((lhs, LeftRight::Left(mir::NodeRef(br)))); + 0 + } + }; + let rhs = match mapping.get(&rhs) { + Some(&n) => n, + None => { + unresolved_jumps_branches + .insert((rhs, LeftRight::Right(mir::NodeRef(br)))); + 0 + } + }; + + mir.set_node_data(mir::NodeRef(br), mir::Data::binary(lhs, rhs)); + + br + } + Inst::Phi2(ty) => { + let (src, dst) = data.unwrap().as_lhs_rhs(); + let lhs = *mapping.get(&src).unwrap(); + let rhs = *mapping.get(&dst).unwrap(); + + mir.gen_phi2(ty.mir_type(), lhs, rhs) + } #[allow(unreachable_patterns)] _ => { unimplemented!()