diff --git a/lang/src/ast.asm b/lang/src/ast.asm index 46bacba..14f7c48 100644 --- a/lang/src/ast.asm +++ b/lang/src/ast.asm @@ -37,8 +37,11 @@ section .text extern vec_init_with extern vec_push extern vec_get +extern vec_insert_sorted +extern vec_get_or extern panic extern memcpy +extern strcmp extern vec_binary_search_by extern vec_insert @@ -63,6 +66,8 @@ global parse_binary_expr global parse_primary_expr global parse_statement global parse_block +global ast_build_symtable +global ast_walk_for_each ;; start very simple, with only functions and addition ;; ```rust @@ -960,3 +965,451 @@ ast_place_to_value: .done: pop rbp ret + + +;; rdi: ctx +;; rsi: a: *const SymKey +;; rdx: b: *const SymKey +;; define-fn: fn symkey_cmp(a: *const SymKey, b: *const SymKey) -> i32 +symkey_cmp: + push rbp + mov rbp, rsp + push rbx + + mov al, byte [rsi] ; a.kind + mov bl, byte [rdx] ; b.kind + cmp al, bl + jl .a_less + jg .a_greater + mov rax, [rsi + 8] ; a.scope_index + mov rbx, [rdx + 8] ; b.scope_index + cmp rax, rbx + jl .a_less + jg .a_greater + mov rax, [rsi + 16] ; a.span + mov rbx, [rdx + 16] ; b.span + cmp rax, rbx + jl .a_less + jg .a_greater + + mov rdi, [rsi + 24] ; a.ident + mov rsi, [rsi + 32] ; a.ident_len + mov rcx, [rdx + 32] ; b.ident_len + mov rdx, [rdx + 24] ; b.ident + call strcmp + cmp rax, 0 + jl .a_less + jg .a_greater + xor rax, rax + jmp .epilogue +.a_less: + mov rax, -1 + jmp .epilogue +.a_greater: + mov rax, 1 +.epilogue: + pop rbx + pop rbp + ret + +section .rdata + KEY_SCOPE equ 1 + KEY_SCOPE_NAME equ 2 + KEY_PARENT_SCOPE equ 3 + KEY_START_LOCALS equ 4 + KEY_ARG equ 5 + KEY_VAR equ 6 + KEY_END_LOCALS equ 7 + +section .text + +;; rdi: Ast +;; rsi: root index +;; rdx: *SymbolTable +;; define-fn: fn ast_build_symtable(ast: *mut Ast, root_index: u64, symtable: *mut core::mem::MaybeUninit) +ast_build_symtable: + push rbp + mov rbp, rsp + + ; BuildSymtableCtx [24..104] + ; *SymbolTable [16..24] + ; root_index [8..16] + ; Ast [0..8] + sub rsp, 104 + mov [rsp], rdi ; Ast + mov [rsp + 8], rsi ; root_index + mov [rsp + 16], rdx ; *SymbolTable + + ; initialise scope_stack and symtable vecs + lea rdi, [rsp + 24] ; &BuildSymtableCtx.scope_stack + mov rsi, 8 ; size of u64 + mov rdx, 0 ; drop = None + mov rcx, 128 ; capacity + call vec_init_with + + lea rdi, [rsp + 24 + 40] ; &BuildSymtableCtx.symtable + mov rsi, 56 ; size_of:: + mov rdx, 0 ; drop = None + mov rcx, 128 ; capacity + call vec_init_with + + mov rdi, [rsp] ; Ast + mov rsi, [rsp + 8] ; root_index + lea rdx, [rsp + 24] ; &BuildSymtableCtx + mov rcx, ast_build_symtable_for_each + call ast_walk_for_each + + ; memcpy symtable out + + mov rdi, [rsp + 16] ; *SymbolTable + lea rsi, [rsp + 24 + 40] ; &BuildSymtableCtx.symtable + mov rdx, 40 ; size_of::> + call memcpy + + add rsp, 96 + pop rbp + ret + +;; symtable is a sorted vec pretending to be a b-tree: +;; entries are sorted by a key in order to get the following ordering: +;; scope (index0) -> (ident0) +;; scope (index1) -> (ident1) +;; scope (index2) -> (ident2) +;; scope-name (ident1) -> (index1) +;; scope-name (ident1) -> (index1) +;; parent-scope (scope1) -> (index0) +;; arg (scope1, span, ident) -> (index) +;; var (scope1, span, ident) -> (index) +;; var (scope1, span, ident) -> (index) +;; arg (scope0, span, ident) -> (index) +;; var (scope0, span, ident) -> (index) +;; var (scope0, span, ident) -> (index) +;; +;; arguments are ordered before variables in order to allow shadowing of variables by arguments. +;; variables are ordered by span in order to allow shadowing of variables by variables. +;; all references within a scope are in the range parent-scope(scopeN)..var +;; (scopeN, u64::MAX, u64::MAX) +;; +;; the symtable contains `SymEntries`, which hold a `SymKey` and an index into the AST node list. +;; for scope entries, the index holds the pointer to the scope's ident, +;; and `extra` holds the length; for other keys, `extra` is 0. +;; +;; start-structs +;; struct SymbolTable { +;; symtable: Vec, +;; } +;; struct SymKey { +;; kind: u8, +;; scope_index: u64, +;; span: u64, +;; ident: *const u8, +;; ident_len: usize, +;; } +;; struct SymEntry { +;; key: SymKey, +;; index: u64, +;; extra: u64, +;; } +;; end-structs +;; size_of:: == 40 +;; size_of:: == 56 +;; +;; #start-structs +;; struct BuildSymtableCtx { +;; scope_stack: Vec, +;; symtable: Vec, +;; } +;; #end-structs +;; +;; scope_stack [0..40] +;; symtable [40..80] +;; +;; rdi: Ctx +;; rsi: Ast +;; rdx: index +ast_build_symtable_for_each: + push rbp + mov rbp, rsp + push rbx + + ; SymEntry [32..88] + ; SymKey [32..72] + ; *AstNode [24..32] + ; index [16..24] + ; ctx [8..16] + ; ast [0..8] + sub rsp, 88 + mov [rsp], rsi ; Ast + mov [rsp + 8], rdi ; Ctx + mov [rsp + 16], rdx ; index + + mov rdi, rsi ; Ast + mov rsi, rdx ; index + call vec_get + mov [rsp + 24], rax ; *AstNode + + mov bl, byte [rax] ; AstNode.kind + + cmp bl, AST_FUNCTION + je .func + cmp bl, AST_VAR_DECL + je .var_decl + cmp bl, AST_ARG + je .arg + jmp .done +.func: + ; insert scope entry + mov byte [rsp + 32], KEY_SCOPE ; SymKey.kind + mov rdx, [rsp + 16] ; index + mov qword [rsp + 40], rdx ; SymKey.scope_index + mov qword [rsp + 48], 0 ; SymKey.span + mov qword [rsp + 56], 0 ; SymKey.ident + mov qword [rsp + 64], 0 ; SymKey.ident_len + + ; mov rbx, [rax + 16] ; AstNode.data + ; mov rdx, [rbx + 8] ; Func.name + ; mov rcx, [rbx + 16] ; Func.name_len + mov rbx, [rax + 8] ; AstNode.data + mov rdx, [rbx + 0] ; Func.name + mov rcx, [rbx + 8] ; Func.name_len + + mov qword [rsp + 72], rdx ; SymEntry.index + mov qword [rsp + 80], rcx ; SymEntry.extra + + mov rdi, [rsp + 8] ; *Ctx + lea rdi, [rdi + 40] ; Ctx.symtable + lea rsi, [rsp + 32] ; &SymEntry + mov rcx, 0 ; cmp_ctx + mov rdx, symkey_cmp ; cmp + call vec_insert_sorted + + ; push scope index onto scope_stack + mov rdi, [rsp + 8] ; *Ctx + lea rdi, [rdi + 0] ; Ctx.scope_stack + lea rsi, [rsp + 16] ; &index + call vec_push + + jmp .done +.var_decl: + ; insert variable entry + mov byte [rsp + 32], KEY_VAR ; SymKey.kind + + ; TODO: set span correctly + mov qword [rsp + 48], 0 ; SymKey.span + + mov rbx, [rsp + 24] ; AstNode.data + mov rdx, [rbx + 0] ; AstVarDecl.name + mov rcx, [rbx + 8] ; AstVarDecl.name_len + + mov [rsp + 56], rdx ; SymKey.ident + mov [rsp + 64], rcx ; SymKey.ident_len + + mov rdx, [rsp + 16] ; index + mov [rsp + 72], rdx ; SymEntry.index + mov qword [rsp + 80], 0 ; SymEntry.extra + + mov qword [rsp + 40], 0 ; SymKey.scope_index = default + lea rdx, [rsp + 40] + mov rdi, [rsp + 8] ; *Ctx + mov rsi, [rdi + 8] ; Ctx.scope_stack.len() + dec rsi + call vec_get_or + mov rax, [rax] ; current scope index + mov [rsp + 40], rax ; SymKey.scope_index = scope_stack.last_or(0) + + mov rdi, [rsp + 8] ; *Ctx + lea rdi, [rdi + 40] ; Ctx.symtable + lea rsi, [rsp + 32] ; &SymEntry + mov rcx, 0 ; cmp_ctx + mov rdx, symkey_cmp ; cmp + call vec_insert_sorted + +.arg: + ; insert variable entry + mov byte [rsp + 32], KEY_ARG ; SymKey.kind + + ; TODO: set span correctly + mov qword [rsp + 48], 0 ; SymKey.span + + mov rbx, [rsp + 24] ; AstNode.data + mov rdx, [rbx + 0] ; AstArgument.name + mov rcx, [rbx + 8] ; AstArgument.name_len + + mov [rsp + 56], rdx ; SymKey.ident + mov [rsp + 64], rcx ; SymKey.ident_len + + mov rdx, [rsp + 16] ; index + mov [rsp + 72], rdx ; SymEntry.index + mov qword [rsp + 80], 0 ; SymEntry.extra + + mov qword [rsp + 40], 0 ; SymKey.scope_index = default + lea rdx, [rsp + 40] + mov rdi, [rsp + 8] ; *Ctx + mov rsi, [rdi + 8] ; Ctx.scope_stack.len() + dec rsi + call vec_get_or + mov rax, [rax] ; current scope index + mov [rsp + 40], rax ; SymKey.scope_index = scope_stack.last_or(0) + + mov rdi, [rsp + 8] ; *Ctx + lea rdi, [rdi + 40] ; Ctx.symtable + lea rsi, [rsp + 32] ; &SymEntry + mov rcx, 0 ; cmp_ctx + mov rdx, symkey_cmp ; cmp + call vec_insert_sorted + +.done: + add rsp, 88 + pop rbx + pop rbp + ret + +;; rdi: Ast +;; rsi: start_index +;; rdx: ctx +;; rcx: for_each +;; define-fn: fn ast_walk_for_each(ast: *mut Ast, start_index: u64, ctx: *mut (), for_each: unsafe extern "C" fn(ctx: *mut (), *mut Ast, node_index: u64)) +ast_walk_for_each: + push rbp + push r15 + push r14 + push rbx + + ; current_index [24..32] + ; for_each [16..24] + ; ctx [8..16] + ; ast [0..8] + sub rsp, 32 + mov [rsp], rdi ; Ast + mov [rsp + 8], rdx ; ctx + mov [rsp + 16], rcx ; for_each + mov [rsp + 24], rsi ; current_node_ptr + + mov rbp, rsp + push rsi + +.loop: + cmp rsp, rbp + jge .done + ; call for_each(ctx, ast, current_index) + mov rdi, [rbp + 8] ; ctx + mov rsi, [rbp] ; Ast + mov rdx, [rsp] ; current_index + mov rax, [rbp + 16] ; for_each + + ; align stack to 16 bytes before call + mov rbx, rsp + sub rsp, 8 + and rsp, -16 + mov [rsp], rbx + call rax + pop rsp + + ; get current_node_ptr + mov rdi, [rbp] ; Ast + pop rsi ; current_index + call vec_get + mov [rbp + 24], rax ; current_node_ptr + mov bl, byte [rax] ; AstNode.kind + cmp bl, AST_FUNCTION + je .func + cmp bl, AST_BLOCK + je .block + cmp bl, AST_BINARY_OP + je .binary_op + cmp bl, AST_ASSIGNMENT + je .assignment + cmp bl, AST_VALUE_TO_PLACE + je .value_to_place + cmp bl, AST_PLACE_TO_VALUE + je .place_to_value + cmp bl, AST_DEREF + je .deref + cmp bl, AST_ADDRESS_OF + je .address_of + cmp bl, AST_RETURN_STATEMENT + je .return_statement + jmp .loop + +.func: + ; push child indices to stack + mov rbx, [rax + 8] ; AstNode.data + + mov r15, [rbx + 24] ; AstFunction.args_len + xor r14, r14 ; index +.arg_loop: + cmp r14, r15 + jge .arg_loop_done + mov rdx, [rbx + 16] ; AstFunction.args + lea rdx, [rdx + r14*8] + push qword [rdx] ; push arg index + inc r14 + jmp .arg_loop +.arg_loop_done: + mov r15, [rbx + 48] ; AstFunction.body + push r15 ; push body index + jmp .loop + +.block: + mov rbx, [rax + 8] ; AstNode.data + mov r15, [rax + 16] ; AstNode.extra + +.stmt_loop: + cmp r15, 0 + jle .stmt_loop_done + dec r15 + mov rdx, [rbx + r15*8] ; statement index + push rdx ; push statement index + jmp .stmt_loop +.stmt_loop_done: + jmp .loop + +.binary_op: + mov rbx, [rax + 8] ; AstNode.data + mov rdx, [rbx + 16] ; right index + push rdx ; push right index + mov rdx, [rbx + 0] ; left index + push rdx ; push left index + jmp .loop + +.assignment: + mov rbx, [rax + 8] ; AstNode.data = dest + mov rdx, [rax + 16] ; AstNode.extra = source + push rdx ; push source index + push rbx ; push dest index + jmp .loop + +.value_to_place: +.place_to_value: +.deref: +.address_of: +.return_statement: + mov rbx, [rax + 8] ; AstNode.data + push rbx ; push inner expr index + jmp .loop + + ; weird alloca thing +.done: + add rsp, 32 + pop rbx + pop r14 + pop r15 + pop rbp + ret + +;; rdi: Ast +;; define-fn: fn ast_resolve_var_refs(ast: *mut Ast) +ast_resolve_var_refs: + push rbp + mov rbp, rsp + push r15 + push r14 + push rbx + +.epilogue: + add rsp, 8 + pop rbx + pop r14 + pop r15 + pop rbp + ret diff --git a/lang/src/vec.asm b/lang/src/vec.asm index 9ab3e93..08bd02c 100644 --- a/lang/src/vec.asm +++ b/lang/src/vec.asm @@ -18,6 +18,7 @@ global vec_drop_last global vec_swap global vec_remove global vec_get +global vec_get_or global vec_drop global vec_find global vec_insert @@ -104,6 +105,29 @@ vec_init_with: pop rbp ret +;; rdi: pointer to Vec struct +;; rsi: index +;; rdx: pointer to default value +;; fn vec_get(vec: *mut Vec, index: usize, default: *mut u8) -> *mut u8 +vec_get_or: + push rbp + mov rbp, rsp + ; if (index >= vec.len) panic(); + mov rax, [rdi + 8] ; len + cmp rsi, rax + jge .default + ; return &mut vec.data[index * vec.item_size]; + mov rax, [rdi + 24] ; item_size + mul rsi ; index * item_size + mov rsi, [rdi] ; data + add rax, rsi ; data + index * item_size + pop rbp + ret +.default: + mov rax, rdx + pop rbp + ret + ;; rdi: pointer to Vec struct ;; rsi: index ;; fn vec_get(vec: *mut Vec, index: usize) -> *mut u8 diff --git a/lang/tests/ast.rs b/lang/tests/ast.rs index 67a9d26..245ad49 100644 --- a/lang/tests/ast.rs +++ b/lang/tests/ast.rs @@ -1,3 +1,5 @@ +#![feature(debug_closure_helpers)] + #[path = "shared/shared.rs"] mod util; @@ -26,23 +28,36 @@ fn main() { let expr_id = parser(&mut ast); eprintln!("Parsed expression ID: {}", expr_id); println!("{:#}", &ast); + + unsafe extern "C" fn visit_node(_this: *mut (), ast: *mut Ast, node_id: u64) { + let ast = unsafe { &*ast }; + let node = ast.nodes.get(node_id as usize).unwrap(); + eprintln!("Visiting node {node_id}: {node}"); + } + + util::defs::ast_walk_for_each(&mut ast, expr_id, core::ptr::null_mut(), visit_node); + + // let mut symtable = core::mem::MaybeUninit::::uninit(); + // util::defs::ast_build_symtable(&mut ast, expr_id, &mut symtable); + // let symtable = symtable.assume_init(); + // println!("Symbol Table: {:#?}", symtable); }; } - print_ast(b"3 + 4", |ast| unsafe { parse_expr(ast) }); - print_ast(b"fn main() -> void { return 1 + 2; }", |ast| unsafe { - parse_func(ast) - }); - print_ast(b"fn main() -> void { return (1 + (2)); }", |ast| unsafe { - parse_func(ast) - }); - print_ast( - b"fn main() -> void { return (1 + (2 * 3)) / 4; }", - |ast| unsafe { parse_func(ast) }, - ); - print_ast(b"fn main() -> void { return 1 + 2 * 3; }", |ast| unsafe { - parse_func(ast) - }); + // print_ast(b"3 + 4", |ast| unsafe { parse_expr(ast) }); + // print_ast(b"fn main() -> void { return 1 + 2; }", |ast| unsafe { + // parse_func(ast) + // }); + // print_ast(b"fn main() -> void { return (1 + (2)); }", |ast| unsafe { + // parse_func(ast) + // }); + // print_ast( + // b"fn main() -> void { return (1 + (2 * 3)) / 4; }", + // |ast| unsafe { parse_func(ast) }, + // ); + // print_ast(b"fn main() -> void { return 1 + 2 * 3; }", |ast| unsafe { + // parse_func(ast) + // }); print_ast(b"fn main() -> void { let x: u32 = 4; }", |ast| unsafe { parse_func(ast) @@ -182,3 +197,41 @@ impl core::fmt::Display for Ast { writeln!(f, "\n]") } } + +impl core::fmt::Display for util::defs::SymEntry { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SymEntry") + .field_with("key", |f| { + f.debug_struct("Key") + .field("kind", &self.key.kind) + .field("scope", &self.key.scope_index) + .field("span", &self.key.span) + .field_with("ident", |f| { + f.write_str(unsafe { + &core::str::from_utf8_unchecked(core::slice::from_raw_parts( + self.key.ident, + self.key.ident_len, + )) + }) + }) + .finish() + }) + .field_with("value", |f| { + let stct = &mut f.debug_struct("Value"); + if self.extra == 0 { + stct.field("ast_index", &self.index).finish() + } else { + stct.field_with("ident", |f| { + f.write_str(unsafe { + core::str::from_utf8_unchecked(core::slice::from_raw_parts( + self.index as *const u8, + self.extra as usize, + )) + }) + }) + .finish() + } + }) + .finish() + } +} diff --git a/lang/tests/shared/defs.rs b/lang/tests/shared/defs.rs index ed96a8c..5919082 100644 --- a/lang/tests/shared/defs.rs +++ b/lang/tests/shared/defs.rs @@ -13,6 +13,10 @@ unsafe extern "C" { pub unsafe fn parse_prefix_expr(ast: *mut Ast) -> (u64, bool); pub unsafe fn parse_assignment(ast: *mut Ast) -> (u64, bool); pub unsafe fn ast_parse_let(ast: *mut Ast) -> (u64, bool); + pub unsafe fn symkey_cmp(a: *const SymKey, b: *const SymKey) -> i32; + pub unsafe fn ast_build_symtable(ast: *mut Ast, root_index: u64, symtable: *mut core::mem::MaybeUninit); + pub unsafe fn ast_walk_for_each(ast: *mut Ast, start_index: u64, ctx: *mut (), for_each: unsafe extern "C" fn(ctx: *mut (), *mut Ast, node_index: u64)); + pub unsafe fn ast_resolve_var_refs(ast: *mut Ast); } pub const AST_FUNCTION: u8 = 1; @@ -154,4 +158,28 @@ pub struct AstAssignment { pub expr: u64, } +#[repr(C)] +#[derive(Debug)] +pub struct SymbolTable { + pub symtable: Vec, +} + +#[repr(C)] +#[derive(Debug)] +pub struct SymKey { + pub kind: u8, + pub scope_index: u64, + pub span: u64, + pub ident: *const u8, + pub ident_len: usize, +} + +#[repr(C)] +#[derive(Debug)] +pub struct SymEntry { + pub key: SymKey, + pub index: u64, + pub extra: u64, +} + use super::vec::Vec;