From 3aee606ca206558363b8de87d1518c7bf4ce03bb Mon Sep 17 00:00:00 2001 From: Janis Date: Thu, 29 Aug 2024 01:50:58 +0200 Subject: [PATCH] mir.rs changes for last commit's triples.rs changes --- src/mir.rs | 487 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 469 insertions(+), 18 deletions(-) diff --git a/src/mir.rs b/src/mir.rs index 65f04bf..14d4577 100644 --- a/src/mir.rs +++ b/src/mir.rs @@ -109,7 +109,7 @@ pub enum Inst { /// imm64 ConstantQWord, /// src - LoadConstant(Type), // hint for loading constant into register + LoadRegister(Type), // hint for loading value into register /// ast-node ExternRef, /// size, align @@ -162,18 +162,18 @@ impl Inst { match self { Inst::Label | Inst::ConstantBytes - | Inst::ConstantByte - | Inst::ConstantWord - | Inst::ConstantDWord - | Inst::ConstantQWord | Inst::ExternRef | Inst::Alloca | Inst::ReturnValue | Inst::Store(_) + | Inst::ConstantByte + | Inst::ConstantWord + | Inst::ConstantDWord + | Inst::ConstantQWord | Inst::Return => None, Inst::GetElementPtr(ty) | Inst::Load(ty) - | Inst::LoadConstant(ty) + | Inst::LoadRegister(ty) | Inst::Parameter(ty) | Inst::Add(ty) | Inst::Sub(ty) @@ -194,6 +194,7 @@ impl Inst { } fn has_value(&self) -> bool { // basically, when an arithmetic instruction has two immediates, then just replace it with a mov into the dst reg + // TODO: need to account for spilled values eventually; probably move this to `Mir`. match self { Inst::Label | Inst::ConstantBytes @@ -208,7 +209,7 @@ impl Inst { | Inst::Return => false, Inst::GetElementPtr(_) | Inst::Load(_) - | Inst::LoadConstant(_) + | Inst::LoadRegister(_) | Inst::Parameter(_) | Inst::Add(_) | Inst::Sub(_) @@ -311,6 +312,281 @@ impl Data { } } +use bitflags::bitflags; +bitflags! { + #[derive(Debug, Clone, Copy)] + struct BinaryOperandFlags: u8 { + const RhsReg = 0b00000001; + const RhsMem = 0b00000010; + const RhsImm = 0b00000100; + const LhsReg = 0b10000000; + const LhsMem = 0b01000000; + const LhsImm = 0b00100000; + + const RegReg = 0b10000001; + const RegMem = 0b10000010; + const RegImm = 0b10000100; + const MemReg = 0b01000001; + const MemMem = 0b01000010; + const MemImm = 0b01000100; + + const RhsAll = 0b00000111; + const LhsAll = 0b11100000; + + const NULL = 0b0; + } +} +bitflags! { + #[derive(Debug, Clone, Copy)] + struct OperandKinds: u8 { + const RegReg = 0b00000001; + const RegMem = 0b00000010; + const RegImm = 0b00000100; + const MemReg = 0b00001000; + const MemMem = 0b00010000; + const MemImm = 0b00100000; + } +} + +impl OperandKinds { + /// works for: add,sub,or,and,sbb,adc + fn add() -> Self { + Self::RegImm | Self::MemImm | Self::RegMem | Self::MemReg | Self::RegReg + } + /// imul is special... + fn imul() -> Self { + Self::RegImm | Self::MemImm | Self::RegMem | Self::RegReg + } + /// works for: div,idiv,mul + fn mul() -> Self { + Self::RegMem | Self::RegReg + } + /// works for: mulss,mulsd,divss,divsd,addss,addsd,subss,subsd + fn sse() -> Self { + Self::RegMem | Self::RegReg + } + /// works for: shl,shr,sar,sal + fn shift() -> Self { + Self::RegReg | Self::RegImm | Self::MemImm | Self::MemReg + } + const fn to_rhs_binop(self) -> BinaryOperandFlags { + let reg = if self.intersects(Self::RegReg.union(Self::MemReg)) { + BinaryOperandFlags::RhsReg + } else { + BinaryOperandFlags::empty() + }; + let mem = if self.intersects(Self::RegMem.union(Self::MemMem)) { + BinaryOperandFlags::RhsMem + } else { + BinaryOperandFlags::empty() + }; + let imm = if self.intersects(Self::RegImm.union(Self::MemImm)) { + BinaryOperandFlags::RhsImm + } else { + BinaryOperandFlags::empty() + }; + + reg.union(mem).union(imm) + } + const fn to_lhs_binop(self) -> BinaryOperandFlags { + let reg = if self.intersects(Self::RegReg.union(Self::RegImm).union(Self::RegMem)) { + BinaryOperandFlags::LhsReg + } else { + BinaryOperandFlags::empty() + }; + let mem = if self.intersects(Self::MemReg.union(Self::MemMem).union(Self::MemImm)) { + BinaryOperandFlags::LhsMem + } else { + BinaryOperandFlags::empty() + }; + reg.union(mem) + } + + fn reduce_with_rhs(self, lhs: BinaryOperandFlags) -> OperandKinds { + let mut out = self; + if !lhs.contains(BinaryOperandFlags::RhsImm) { + out = out.difference(Self::MemImm | Self::RegImm); + } + if !lhs.contains(BinaryOperandFlags::RhsMem) { + out = out.difference(Self::MemMem | Self::RegMem); + } + if !lhs.contains(BinaryOperandFlags::RhsReg) { + out = out.difference(Self::RegReg | Self::MemReg); + } + out + } + fn reduce_with_lhs(self, lhs: BinaryOperandFlags) -> OperandKinds { + let mut out = self; + if !lhs.contains(BinaryOperandFlags::LhsMem) { + out = out.difference(Self::MemReg | Self::MemMem | Self::MemImm); + } + if !lhs.contains(BinaryOperandFlags::LhsReg) { + out = out.difference(Self::RegReg | Self::RegMem | Self::RegImm); + } + out + } +} + +enum OperandKind { + Mem, + Imm, + Reg, +} + +impl OperandKind { + const fn as_rhs(self) -> BinaryOperandFlags { + match self { + OperandKind::Mem => BinaryOperandFlags::RhsMem, + OperandKind::Imm => BinaryOperandFlags::RhsImm, + OperandKind::Reg => BinaryOperandFlags::RhsReg, + } + } + const fn as_lhs(self) -> BinaryOperandFlags { + match self { + OperandKind::Mem => BinaryOperandFlags::LhsMem, + OperandKind::Imm => BinaryOperandFlags::LhsImm, + OperandKind::Reg => BinaryOperandFlags::LhsReg, + } + } +} + +struct BinaryOperands<'a> { + mir: &'a mut Mir, + commutative: bool, + kinds: OperandKinds, +} + +struct BinaryOperandsRunner<'a> { + inner: BinaryOperands<'a>, + lhs: (u32, Type), + rhs: (u32, Type), +} + +impl<'a> BinaryOperandsRunner<'a> { + fn new(inner: BinaryOperands<'a>, lhs: u32, lhs_type: Type, rhs: u32, rhs_type: Type) -> Self { + Self { + inner, + lhs: (lhs, lhs_type), + rhs: (rhs, rhs_type), + } + } + fn mir_mut(&mut self) -> &mut Mir { + self.inner.mir + } + fn mir(&self) -> &Mir { + self.inner.mir + } + fn lhs(&self) -> u32 { + self.lhs.0 + } + fn rhs(&self) -> u32 { + self.rhs.0 + } + fn lhs_type(&self) -> Type { + self.lhs.1 + } + fn rhs_type(&self) -> Type { + self.rhs.1 + } + fn lhs_rhs(&self) -> (u32, u32) { + (self.lhs(), self.rhs()) + } + fn canonicalise_lhs_with_reduced_kinds(&mut self, kinds: OperandKinds) { + let (lhs, ty) = self.lhs; + + let l_legal = kinds.to_lhs_binop(); + let l_kind = self.mir().as_operand_kind(self.lhs()).as_lhs(); + + if l_legal.contains(l_kind) { + } else if l_legal.contains(BinaryOperandFlags::LhsReg) { + self.lhs.0 = self.mir_mut().to_reg(ty, lhs); + } else if l_legal.contains(BinaryOperandFlags::LhsMem) { + self.lhs.0 = self.mir_mut().gen_spill_value(lhs); + } else { + unreachable!() + } + } + + fn try_swap(&mut self) { + let (lhs, rhs) = self.lhs_rhs(); + let l_legal = self.inner.kinds.to_lhs_binop(); + let l_kind = self.mir().as_operand_kind(lhs).as_lhs(); + let r_kind = self.mir().as_operand_kind(rhs).as_rhs(); + + if self.inner.commutative && (!l_legal.contains(l_kind) && l_legal.contains(r_kind)) { + core::mem::swap(&mut self.lhs, &mut self.rhs); + } + } + + fn order(&mut self) { + self.try_swap(); + let rhs = self.rhs(); + let ty = self.rhs_type(); + let r_legal = self.inner.kinds.to_rhs_binop(); + let r_kind = self.mir().as_operand_kind(rhs).as_rhs(); + + if r_legal.contains(r_kind) { + } else if r_legal.contains(BinaryOperandFlags::RhsReg) { + self.rhs.0 = self.mir_mut().to_reg(ty, rhs); + } else if r_legal.contains(BinaryOperandFlags::RhsMem) { + self.rhs.0 = self.mir_mut().gen_spill_value(rhs); + } else { + unreachable!() + } + + let rhs = self.rhs(); + self.canonicalise_lhs_with_reduced_kinds( + self.inner + .kinds + .reduce_with_rhs(self.mir().as_operand_kind(rhs).as_rhs()), + ); + } +} + +impl<'a> BinaryOperands<'a> { + fn new(mir: &'a mut Mir, commutative: bool, kinds: OperandKinds) -> Self { + Self { + mir, + commutative, + kinds, + } + } + + fn new_add_or_and_xor_adc(mir: &'a mut Mir) -> Self { + Self::new(mir, true, OperandKinds::add()) + } + fn new_sub_sbb(mir: &'a mut Mir) -> Self { + Self::new(mir, false, OperandKinds::add()) + } + fn new_sse(mir: &'a mut Mir) -> Self { + Self::new(mir, true, OperandKinds::sse()) + } + fn new_mul(mir: &'a mut Mir) -> Self { + Self::new(mir, true, OperandKinds::mul()) + } + fn new_imul(mir: &'a mut Mir) -> Self { + Self::new(mir, true, OperandKinds::imul()) + } + fn new_div_idiv_rem_irem(mir: &'a mut Mir) -> Self { + Self::new(mir, false, OperandKinds::mul()) + } + fn new_shift(mir: &'a mut Mir) -> Self { + Self::new(mir, false, OperandKinds::shift()) + } + + fn wrangle(self, lhs: u32, lhs_type: Type, rhs: u32, rhs_type: Type) -> (u32, u32) { + let mut runner = BinaryOperandsRunner { + inner: self, + lhs: (lhs, lhs_type), + rhs: (rhs, rhs_type), + }; + + runner.order(); + + runner.lhs_rhs() + } +} + pub struct Mir { pub nodes: Vec, pub data: Vec, @@ -324,6 +600,62 @@ impl Mir { } } + fn as_operand_kind(&self, node: u32) -> OperandKind { + if self.is_imm(node) { + OperandKind::Imm + } else if self.is_register(node) { + OperandKind::Reg + } else { + OperandKind::Mem + } + } + + pub fn to_reg(&mut self, ty: Type, node: u32) -> u32 { + if !self.is_register(node) { + self.gen_load_register(ty, node) + } else { + node + } + } + + pub fn type_of_node(&self, node: u32) -> Option { + self.nodes[node as usize].value_type() + } + + pub fn is_register(&self, node: u32) -> bool { + match self.nodes[node as usize] { + Inst::LoadRegister(_) + | Inst::Load(_) + | Inst::GetElementPtr(_) + | Inst::Parameter(_) + | Inst::Add(_) + | Inst::Sub(_) + | Inst::Mul(_) + | Inst::MulSigned(_) + | Inst::Div(_) + | Inst::DivSigned(_) + | Inst::Rem(_) + | Inst::RemSigned(_) + | Inst::BitAnd(_) + | Inst::BitOr(_) + | Inst::BitXOr(_) + | Inst::Negate(_) + | Inst::ShiftLeft(_) + | Inst::ShiftRightSigned(_) + | Inst::ShiftRightUnsigned(_) => true, + _ => false, + } + } + + pub fn is_imm(&self, node: u32) -> bool { + match self.nodes[node as usize] { + Inst::ConstantByte | Inst::ConstantWord | Inst::ConstantDWord | Inst::ConstantQWord => { + true + } + _ => false, + } + } + pub fn push(&mut self, inst: Inst, data: Data) -> u32 { let node = self.nodes.len() as u32; self.nodes.push(inst); @@ -344,8 +676,15 @@ impl Mir { pub fn gen_u64(&mut self, value: u64) -> u32 { self.push(Inst::ConstantQWord, Data::imm64(value)) } - pub fn gen_load_const(&mut self, ty: Type, src: u32) -> u32 { - self.push(Inst::LoadConstant(ty), Data::node(src)) + pub fn gen_load_register(&mut self, ty: Type, src: u32) -> u32 { + self.push(Inst::LoadRegister(ty), Data::node(src)) + } + pub fn gen_spill_value(&mut self, src: u32) -> u32 { + let ty = self.type_of_node(src).unwrap(); + let size = ty.bytes(); + let alloc = self.gen_alloca(size, size); + _ = self.gen_store(ty, src, alloc); + alloc } pub fn gen_label(&mut self, name: StringsIndex) -> u32 { @@ -394,76 +733,125 @@ impl Mir { masked } } + fn imm_to_reg(&mut self, src: u32) -> u32 { + if self.is_imm(src) { + // SAFETY: imms have values and thus types + self.gen_load_register(self.type_of_node(src).unwrap(), src) + } else { + src + } + } pub fn gen_add(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = if ty.is_floating() { + BinaryOperands::new_sse(self).wrangle(lhs, ty, rhs, ty) + } else { + BinaryOperands::new_add_or_and_xor_adc(self).wrangle(lhs, ty, rhs, ty) + }; self.push(Inst::Add(ty), Data::binary(lhs, rhs)) } pub fn gen_sub(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = if ty.is_floating() { + BinaryOperands::new_sse(self).wrangle(lhs, ty, rhs, ty) + } else { + BinaryOperands::new_sub_sbb(self).wrangle(lhs, ty, rhs, ty) + }; self.push(Inst::Sub(ty), Data::binary(lhs, rhs)) } pub fn gen_mul(&mut self, ty: Type, signed: bool, lhs: u32, rhs: u32) -> u32 { - if signed && !ty.is_floating() { + if ty.is_floating() { + self.gen_mul_sse(ty, lhs, rhs) + } else if signed { self.gen_mul_signed(ty, lhs, rhs) } else { self.gen_mul_unsigned(ty, lhs, rhs) } } + pub fn gen_mul_sse(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_sse(self).wrangle(lhs, ty, rhs, ty); + self.push(Inst::Mul(ty), Data::binary(lhs, rhs)) + } pub fn gen_mul_unsigned(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_mul(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::Mul(ty), Data::binary(lhs, rhs)) } pub fn gen_mul_signed(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_imul(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::MulSigned(ty), Data::binary(lhs, rhs)) } pub fn gen_div(&mut self, ty: Type, signed: bool, lhs: u32, rhs: u32) -> u32 { - if signed && !ty.is_floating() { + if ty.is_floating() { + self.gen_div_sse(ty, lhs, rhs) + } else if signed { self.gen_div_signed(ty, lhs, rhs) } else { self.gen_div_unsigned(ty, lhs, rhs) } } + pub fn gen_div_sse(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_sse(self).wrangle(lhs, ty, rhs, ty); + self.push(Inst::Div(ty), Data::binary(lhs, rhs)) + } pub fn gen_div_unsigned(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_div_idiv_rem_irem(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::Div(ty), Data::binary(lhs, rhs)) } pub fn gen_div_signed(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_div_idiv_rem_irem(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::DivSigned(ty), Data::binary(lhs, rhs)) } pub fn gen_rem(&mut self, ty: Type, signed: bool, lhs: u32, rhs: u32) -> u32 { - if signed && !ty.is_floating() { + if ty.is_floating() { + self.gen_rem_fp(ty, lhs, rhs) + } else if signed { self.gen_rem_signed(ty, lhs, rhs) } else { self.gen_rem_unsigned(ty, lhs, rhs) } } + pub fn gen_rem_fp(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + _ = (ty, lhs, rhs); + todo!() + } pub fn gen_rem_unsigned(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_div_idiv_rem_irem(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::Rem(ty), Data::binary(lhs, rhs)) } pub fn gen_rem_signed(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_div_idiv_rem_irem(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::RemSigned(ty), Data::binary(lhs, rhs)) } pub fn gen_bitand(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { + let (lhs, rhs) = BinaryOperands::new_add_or_and_xor_adc(self).wrangle(lhs, ty, rhs, ty); self.push(Inst::BitAnd(ty), Data::binary(lhs, rhs)) } pub fn gen_bitor(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { - self.push(Inst::BitAnd(ty), Data::binary(lhs, rhs)) + let (lhs, rhs) = BinaryOperands::new_add_or_and_xor_adc(self).wrangle(lhs, ty, rhs, ty); + self.push(Inst::BitOr(ty), Data::binary(lhs, rhs)) } pub fn gen_bitxor(&mut self, ty: Type, lhs: u32, rhs: u32) -> u32 { - self.push(Inst::BitAnd(ty), Data::binary(lhs, rhs)) + let (lhs, rhs) = BinaryOperands::new_add_or_and_xor_adc(self).wrangle(lhs, ty, rhs, ty); + self.push(Inst::BitXOr(ty), Data::binary(lhs, rhs)) } pub fn gen_negate(&mut self, ty: Type, src: u32) -> u32 { + let src = self.imm_to_reg(src); self.push(Inst::Negate(ty), Data::node(src)) } #[doc(alias = "gen_shift_left")] pub fn gen_shl(&mut self, ty: Type, src: u32, shift: u32) -> u32 { + let (src, shift) = BinaryOperands::new_shift(self).wrangle(src, ty, shift, ty); self.push(Inst::ShiftLeft(ty), Data::binary(src, shift)) } #[doc(alias = "gen_shift_right")] pub fn gen_shr(&mut self, ty: Type, src: u32, shift: u32) -> u32 { + let (src, shift) = BinaryOperands::new_shift(self).wrangle(src, ty, shift, ty); self.push(Inst::ShiftRightUnsigned(ty), Data::binary(src, shift)) } #[doc(alias = "gen_shift_right_signed")] pub fn gen_sar(&mut self, ty: Type, src: u32, shift: u32) -> u32 { + let (src, shift) = BinaryOperands::new_shift(self).wrangle(src, ty, shift, ty); self.push(Inst::ShiftRightSigned(ty), Data::binary(src, shift)) } pub fn gen_ret_val(&mut self, val: u32) -> u32 { @@ -501,9 +889,9 @@ impl Mir { Inst::ConstantWord => writeln!(w, "%{node} = imm16({:x?})", data.as_imm16()), Inst::ConstantDWord => writeln!(w, "%{node} = imm32({:x?})", data.as_imm32()), Inst::ConstantQWord => writeln!(w, "%{node} = imm64({:x?})", data.as_imm64()), - Inst::LoadConstant(ty) => { + Inst::LoadRegister(ty) => { let src = data.as_node(); - writeln!(w, "%{node} = load constant {ty} %{src}") + writeln!(w, "%{node} = load register {ty} %{src}") } Inst::ExternRef => writeln!(w, "%{node} = extern %%{}", data.as_node()), Inst::Alloca => { @@ -639,7 +1027,6 @@ impl Mir { } else { in_colors.pop() } { - println!("prefering {reg} for param"); preferred_colors.insert(node, reg); }; inouts.push(node); @@ -667,7 +1054,6 @@ impl Mir { if !ty.is_floating() { _ = preferred_colors.try_insert(lhs, amd64::Register::rax); - _ = preferred_colors.try_insert(rhs, amd64::Register::rax); } } // div wants lhs to be rax, idiv can do either. @@ -909,6 +1295,71 @@ impl Mir { } } +use crate::asm::amd64::{self, Mnemonic, Operand, Operands}; +use crate::variant; +type Asm = Vec<(Mnemonic, Operands)>; + +#[allow(dead_code, unused)] +impl Mir { + fn assemble(&self, strings: &StringTable) { + let mut mapping = HashMap::::new(); + let mut entry = Asm::new(); + let mut branches = HashMap::::new(); + let mut stack_offset = 0; + let mut current_branch = StringsIndex::none(); + branches.insert(current_branch, Asm::new()); + + for i in 0..self.nodes.len() { + let inst = self.nodes[i]; + let data = self.data[i]; + + let branch = branches.get_mut(¤t_branch).unwrap(); + + match inst { + Inst::Label => todo!(), + Inst::ConstantBytes => todo!(), + Inst::ConstantByte => todo!(), + Inst::ConstantWord => todo!(), + Inst::ConstantDWord => todo!(), + Inst::ConstantQWord => todo!(), + Inst::LoadRegister(ty) => todo!(), + Inst::ExternRef => todo!(), + Inst::Alloca => { + let (size, align) = data.as_binary(); + let size = size.next_multiple_of(align); + stack_offset += size; + mapping.insert(i, (current_branch, branch.len())); + branch.push((Mnemonic::alloca, Operands::One(Operand::imm32(size)))) + } + Inst::Load(ty) => { + // stuff + // branch.push((Mnemonic::mov, Operands::One(Operand::imm32(size)))) + } + Inst::Store(ty) => todo!(), + Inst::GetElementPtr(ty) => todo!(), + Inst::Parameter(ty) => todo!(), + Inst::Add(_) => todo!(), + Inst::Sub(_) => todo!(), + Inst::Mul(_) => todo!(), + Inst::MulSigned(_) => todo!(), + Inst::Div(_) => todo!(), + Inst::DivSigned(_) => todo!(), + Inst::Rem(_) => todo!(), + Inst::RemSigned(_) => todo!(), + Inst::BitAnd(_) => todo!(), + Inst::BitOr(_) => todo!(), + Inst::BitXOr(_) => todo!(), + Inst::Negate(_) => todo!(), + Inst::ShiftLeft(_) => todo!(), + Inst::ShiftRightSigned(_) => todo!(), + Inst::ShiftRightUnsigned(_) => todo!(), + Inst::ReturnValue => todo!(), + Inst::Return => todo!(), + } + } + } +} + pub struct DisplayMir<'a, 'b> { mir: &'a Mir, strings: &'b StringTable,