mir to asm translation WORKS!!

This commit is contained in:
Janis 2024-08-29 21:34:05 +02:00
parent 3aee606ca2
commit 9f0e2c4a31
8 changed files with 1557 additions and 152 deletions

View file

@ -245,11 +245,11 @@ impl Register {
}
}
const fn byte_size(self) -> u32 {
pub const fn byte_size(self) -> u32 {
self.bit_size().div_ceil(8)
}
const fn into_parent_register(self) -> Self {
pub const fn parent_reg(self) -> Self {
use Register::*;
match self {
rax | eax | ax | ah | al => rax,
@ -279,20 +279,20 @@ impl Register {
unsafe { *(&val as *const u8 as *const Self) }
}
const fn into_qword(self) -> Self {
Self::from_u8(self.into_parent_register() as u8)
pub const fn into_qword(self) -> Self {
Self::from_u8(self.parent_reg() as u8)
}
const fn into_dword(self) -> Self {
Self::from_u8(self.into_parent_register() as u8 + 1)
pub const fn into_dword(self) -> Self {
Self::from_u8(self.parent_reg() as u8 + 1)
}
const fn into_word(self) -> Self {
Self::from_u8(self.into_parent_register() as u8 + 2)
pub const fn into_word(self) -> Self {
Self::from_u8(self.parent_reg() as u8 + 2)
}
const fn into_byte(self) -> Self {
Self::from_u8(self.into_parent_register() as u8 + 3)
pub const fn into_byte(self) -> Self {
Self::from_u8(self.parent_reg() as u8 + 3)
}
fn into_bitsize(self, size: u32) -> Self {
@ -305,6 +305,14 @@ impl Register {
_ => panic!("unsupported bitsize {size}"),
}
}
pub fn into_bytesize(self, size: u32) -> Self {
self.into_bitsize(size * 8)
}
pub const SYSV_CALLEE_SAVED: [Register; 5] = {
use Register::*;
[r12, r13, r14, r15, rbx]
};
pub const GPR: [Register; 14] = {
use Register::*;
@ -320,7 +328,7 @@ impl Register {
]
};
const fn is_gp(self) -> bool {
pub const fn is_gp(self) -> bool {
use Register::*;
match self {
rax | eax | ax | ah | al | rbx | ebx | bx | bh | bl | rcx | ecx | cx | ch | cl
@ -331,6 +339,15 @@ impl Register {
_ => false,
}
}
pub const fn is_sse(self) -> bool {
use Register::*;
match self {
xmm0 | xmm1 | xmm2 | xmm3 | xmm4 | xmm5 | xmm6 | xmm7 | xmm8 | xmm9 | xmm10 | xmm11
| xmm12 | xmm13 | xmm14 | xmm15 | ymm0 | ymm1 | ymm2 | ymm3 | ymm4 | ymm5 | ymm6
| ymm7 | ymm8 | ymm9 | ymm10 | ymm11 | ymm12 | ymm13 | ymm14 | ymm15 => true,
_ => false,
}
}
}
pub enum Operands {

View file

@ -221,6 +221,12 @@ impl core::fmt::Display for FloatingType {
}
impl IntegralType {
pub fn new(signed: bool, bits: u16) -> Self {
Self { signed, bits }
}
pub fn u1() -> Self {
Self::new(false, 1)
}
pub fn u32() -> IntegralType {
Self {
signed: false,
@ -543,15 +549,6 @@ pub mod tree_visitor {
Post(Node),
}
impl PrePost {
fn node(self) -> Node {
match self {
PrePost::Pre(n) => n,
PrePost::Post(n) => n,
}
}
}
/// Don't modify `node` in `pre()`
/// Don't modify `children` in `pre()`
pub struct Visitor<F1, F2> {

93
src/bin/main.rs Normal file
View file

@ -0,0 +1,93 @@
use std::{io::Read, path::PathBuf};
use clap::{Arg, Command};
use compiler::{
lexer::Tokenizer,
parser::Tree,
triples::{MirBuilder, IR},
};
fn main() {
let cmd = clap::Command::new("sea")
.bin_name("sea")
.arg(
clap::Arg::new("input")
.short('i')
.help("sea source file.")
.value_parser(clap::builder::PathBufValueParser::new()),
)
.subcommands([
Command::new("ast").about("output AST."),
Command::new("mir").about("output machine-level intermediate representation."),
Command::new("ir").about("output intermediate representation."),
Command::new("asm").about("output x86-64 assembly (intel syntax)."),
]);
let matches = cmd.get_matches();
let path = matches.get_one::<PathBuf>("input");
let source = path
.and_then(|p| std::fs::read(p).ok())
.or_else(|| {
let mut buf = Vec::new();
std::io::stdin().read(&mut buf).ok()?;
Some(buf)
})
.expect("no source bytes.");
let tokens = Tokenizer::new(&source).unwrap();
let mut tree = Tree::new();
tree.parse(tokens.iter()).unwrap();
tree.fold_comptime();
if let Some((cmd, _matches)) = matches.subcommand() {
match cmd {
"ast" => {
let mut buf = String::new();
tree.render(&mut buf).unwrap();
println!("AST:\n{buf}");
}
"ir" => {
let mut ir = IR::new();
let builder = ir.build(&mut tree);
let mut buf = String::new();
builder.render(&mut buf).unwrap();
println!("IR:\n{buf}");
}
"mir" => {
let mut ir = IR::new();
ir.build(&mut tree);
let mut mir = MirBuilder::new(&ir, tree.strings);
mir.build();
let MirBuilder {
strings, functions, ..
} = mir;
for (name, mir) in functions {
println!("{}:\n{}", strings.get_str(name), mir.display(&strings));
}
}
"asm" => {
let mut ir = IR::new();
ir.build(&mut tree);
let mut mir = MirBuilder::new(&ir, tree.strings);
mir.build();
let MirBuilder {
strings, functions, ..
} = mir;
println!(".intel_syntax");
println!(".text");
for (_name, mir) in functions {
let assembly = mir.assemble(&strings).unwrap();
println!("{assembly}");
}
}
_ => {}
}
}
}

1407
src/mir.rs

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt::Display};
use std::collections::HashMap;
use itertools::Itertools;
use num_bigint::{BigInt, BigUint};
@ -12,7 +12,6 @@ use crate::{
string_table::{ImmOrIndex, Index, StringTable},
symbol_table::{SymbolKind, SymbolTable},
tokens::Token,
variant,
};
#[derive(Debug, thiserror::Error)]
@ -1950,7 +1949,7 @@ impl Tree {
ast::tree_visitor::Visitor::new_seek(
self,start,
|_: &Tree, node| {
|_: &Tree, _| {
},
|tree: &Tree, node| match tree.nodes.get_node(node) {
&Tag::Assign { lhs, rhs } => {

View file

@ -3,7 +3,7 @@
use std::collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap};
use crate::{
ast::{Node as AstNode, Tag, Type},
ast::{IntegralType, Node as AstNode, Tag, Type},
parser::Tree,
string_table::{ImmOrIndex, Index as StringsIndex, StringTable},
variant, writeln_indented,
@ -18,7 +18,7 @@ enum NodeOrList {
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Type2 {
pub enum Type2 {
Integral(bool, u16),
Binary32,
Binary64,
@ -108,7 +108,7 @@ impl From<&Type> for Type2 {
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Inst {
pub enum Inst {
/// index
Label,
/// index
@ -154,7 +154,9 @@ enum Inst {
/// lhs
Negate(Type2),
/// lhs
ReturnValue,
ExplicitCast(Type2, Type2),
/// lhs
ReturnValue(Type2),
/// no parameters
Return,
}
@ -233,7 +235,7 @@ impl From<crate::string_table::Index> for Data {
}
}
struct IRBuilder<'tree, 'ir> {
pub struct IRBuilder<'tree, 'ir> {
ir: &'ir mut IR,
tree: &'tree mut Tree,
type_map: HashMap<AstNode, Type>,
@ -297,7 +299,9 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
// TODO: return value of body expression
self.tree.st.into_parent();
if value != !0 {
self.ir.push(Inst::ReturnValue, Some(Data::lhs(value)))
let ty = self.tree.type_of_node(*body);
self.ir
.push(Inst::ReturnValue(ty.into()), Some(Data::lhs(value)))
} else {
!0
}
@ -334,8 +338,10 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
}
Tag::ReturnStmt { expr } => {
if let Some(expr) = expr {
let ty = self.tree.type_of_node(*expr);
let expr = self.visit(*expr);
self.ir.push(Inst::ReturnValue, Some(Data::lhs(expr)))
self.ir
.push(Inst::ReturnValue(ty.into()), Some(Data::lhs(expr)))
} else {
self.ir.push(Inst::Return, None)
}
@ -428,14 +434,10 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
//noop?
lhs
} else {
let alloc = self.ir.push(
Inst::Alloca,
Some(Data::new(r_ty.size_of(), r_ty.align_of())),
);
let load = self
.ir
.push(Inst::Load(l_ty.into()), Some(Data::lhs(alloc)));
load
self.ir.push(
Inst::ExplicitCast(l_ty.into(), r_ty.into()),
Some(Data::lhs(lhs)),
)
}
}
_ => {
@ -446,7 +448,7 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
}
}
struct IR {
pub struct IR {
nodes: Vec<Inst>,
data: Vec<Option<Data>>,
}
@ -533,6 +535,9 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
Inst::Negate(ty) => {
writeln_indented!(indent, w, "%{} = negate_{ty}(%{})", node, data.lhs)?;
}
Inst::ExplicitCast(from, to) => {
writeln_indented!(indent, w, "%{} = cast_{from}_to_{to}(%{})", node, data.lhs)?;
}
Inst::ShiftLeft(ty) => {
writeln_indented!(
indent,
@ -553,8 +558,8 @@ impl<'tree, 'ir> IRBuilder<'tree, 'ir> {
data.rhs
)?;
}
Inst::ReturnValue => {
writeln_indented!(indent, w, "%{} = return %{}", node, data.lhs)?;
Inst::ReturnValue(ty) => {
writeln_indented!(indent, w, "%{} = return {ty} %{}", node, data.lhs)?;
}
Inst::Return => {
writeln_indented!(indent, w, "%{} = return", node)?;
@ -1200,7 +1205,8 @@ impl<'a> Assembler<'a> {
Inst::BitOr(_) => todo!(),
Inst::BitXOr(_) => todo!(),
Inst::Negate(_) => todo!(),
Inst::ReturnValue => {
Inst::ExplicitCast(_, _) => todo!(),
Inst::ReturnValue(_) => {
let val = data.unwrap().lhs;
let (&reg, _) = registers.iter().find(|(_, node)| node == &&val)
.expect(&format!(
@ -1265,14 +1271,14 @@ impl<'a> Assembler<'a> {
use crate::mir;
struct MirBuilder<'a> {
pub struct MirBuilder<'a> {
ir: IRIter<'a>,
strings: StringTable,
mir: mir::Mir,
pub strings: StringTable,
pub functions: HashMap<StringsIndex, mir::Mir>,
}
impl<'a> MirBuilder<'a> {
fn new(ir: &'a IR, strings: StringTable) -> MirBuilder<'a> {
pub fn new(ir: &'a IR, strings: StringTable) -> MirBuilder<'a> {
Self {
ir: IRIter {
ir,
@ -1280,12 +1286,12 @@ impl<'a> MirBuilder<'a> {
item: None,
},
strings,
mir: mir::Mir::new(),
functions: HashMap::new(),
}
}
fn build_function(&mut self, name: StringsIndex) {
let mut mir = mir::Mir::new();
let mut mir = mir::Mir::new(name);
let mut mapping = BTreeMap::<u32, u32>::new();
loop {
@ -1298,9 +1304,7 @@ impl<'a> MirBuilder<'a> {
self.ir.offset -= 1;
break;
}
Inst::Label => {
mir.push(mir::Inst::Label, mir::Data::index(data.unwrap().as_index()))
}
Inst::Label => mir.gen_label(data.unwrap().as_index()),
Inst::ConstantU32 => mir.push(
mir::Inst::ConstantDWord,
mir::Data::imm32(data.unwrap().as_u32()),
@ -1375,7 +1379,7 @@ impl<'a> MirBuilder<'a> {
let ty = mir::Type::from_bitsize_int(bits as u32);
let sum = mir.gen_add(ty, lhs, rhs);
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
}
},
Type2::Binary32 => mir.gen_add(mir::Type::SinglePrecision, lhs, rhs),
@ -1394,7 +1398,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_sub(ty, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1411,7 +1415,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_mul(ty, signed, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1428,7 +1432,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_div(ty, signed, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1445,7 +1449,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_rem(ty, signed, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1467,7 +1471,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_bitand(ty, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1489,7 +1493,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_bitor(ty, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1511,7 +1515,7 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_bitxor(ty, lhs, rhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
@ -1522,6 +1526,7 @@ impl<'a> MirBuilder<'a> {
let rhs = *mapping.get(&dst).unwrap();
// TODO: check rhs type and pass it to gen_sh{l,r}?
let rhs = mir.gen_truncate_integer(rhs, ty.into(), false, 8);
match ty {
Type2::Integral(signed, bits) => match bits {
8 => mir.gen_shl(mir::Type::Byte, lhs, rhs),
@ -1535,7 +1540,7 @@ impl<'a> MirBuilder<'a> {
let ty = mir::Type::from_bitsize_int(bits as u32);
let sum = mir.gen_shl(ty, lhs, rhs);
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
}
},
_ => unreachable!(),
@ -1568,7 +1573,7 @@ impl<'a> MirBuilder<'a> {
mir.gen_shr(ty, lhs, rhs)
};
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
}
},
_ => unreachable!(),
@ -1584,16 +1589,49 @@ impl<'a> MirBuilder<'a> {
let sum = mir.gen_negate(ty, lhs);
if let Some((signed, bits)) = unalignment {
mir.gen_truncate(sum, ty, signed, bits)
mir.gen_truncate_integer(sum, ty, signed, bits)
} else {
sum
}
}
Inst::ReturnValue => {
Inst::ExplicitCast(from, to) => {
let lhs = data.unwrap().as_u32();
let from_mir = from.mir_type();
let to_mir = to.mir_type();
let lhs = *mapping.get(&lhs).unwrap();
match (from, to) {
(Type2::Integral(a_signed, a), Type2::Integral(b_signed, b)) => {
if a > b {
mir.gen_truncate_integer(lhs, to_mir, b_signed, b)
} else if a < b {
mir.gen_extend_integer(
lhs,
IntegralType::new(a_signed, a),
IntegralType::new(b_signed, b),
)
} else {
unreachable!()
}
}
(Type2::Integral(_, _), Type2::Bool) => {
let is_zero = mir.gen_is_zero(from_mir, lhs);
mir.gen_negate(mir::Type::Byte, is_zero)
}
(Type2::Bool, Type2::Integral(b_signed, b)) => mir.gen_extend_integer(
lhs,
IntegralType::u1(),
IntegralType::new(b_signed, b),
),
_ => unimplemented!(),
}
}
Inst::ReturnValue(ty) => {
let src = data.unwrap().as_u32();
let src = *mapping.get(&src).unwrap();
mir.gen_ret_val(src)
mir.gen_ret_val(ty.mir_type(), src)
}
Inst::Return => mir.gen_ret(),
#[allow(unreachable_patterns)]
@ -1605,14 +1643,10 @@ impl<'a> MirBuilder<'a> {
mapping.insert(ir_node, node);
}
println!(
"{} mir:\n{}",
self.strings.get_str(name),
mir.display(&self.strings)
);
self.functions.insert(name, mir);
}
fn build(&mut self) {
pub fn build(&mut self) {
loop {
let Some((inst, data)) = self.ir.next() else {
break;
@ -1664,6 +1698,16 @@ fn inverse_sqrt(x: f32) -> f32 {
let mut mir = MirBuilder::new(&ir, tree.strings);
mir.build();
let MirBuilder {
strings, functions, ..
} = mir;
for (_name, mir) in functions {
let assembly = mir.assemble(&strings).unwrap();
println!("mir:\n{}", mir.display(&strings));
println!("assembly:\n{assembly}");
}
}
#[test]

View file

@ -0,0 +1,7 @@
fn inverse_sqrt(n: f32) -> f32 {
let x = n;
var i = *(&x as *i32);
i = 0x5f3759dfi32 - (i >> 1u8);
let y = *(&i as *f32);
y * (1.5f32 - (x * 0.5f32 * y * y))
}

View file

@ -0,0 +1,5 @@
fn main() -> u32 {
let a: u32 = 0u32 + 3u32;
let b = &a;
return *b * 2u32;
}