From e79af0192535d7f2a9513bd2f700c873f6e57f6f Mon Sep 17 00:00:00 2001 From: janis Date: Fri, 31 Oct 2025 01:55:04 +0100 Subject: [PATCH] ast: correctly resolves var refs --- lang/src/ast.asm | 93 +++++++++++++++++++++++---------------- lang/tests/ast.rs | 48 +++++++++++--------- lang/tests/shared/defs.rs | 2 +- 3 files changed, 83 insertions(+), 60 deletions(-) diff --git a/lang/src/ast.asm b/lang/src/ast.asm index ee4629f..fe86cac 100644 --- a/lang/src/ast.asm +++ b/lang/src/ast.asm @@ -68,6 +68,7 @@ global parse_statement global parse_block global ast_build_symtable global ast_walk_for_each +global ast_resolve_var_refs ;; start very simple, with only functions and addition ;; ```rust @@ -1528,22 +1529,27 @@ ast_walk_for_each: pop rbp ret -;; rdi: BuildSymtableCtx +;; rdi: *mut SymbolTable ;; rsi: *mut Ast ;; rdx: node_index +;; rcx: scope ast_resolve_var_refs_for_each: push rbp mov rbp, rsp push rbx + ; lower_bound [88..96] + ; scope: u64 [80..88] ; SymEntry [24..80] ; *AstNode [16..24] ; *BuildSymtableCtx [8..16] ; *Ast [0..8] - sub rsp, 24 + sub rsp, 96 mov [rsp], rsi ; Ast mov [rsp + 8], rdi ; Ctx + mov [rsp + 80], rcx ; SymKey.scope_index + mov rdi, rsi ; Ast mov rsi, rdx ; node_index call vec_get @@ -1556,45 +1562,59 @@ ast_resolve_var_refs_for_each: ; lookup variable in symbol table - ; get current scope index - mov qword [rsp + 32], 0 ; SymKey.scope_index = default - mov rdi, [rsp + 8] ; *Ctx - lea rdi, [rdi + 0] ; Ctx.scope_stack - mov rsi, [rdi + 8] ; Ctx.scope_stack.len() - dec rsi - lea rdx, [rsp + 32] - call vec_get_or - mov rax, [rax] ; current scope index - - ; construct key - mov byte [rsp + 24 + 0], SYM_KEY_VAR ; SymKey.kind - mov [rsp + 24 + 8], rax ; SymKey.scope_index = scope_stack.last_or(0) - mov rax, [rsp + 16] ; *AstNode - mov rbx, [rax + 24] ; AstNode.span - mov [rsp + 24 + 16], rbx ; SymKey.span - mov rbx, [rax + 8] ; AstNode.data - mov rax, [rbx + 8] ; AstVarRef.name - mov rbx, [rbx + 16] ; AstVarRef.name_len - mov [rsp + 24 + 24], rax ; SymKey.ident - mov [rsp + 24 + 32], rbx ; SymKey.ident_len + ; binary search lower bound + mov byte [rsp + 24 + 0], SYM_KEY_START_LOCALS ; SymKey.kind + mov qword [rsp + 24 + 8], 0 ; SymKey.scope_index + mov qword [rsp + 24 + 16], 0 ; SymKey.span + mov qword [rsp + 24 + 24], 1 ; SymKey.name + mov qword [rsp + 24 + 32], 0 ; SymKey.name_len ; binary search in symbol table mov rdi, [rsp + 8] ; *Ctx - lea rdi, [rdi + 40] ; Ctx.symtable lea rsi, [rsp + 24] ; &SymKey - mov rdx, 0 ; cmp_ctx - mov rcx, symkey_cmp ; cmp + mov rdx, symkey_cmp ; cmp + mov rcx, 0 ; cmp_ctx + call vec_binary_search_by + mov [rsp + 88], rax ; lower_bound + + ; construct key + mov byte [rsp + 24 + 0], SYM_KEY_VAR ; SymKey.kind + mov rax, [rsp + 80] ; scope + mov [rsp + 24 + 8], rax ; SymKey.scope_index + mov rax, [rsp + 16] ; *AstNode + mov rbx, [rax + 24] ; AstNode.span + mov [rsp + 24 + 16], rbx ; SymKey.span + mov rbx, [rax + 8] ; AstNode.data + mov rax, [rbx + 8] ; AstVarRef.name + mov rbx, [rbx + 16] ; AstVarRef.name_len + mov [rsp + 24 + 24], rax ; SymKey.ident + mov [rsp + 24 + 32], rbx ; SymKey.ident_len + + ; binary search in symbol table + mov rdi, [rsp + 8] ; *Ctx + lea rsi, [rsp + 24] ; &SymKey + mov rdx, symkey_cmp ; cmp + mov rcx, 0 ; cmp_ctx call vec_binary_search_by test rdx, rdx - jnz .panic + jz .fixup + dec rax + +.fixup: + cmp rax, [rsp + 88] ; lower_bound + jl .panic + + mov rdi, [rsp + 8] ; *Ctx + mov rsi, rax ; index + call vec_get + mov rax, [rax + 40] ; SymEntry.index mov rdx, [rsp + 16] ; *AstNode mov rdx, [rdx + 8] ; AstNode.data mov [rdx + 0], rax ; AstVarRef.resolved_index - .epilogue: - add rsp, 24 + add rsp, 96 pop rbx pop rbp ret @@ -1602,18 +1622,17 @@ ast_resolve_var_refs_for_each: call panic ;; rdi: Ast -;; define-fn: fn ast_resolve_var_refs(ast: *mut Ast) +;; rsi: *mut SymbolTable +;; rdx: root_index +;; define-fn: fn ast_resolve_var_refs(ast: *mut Ast, ctx: *mut SymbolTable, root_index: u64) ast_resolve_var_refs: push rbp mov rbp, rsp - push r15 - push r14 - push rbx + + xchg rsi, rdx + mov rcx, ast_resolve_var_refs_for_each + call ast_walk_for_each .epilogue: - add rsp, 8 - pop rbx - pop r14 - pop r15 pop rbp ret diff --git a/lang/tests/ast.rs b/lang/tests/ast.rs index b0c81cf..6b9e0af 100644 --- a/lang/tests/ast.rs +++ b/lang/tests/ast.rs @@ -44,37 +44,41 @@ fn main() { let mut symtable = core::mem::MaybeUninit::::uninit(); util::defs::ast_build_symtable(&mut ast, expr_id, &mut symtable); - let symtable = symtable.assume_init(); + let mut symtable = symtable.assume_init(); use util::DisplayedSliceExt; println!( "Symbol Table: {:#?}", symtable.symtable.as_slice().displayed() ); + + util::defs::ast_resolve_var_refs(&mut ast, &mut symtable, expr_id); + + println!("{:#}", &ast); }; } - print_ast(b"3 + 4", |ast| unsafe { parse_expr(ast) }); - print_ast(b"fn main() -> void { return 1 + 2; }", |ast| unsafe { - parse_func(ast) - }); - print_ast(b"fn main() -> void { return (1 + (2)); }", |ast| unsafe { - parse_func(ast) - }); - print_ast( - b"fn main() -> void { return (1 + (2 * 3)) / 4; }", - |ast| unsafe { parse_func(ast) }, - ); - print_ast(b"fn main() -> void { return 1 + 2 * 3; }", |ast| unsafe { - parse_func(ast) - }); + // print_ast(b"3 + 4", |ast| unsafe { parse_expr(ast) }); + // print_ast(b"fn main() -> void { return 1 + 2; }", |ast| unsafe { + // parse_func(ast) + // }); + // print_ast(b"fn main() -> void { return (1 + (2)); }", |ast| unsafe { + // parse_func(ast) + // }); + // print_ast( + // b"fn main() -> void { return (1 + (2 * 3)) / 4; }", + // |ast| unsafe { parse_func(ast) }, + // ); + // print_ast(b"fn main() -> void { return 1 + 2 * 3; }", |ast| unsafe { + // parse_func(ast) + // }); - print_ast(b"fn main() -> void { let x: u32 = 4; }", |ast| unsafe { - parse_func(ast) - }); - print_ast( - b"fn main(a: u32) -> void { let x: u32 = a + 4; }", - |ast| unsafe { parse_func(ast) }, - ); + // print_ast(b"fn main() -> void { let x: u32 = 4; }", |ast| unsafe { + // parse_func(ast) + // }); + // print_ast( + // b"fn main(a: u32) -> void { let x: u32 = a + 4; }", + // |ast| unsafe { parse_func(ast) }, + // ); print_ast( b"fn main(a: u32) -> void { let y: u32 = a + 4; diff --git a/lang/tests/shared/defs.rs b/lang/tests/shared/defs.rs index f43869e..0ec4d9b 100644 --- a/lang/tests/shared/defs.rs +++ b/lang/tests/shared/defs.rs @@ -16,7 +16,7 @@ unsafe extern "C" { pub unsafe fn symkey_cmp(a: *const SymKey, b: *const SymKey) -> i32; pub unsafe fn ast_build_symtable(ast: *mut Ast, root_index: u64, symtable: *mut core::mem::MaybeUninit); pub unsafe fn ast_walk_for_each(ast: *mut Ast, start_index: u64, ctx: *mut (), for_each: unsafe extern "C" fn(ctx: *mut (), *mut Ast, node_index: u64, scope: u64)); - pub unsafe fn ast_resolve_var_refs(ast: *mut Ast); + pub unsafe fn ast_resolve_var_refs(ast: *mut Ast, ctx: *mut SymbolTable, root_index: u64); } pub const AST_FUNCTION: u8 = 1;