ast walking

This commit is contained in:
janis 2025-10-30 21:41:27 +01:00
parent 525b78cdf4
commit ad3b0205c2
Signed by: janis
SSH key fingerprint: SHA256:bB1qbbqmDXZNT0KKD5c2Dfjg53JGhj7B3CFcLIzSqq8
4 changed files with 572 additions and 14 deletions

View file

@ -37,8 +37,11 @@ section .text
extern vec_init_with
extern vec_push
extern vec_get
extern vec_insert_sorted
extern vec_get_or
extern panic
extern memcpy
extern strcmp
extern vec_binary_search_by
extern vec_insert
@ -63,6 +66,8 @@ global parse_binary_expr
global parse_primary_expr
global parse_statement
global parse_block
global ast_build_symtable
global ast_walk_for_each
;; start very simple, with only functions and addition
;; ```rust
@ -960,3 +965,451 @@ ast_place_to_value:
.done:
pop rbp
ret
;; rdi: ctx
;; rsi: a: *const SymKey
;; rdx: b: *const SymKey
;; define-fn: fn symkey_cmp(a: *const SymKey, b: *const SymKey) -> i32
symkey_cmp:
push rbp
mov rbp, rsp
push rbx
mov al, byte [rsi] ; a.kind
mov bl, byte [rdx] ; b.kind
cmp al, bl
jl .a_less
jg .a_greater
mov rax, [rsi + 8] ; a.scope_index
mov rbx, [rdx + 8] ; b.scope_index
cmp rax, rbx
jl .a_less
jg .a_greater
mov rax, [rsi + 16] ; a.span
mov rbx, [rdx + 16] ; b.span
cmp rax, rbx
jl .a_less
jg .a_greater
mov rdi, [rsi + 24] ; a.ident
mov rsi, [rsi + 32] ; a.ident_len
mov rcx, [rdx + 32] ; b.ident_len
mov rdx, [rdx + 24] ; b.ident
call strcmp
cmp rax, 0
jl .a_less
jg .a_greater
xor rax, rax
jmp .epilogue
.a_less:
mov rax, -1
jmp .epilogue
.a_greater:
mov rax, 1
.epilogue:
pop rbx
pop rbp
ret
section .rdata
KEY_SCOPE equ 1
KEY_SCOPE_NAME equ 2
KEY_PARENT_SCOPE equ 3
KEY_START_LOCALS equ 4
KEY_ARG equ 5
KEY_VAR equ 6
KEY_END_LOCALS equ 7
section .text
;; rdi: Ast
;; rsi: root index
;; rdx: *SymbolTable
;; define-fn: fn ast_build_symtable(ast: *mut Ast, root_index: u64, symtable: *mut core::mem::MaybeUninit<SymbolTable>)
ast_build_symtable:
push rbp
mov rbp, rsp
; BuildSymtableCtx [24..104]
; *SymbolTable [16..24]
; root_index [8..16]
; Ast [0..8]
sub rsp, 104
mov [rsp], rdi ; Ast
mov [rsp + 8], rsi ; root_index
mov [rsp + 16], rdx ; *SymbolTable
; initialise scope_stack and symtable vecs
lea rdi, [rsp + 24] ; &BuildSymtableCtx.scope_stack
mov rsi, 8 ; size of u64
mov rdx, 0 ; drop = None
mov rcx, 128 ; capacity
call vec_init_with
lea rdi, [rsp + 24 + 40] ; &BuildSymtableCtx.symtable
mov rsi, 56 ; size_of::<SymEntry>
mov rdx, 0 ; drop = None
mov rcx, 128 ; capacity
call vec_init_with
mov rdi, [rsp] ; Ast
mov rsi, [rsp + 8] ; root_index
lea rdx, [rsp + 24] ; &BuildSymtableCtx
mov rcx, ast_build_symtable_for_each
call ast_walk_for_each
; memcpy symtable out
mov rdi, [rsp + 16] ; *SymbolTable
lea rsi, [rsp + 24 + 40] ; &BuildSymtableCtx.symtable
mov rdx, 40 ; size_of::<Vec<SymEntry>>
call memcpy
add rsp, 96
pop rbp
ret
;; symtable is a sorted vec pretending to be a b-tree:
;; entries are sorted by a key in order to get the following ordering:
;; scope (index0) -> (ident0)
;; scope (index1) -> (ident1)
;; scope (index2) -> (ident2)
;; scope-name (ident1) -> (index1)
;; scope-name (ident1) -> (index1)
;; parent-scope (scope1) -> (index0)
;; arg (scope1, span, ident) -> (index)
;; var (scope1, span, ident) -> (index)
;; var (scope1, span, ident) -> (index)
;; arg (scope0, span, ident) -> (index)
;; var (scope0, span, ident) -> (index)
;; var (scope0, span, ident) -> (index)
;;
;; arguments are ordered before variables in order to allow shadowing of variables by arguments.
;; variables are ordered by span in order to allow shadowing of variables by variables.
;; all references within a scope are in the range parent-scope(scopeN)..var
;; (scopeN, u64::MAX, u64::MAX)
;;
;; the symtable contains `SymEntries`, which hold a `SymKey` and an index into the AST node list.
;; for scope entries, the index holds the pointer to the scope's ident,
;; and `extra` holds the length; for other keys, `extra` is 0.
;;
;; start-structs
;; struct SymbolTable {
;; symtable: Vec<SymEntry>,
;; }
;; struct SymKey {
;; kind: u8,
;; scope_index: u64,
;; span: u64,
;; ident: *const u8,
;; ident_len: usize,
;; }
;; struct SymEntry {
;; key: SymKey,
;; index: u64,
;; extra: u64,
;; }
;; end-structs
;; size_of::<SymKey> == 40
;; size_of::<SymEntry> == 56
;;
;; #start-structs
;; struct BuildSymtableCtx {
;; scope_stack: Vec<u64>,
;; symtable: Vec<SymEntry>,
;; }
;; #end-structs
;;
;; scope_stack [0..40]
;; symtable [40..80]
;;
;; rdi: Ctx
;; rsi: Ast
;; rdx: index
ast_build_symtable_for_each:
push rbp
mov rbp, rsp
push rbx
; SymEntry [32..88]
; SymKey [32..72]
; *AstNode [24..32]
; index [16..24]
; ctx [8..16]
; ast [0..8]
sub rsp, 88
mov [rsp], rsi ; Ast
mov [rsp + 8], rdi ; Ctx
mov [rsp + 16], rdx ; index
mov rdi, rsi ; Ast
mov rsi, rdx ; index
call vec_get
mov [rsp + 24], rax ; *AstNode
mov bl, byte [rax] ; AstNode.kind
cmp bl, AST_FUNCTION
je .func
cmp bl, AST_VAR_DECL
je .var_decl
cmp bl, AST_ARG
je .arg
jmp .done
.func:
; insert scope entry
mov byte [rsp + 32], KEY_SCOPE ; SymKey.kind
mov rdx, [rsp + 16] ; index
mov qword [rsp + 40], rdx ; SymKey.scope_index
mov qword [rsp + 48], 0 ; SymKey.span
mov qword [rsp + 56], 0 ; SymKey.ident
mov qword [rsp + 64], 0 ; SymKey.ident_len
; mov rbx, [rax + 16] ; AstNode.data
; mov rdx, [rbx + 8] ; Func.name
; mov rcx, [rbx + 16] ; Func.name_len
mov rbx, [rax + 8] ; AstNode.data
mov rdx, [rbx + 0] ; Func.name
mov rcx, [rbx + 8] ; Func.name_len
mov qword [rsp + 72], rdx ; SymEntry.index
mov qword [rsp + 80], rcx ; SymEntry.extra
mov rdi, [rsp + 8] ; *Ctx
lea rdi, [rdi + 40] ; Ctx.symtable
lea rsi, [rsp + 32] ; &SymEntry
mov rcx, 0 ; cmp_ctx
mov rdx, symkey_cmp ; cmp
call vec_insert_sorted
; push scope index onto scope_stack
mov rdi, [rsp + 8] ; *Ctx
lea rdi, [rdi + 0] ; Ctx.scope_stack
lea rsi, [rsp + 16] ; &index
call vec_push
jmp .done
.var_decl:
; insert variable entry
mov byte [rsp + 32], KEY_VAR ; SymKey.kind
; TODO: set span correctly
mov qword [rsp + 48], 0 ; SymKey.span
mov rbx, [rsp + 24] ; AstNode.data
mov rdx, [rbx + 0] ; AstVarDecl.name
mov rcx, [rbx + 8] ; AstVarDecl.name_len
mov [rsp + 56], rdx ; SymKey.ident
mov [rsp + 64], rcx ; SymKey.ident_len
mov rdx, [rsp + 16] ; index
mov [rsp + 72], rdx ; SymEntry.index
mov qword [rsp + 80], 0 ; SymEntry.extra
mov qword [rsp + 40], 0 ; SymKey.scope_index = default
lea rdx, [rsp + 40]
mov rdi, [rsp + 8] ; *Ctx
mov rsi, [rdi + 8] ; Ctx.scope_stack.len()
dec rsi
call vec_get_or
mov rax, [rax] ; current scope index
mov [rsp + 40], rax ; SymKey.scope_index = scope_stack.last_or(0)
mov rdi, [rsp + 8] ; *Ctx
lea rdi, [rdi + 40] ; Ctx.symtable
lea rsi, [rsp + 32] ; &SymEntry
mov rcx, 0 ; cmp_ctx
mov rdx, symkey_cmp ; cmp
call vec_insert_sorted
.arg:
; insert variable entry
mov byte [rsp + 32], KEY_ARG ; SymKey.kind
; TODO: set span correctly
mov qword [rsp + 48], 0 ; SymKey.span
mov rbx, [rsp + 24] ; AstNode.data
mov rdx, [rbx + 0] ; AstArgument.name
mov rcx, [rbx + 8] ; AstArgument.name_len
mov [rsp + 56], rdx ; SymKey.ident
mov [rsp + 64], rcx ; SymKey.ident_len
mov rdx, [rsp + 16] ; index
mov [rsp + 72], rdx ; SymEntry.index
mov qword [rsp + 80], 0 ; SymEntry.extra
mov qword [rsp + 40], 0 ; SymKey.scope_index = default
lea rdx, [rsp + 40]
mov rdi, [rsp + 8] ; *Ctx
mov rsi, [rdi + 8] ; Ctx.scope_stack.len()
dec rsi
call vec_get_or
mov rax, [rax] ; current scope index
mov [rsp + 40], rax ; SymKey.scope_index = scope_stack.last_or(0)
mov rdi, [rsp + 8] ; *Ctx
lea rdi, [rdi + 40] ; Ctx.symtable
lea rsi, [rsp + 32] ; &SymEntry
mov rcx, 0 ; cmp_ctx
mov rdx, symkey_cmp ; cmp
call vec_insert_sorted
.done:
add rsp, 88
pop rbx
pop rbp
ret
;; rdi: Ast
;; rsi: start_index
;; rdx: ctx
;; rcx: for_each
;; define-fn: fn ast_walk_for_each(ast: *mut Ast, start_index: u64, ctx: *mut (), for_each: unsafe extern "C" fn(ctx: *mut (), *mut Ast, node_index: u64))
ast_walk_for_each:
push rbp
push r15
push r14
push rbx
; current_index [24..32]
; for_each [16..24]
; ctx [8..16]
; ast [0..8]
sub rsp, 32
mov [rsp], rdi ; Ast
mov [rsp + 8], rdx ; ctx
mov [rsp + 16], rcx ; for_each
mov [rsp + 24], rsi ; current_node_ptr
mov rbp, rsp
push rsi
.loop:
cmp rsp, rbp
jge .done
; call for_each(ctx, ast, current_index)
mov rdi, [rbp + 8] ; ctx
mov rsi, [rbp] ; Ast
mov rdx, [rsp] ; current_index
mov rax, [rbp + 16] ; for_each
; align stack to 16 bytes before call
mov rbx, rsp
sub rsp, 8
and rsp, -16
mov [rsp], rbx
call rax
pop rsp
; get current_node_ptr
mov rdi, [rbp] ; Ast
pop rsi ; current_index
call vec_get
mov [rbp + 24], rax ; current_node_ptr
mov bl, byte [rax] ; AstNode.kind
cmp bl, AST_FUNCTION
je .func
cmp bl, AST_BLOCK
je .block
cmp bl, AST_BINARY_OP
je .binary_op
cmp bl, AST_ASSIGNMENT
je .assignment
cmp bl, AST_VALUE_TO_PLACE
je .value_to_place
cmp bl, AST_PLACE_TO_VALUE
je .place_to_value
cmp bl, AST_DEREF
je .deref
cmp bl, AST_ADDRESS_OF
je .address_of
cmp bl, AST_RETURN_STATEMENT
je .return_statement
jmp .loop
.func:
; push child indices to stack
mov rbx, [rax + 8] ; AstNode.data
mov r15, [rbx + 24] ; AstFunction.args_len
xor r14, r14 ; index
.arg_loop:
cmp r14, r15
jge .arg_loop_done
mov rdx, [rbx + 16] ; AstFunction.args
lea rdx, [rdx + r14*8]
push qword [rdx] ; push arg index
inc r14
jmp .arg_loop
.arg_loop_done:
mov r15, [rbx + 48] ; AstFunction.body
push r15 ; push body index
jmp .loop
.block:
mov rbx, [rax + 8] ; AstNode.data
mov r15, [rax + 16] ; AstNode.extra
.stmt_loop:
cmp r15, 0
jle .stmt_loop_done
dec r15
mov rdx, [rbx + r15*8] ; statement index
push rdx ; push statement index
jmp .stmt_loop
.stmt_loop_done:
jmp .loop
.binary_op:
mov rbx, [rax + 8] ; AstNode.data
mov rdx, [rbx + 16] ; right index
push rdx ; push right index
mov rdx, [rbx + 0] ; left index
push rdx ; push left index
jmp .loop
.assignment:
mov rbx, [rax + 8] ; AstNode.data = dest
mov rdx, [rax + 16] ; AstNode.extra = source
push rdx ; push source index
push rbx ; push dest index
jmp .loop
.value_to_place:
.place_to_value:
.deref:
.address_of:
.return_statement:
mov rbx, [rax + 8] ; AstNode.data
push rbx ; push inner expr index
jmp .loop
; weird alloca thing
.done:
add rsp, 32
pop rbx
pop r14
pop r15
pop rbp
ret
;; rdi: Ast
;; define-fn: fn ast_resolve_var_refs(ast: *mut Ast)
ast_resolve_var_refs:
push rbp
mov rbp, rsp
push r15
push r14
push rbx
.epilogue:
add rsp, 8
pop rbx
pop r14
pop r15
pop rbp
ret

View file

@ -18,6 +18,7 @@ global vec_drop_last
global vec_swap
global vec_remove
global vec_get
global vec_get_or
global vec_drop
global vec_find
global vec_insert
@ -104,6 +105,29 @@ vec_init_with:
pop rbp
ret
;; rdi: pointer to Vec struct
;; rsi: index
;; rdx: pointer to default value
;; fn vec_get(vec: *mut Vec, index: usize, default: *mut u8) -> *mut u8
vec_get_or:
push rbp
mov rbp, rsp
; if (index >= vec.len) panic();
mov rax, [rdi + 8] ; len
cmp rsi, rax
jge .default
; return &mut vec.data[index * vec.item_size];
mov rax, [rdi + 24] ; item_size
mul rsi ; index * item_size
mov rsi, [rdi] ; data
add rax, rsi ; data + index * item_size
pop rbp
ret
.default:
mov rax, rdx
pop rbp
ret
;; rdi: pointer to Vec struct
;; rsi: index
;; fn vec_get(vec: *mut Vec, index: usize) -> *mut u8

View file

@ -1,3 +1,5 @@
#![feature(debug_closure_helpers)]
#[path = "shared/shared.rs"]
mod util;
@ -26,23 +28,36 @@ fn main() {
let expr_id = parser(&mut ast);
eprintln!("Parsed expression ID: {}", expr_id);
println!("{:#}", &ast);
unsafe extern "C" fn visit_node(_this: *mut (), ast: *mut Ast, node_id: u64) {
let ast = unsafe { &*ast };
let node = ast.nodes.get(node_id as usize).unwrap();
eprintln!("Visiting node {node_id}: {node}");
}
util::defs::ast_walk_for_each(&mut ast, expr_id, core::ptr::null_mut(), visit_node);
// let mut symtable = core::mem::MaybeUninit::<util::defs::SymbolTable>::uninit();
// util::defs::ast_build_symtable(&mut ast, expr_id, &mut symtable);
// let symtable = symtable.assume_init();
// println!("Symbol Table: {:#?}", symtable);
};
}
print_ast(b"3 + 4", |ast| unsafe { parse_expr(ast) });
print_ast(b"fn main() -> void { return 1 + 2; }", |ast| unsafe {
parse_func(ast)
});
print_ast(b"fn main() -> void { return (1 + (2)); }", |ast| unsafe {
parse_func(ast)
});
print_ast(
b"fn main() -> void { return (1 + (2 * 3)) / 4; }",
|ast| unsafe { parse_func(ast) },
);
print_ast(b"fn main() -> void { return 1 + 2 * 3; }", |ast| unsafe {
parse_func(ast)
});
// print_ast(b"3 + 4", |ast| unsafe { parse_expr(ast) });
// print_ast(b"fn main() -> void { return 1 + 2; }", |ast| unsafe {
// parse_func(ast)
// });
// print_ast(b"fn main() -> void { return (1 + (2)); }", |ast| unsafe {
// parse_func(ast)
// });
// print_ast(
// b"fn main() -> void { return (1 + (2 * 3)) / 4; }",
// |ast| unsafe { parse_func(ast) },
// );
// print_ast(b"fn main() -> void { return 1 + 2 * 3; }", |ast| unsafe {
// parse_func(ast)
// });
print_ast(b"fn main() -> void { let x: u32 = 4; }", |ast| unsafe {
parse_func(ast)
@ -182,3 +197,41 @@ impl core::fmt::Display for Ast {
writeln!(f, "\n]")
}
}
impl core::fmt::Display for util::defs::SymEntry {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SymEntry")
.field_with("key", |f| {
f.debug_struct("Key")
.field("kind", &self.key.kind)
.field("scope", &self.key.scope_index)
.field("span", &self.key.span)
.field_with("ident", |f| {
f.write_str(unsafe {
&core::str::from_utf8_unchecked(core::slice::from_raw_parts(
self.key.ident,
self.key.ident_len,
))
})
})
.finish()
})
.field_with("value", |f| {
let stct = &mut f.debug_struct("Value");
if self.extra == 0 {
stct.field("ast_index", &self.index).finish()
} else {
stct.field_with("ident", |f| {
f.write_str(unsafe {
core::str::from_utf8_unchecked(core::slice::from_raw_parts(
self.index as *const u8,
self.extra as usize,
))
})
})
.finish()
}
})
.finish()
}
}

View file

@ -13,6 +13,10 @@ unsafe extern "C" {
pub unsafe fn parse_prefix_expr(ast: *mut Ast) -> (u64, bool);
pub unsafe fn parse_assignment(ast: *mut Ast) -> (u64, bool);
pub unsafe fn ast_parse_let(ast: *mut Ast) -> (u64, bool);
pub unsafe fn symkey_cmp(a: *const SymKey, b: *const SymKey) -> i32;
pub unsafe fn ast_build_symtable(ast: *mut Ast, root_index: u64, symtable: *mut core::mem::MaybeUninit<SymbolTable>);
pub unsafe fn ast_walk_for_each(ast: *mut Ast, start_index: u64, ctx: *mut (), for_each: unsafe extern "C" fn(ctx: *mut (), *mut Ast, node_index: u64));
pub unsafe fn ast_resolve_var_refs(ast: *mut Ast);
}
pub const AST_FUNCTION: u8 = 1;
@ -154,4 +158,28 @@ pub struct AstAssignment {
pub expr: u64,
}
#[repr(C)]
#[derive(Debug)]
pub struct SymbolTable {
pub symtable: Vec<SymEntry>,
}
#[repr(C)]
#[derive(Debug)]
pub struct SymKey {
pub kind: u8,
pub scope_index: u64,
pub span: u64,
pub ident: *const u8,
pub ident_len: usize,
}
#[repr(C)]
#[derive(Debug)]
pub struct SymEntry {
pub key: SymKey,
pub index: u64,
pub extra: u64,
}
use super::vec::Vec;