SeaLang/src/ast2/typecheck.rs
2025-08-02 21:25:04 +02:00

503 lines
14 KiB
Rust

// Visitor pattern has lots of unused arguments
#![allow(unused_variables)]
use crate::variant;
use super::{
Ast, Data, Index, Tag,
intern::{self},
tag::{AstNode, AstNodeExt},
visitor::{AstExt, AstVisitorTrait},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ErrorKind {
MismatchingTypes,
DerefNonPointer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Error {
idx: Index,
kind: ErrorKind,
}
struct TypeChecker<'a> {
ip: &'a mut intern::InternPool,
errors: Vec<Error>,
}
impl<'a> TypeChecker<'a> {
fn visit_children<'b>(
&mut self,
ast: &'b mut super::Ast,
idx: Index,
) -> Result<Option<intern::Index>, ()> {
for child in ast.get_node_children(idx) {
self.visit_any(ast, child)?;
}
Ok(None)
}
}
// TODO: actually handle errors here instead of aborting.
impl<'a> AstVisitorTrait<&'a mut Ast> for TypeChecker<'_> {
type Error = ();
type Value = Option<intern::Index>;
const UNIMPL: Self::Error = ();
fn visit_parameter_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
ty: Index,
name: intern::Index,
) -> Result<Self::Value, Self::Error> {
_ = (idx, name);
self.visit_any(ast, ty)
}
fn visit_array_type_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
length: Index,
pointer: Index,
) -> Result<Self::Value, Self::Error> {
_ = self.visit_any(ast, length)?;
let pointee = self.visit_any(ast, pointer)?.expect("no type?");
variant!(self.ip.get_key(pointee) => intern::Key::PointerType { pointee, flags });
// get interened value from constant node
let length = {
let value = ast.datas[length].as_index_intern().1;
match self.ip.get_key(value) {
intern::Key::SIntSmall { bits } => bits as u32,
intern::Key::UIntSmall { bits } => bits as u32,
intern::Key::SInt64 { bits } => bits as u32,
intern::Key::UInt64 { bits } => bits as u32,
intern::Key::NegativeInt { bigint } | intern::Key::PositiveInt { bigint } => {
bigint.iter_u32_digits().next().unwrap_or(0)
}
_ => 0,
}
};
let ty = self.ip.get_array_type(pointee, Some(flags), length);
ast.tags[idx] = Tag::InternedType;
ast.datas[idx] = Data::intern(ty);
Ok(Some(ty))
}
fn visit_pointer_type_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
ty: Index,
flags: intern::PointerFlags,
) -> Result<Self::Value, Self::Error> {
let pointee = self.visit_any(ast, ty)?.expect("no type?");
let ty = self.ip.get_pointer_type(pointee, Some(flags));
ast.tags[idx] = Tag::InternedType;
ast.datas[idx] = Data::intern(ty);
Ok(Some(ty))
}
fn visit_type_decl_ref_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
decl: Index,
) -> Result<Self::Value, Self::Error> {
let ty = self.visit_any(ast, decl)?.expect("no type?");
ast.tags[idx] = Tag::InternedType;
ast.datas[idx] = Data::intern(ty);
Ok(Some(ty))
}
fn visit_function_proto_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
name: intern::Index,
return_type: Index,
parameter_list: Index,
) -> Result<Self::Value, Self::Error> {
variant!(ast.get_ast_node(parameter_list) => AstNode::ParameterList { params });
let return_type = self.visit_any(ast, return_type)?.expect("no type?");
let params = params
.into_iter()
.map(|param| self.visit_parameter(ast, param).map(|a| a.unwrap()))
.collect::<Result<Vec<_>, _>>()?;
let ty = self.ip.get_function_type(return_type, params);
ast.tags[idx] = Tag::FunctionProtoInterned;
ast.datas[idx] = Data::two_interns(name, ty);
Ok(Some(ty))
}
fn visit_function_proto_interned_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
name: intern::Index,
ty: intern::Index,
) -> Result<Self::Value, Self::Error> {
Ok(Some(ty))
}
fn visit_struct_decl_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
name: intern::Index,
flags: intern::StructFlags,
field_names: Vec<intern::Index>,
field_types: Vec<Index>,
) -> Result<Self::Value, Self::Error> {
let types = field_types
.into_iter()
.map(|i| self.visit_any(ast, i).map(Option::unwrap))
.collect::<Result<Vec<_>, _>>()?;
let ty = self.ip.insert_or_replace_struct_type(
name,
idx,
flags.packed,
flags.c_like,
field_names.into_iter().zip(types),
);
ast.tags[idx] = Tag::StructDeclInterned;
ast.datas[idx] = Data::two_interns(name, ty);
Ok(Some(ty))
}
fn visit_function_decl_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
proto: Index,
body: Index,
) -> Result<Self::Value, Self::Error> {
let ty = self.visit_any(ast, proto)?.expect("no type?");
let body = self.visit_block(ast, body)?.unwrap_or(intern::Index::VOID);
let ret_ty = self.ip.try_get_return_type(ty).expect("no type?");
if body != ret_ty {
self.errors.push(Error {
idx,
kind: ErrorKind::MismatchingTypes,
});
Ok(Some(intern::Index::VOID))
} else {
Ok(Some(ty))
}
}
fn visit_block_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
statements: Vec<Index>,
expr: Option<Index>,
) -> Result<Self::Value, Self::Error> {
for stmt in statements {
self.visit_any(ast, stmt)?;
}
expr.map(|expr| self.visit_any(ast, expr))
.transpose()
.map(Option::flatten)
}
fn visit_file_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
decls: Vec<Index>,
) -> Result<Self::Value, Self::Error> {
for decl in decls {
self.visit_any(ast, decl)?;
}
Ok(None)
}
fn visit_struct_decl_interned_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
name: intern::Index,
ty: intern::Index,
) -> Result<Self::Value, Self::Error> {
Ok(Some(ty))
}
fn visit_interned_type_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
intern: intern::Index,
) -> Result<Self::Value, Self::Error> {
Ok(Some(intern))
}
fn visit_any(
&mut self,
ast: &'a mut Ast,
idx: super::Index,
) -> Result<Self::Value, Self::Error> {
let (tag, data) = ast.get_node_tag_and_data(idx);
let _ty = match tag {
Tag::Root => unreachable!(),
Tag::InternedType => Some(data.as_intern()),
Tag::File => self.visit_file(ast, idx)?,
Tag::ArrayType => self.visit_array_type(ast, idx)?,
Tag::PointerType => self.visit_pointer_type(ast, idx)?,
Tag::TypeDeclRef => self.visit_type_decl_ref(ast, idx)?,
Tag::FunctionProtoInterned => self.visit_function_proto_interned(ast, idx)?,
Tag::FunctionProto => self.visit_function_proto(ast, idx)?,
Tag::Parameter => self.visit_parameter(ast, idx)?,
Tag::StructDecl => self.visit_struct_decl(ast, idx)?,
Tag::StructDeclInterned => self.visit_struct_decl_interned(ast, idx)?,
Tag::FunctionDecl => self.visit_function_decl(ast, idx)?,
Tag::ParameterList => todo!(),
Tag::Block | Tag::BlockTrailingExpr => self.visit_block(ast, idx)?,
Tag::Constant => self.visit_constant(ast, idx)?,
Tag::ExprStmt => self.visit_expr_stmt(ast, idx)?,
Tag::ReturnStmt => self.visit_return_stmt(ast, idx)?,
Tag::ReturnExprStmt => todo!(),
Tag::VarDecl => todo!(),
Tag::MutVarDecl => todo!(),
Tag::VarDeclAssignment => todo!(),
Tag::MutVarDeclAssignment => todo!(),
Tag::GlobalDecl => todo!(),
Tag::FieldDecl => todo!(),
Tag::DeclRef => todo!(),
Tag::DeclRefUnresolved => todo!(),
Tag::TypeDeclRefUnresolved => todo!(),
Tag::CallExpr => todo!(),
Tag::FieldAccess => todo!(),
Tag::ArgumentList => todo!(),
Tag::Argument => todo!(),
Tag::NamedArgument => todo!(),
Tag::ExplicitCast => self.visit_explicit_cast(ast, idx)?,
Tag::Deref => self.visit_deref(ast, idx)?,
Tag::AddressOf => self.visit_address_of(ast, idx)?,
Tag::Not | Tag::Negate => {
let ty = self.visit_any(ast, data.as_index())?;
ty
}
Tag::PlaceToValueConversion => self.visit_place_to_value_conversion(ast, idx)?,
Tag::ValueToPlaceConversion => self.visit_value_to_place_conversion(ast, idx)?,
Tag::Or
| Tag::And
| Tag::BitOr
| Tag::BitXOr
| Tag::BitAnd
| Tag::Eq
| Tag::NEq
| Tag::Lt
| Tag::Gt
| Tag::Le
| Tag::Ge
| Tag::Shl
| Tag::Shr
| Tag::Add
| Tag::Sub
| Tag::Mul
| Tag::Div
| Tag::Rem => {
let (lhs, rhs) = data.as_two_indices();
let lhs = self.visit_any(ast, lhs)?.unwrap();
let rhs = self.visit_any(ast, rhs)?.unwrap();
if lhs != rhs {
self.errors.push(Error {
idx,
kind: ErrorKind::MismatchingTypes,
});
}
Some(lhs)
}
Tag::Assign => self.visit_assign(ast, idx)?,
Tag::SubscriptExpr => self.visit_subscript_expr(ast, idx)?,
Tag::IfExpr => self.visit_if_expr(ast, idx)?,
Tag::IfElseExpr => self.visit_if_else_expr(ast, idx)?,
Tag::Error => todo!(),
Tag::Undefined => todo!(),
};
todo!()
}
fn visit_explicit_cast_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
expr: Index,
ty: Index,
) -> Result<Self::Value, Self::Error> {
_ = self.visit_any(ast, expr)?;
// TODO: make sure this cast is legal
let ty = self.visit_any(ast, ty)?;
Ok(ty)
}
fn visit_address_of_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
expr: Index,
) -> Result<Self::Value, Self::Error> {
self.visit_any(ast, expr)
}
fn visit_deref_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
expr: Index,
) -> Result<Self::Value, Self::Error> {
let ty = self.visit_any(ast, expr)?.unwrap();
if let intern::Key::PointerType { pointee, .. } = self.ip.get_key(ty) {
Ok(Some(pointee))
} else {
Ok(None)
}
}
fn visit_place_to_value_conversion_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
expr: Index,
) -> Result<Self::Value, Self::Error> {
let ty = self.visit_any(ast, expr)?;
if let Some(ty) = ty {
Ok(self.ip.try_get_pointee_type(ty))
} else {
self.errors.push(Error {
idx,
kind: ErrorKind::DerefNonPointer,
});
Ok(None)
}
}
fn visit_value_to_place_conversion_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
expr: Index,
) -> Result<Self::Value, Self::Error> {
let ty = self.visit_any(ast, expr)?;
if let Some(ty) = ty {
Ok(Some(self.ip.get_pointer_type(ty, None)))
} else {
Ok(None)
}
}
fn visit_if_else_expr_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
cond: Index,
a: Index,
b: Index,
) -> Result<Self::Value, Self::Error> {
self.visit_any(ast, cond)?;
let a = self.visit_block(ast, a)?;
let b = self.visit_block(ast, b)?;
if a != b {
self.errors.push(Error {
idx,
kind: ErrorKind::MismatchingTypes,
});
}
Ok(a)
}
fn visit_if_expr_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
cond: Index,
body: Index,
) -> Result<Self::Value, Self::Error> {
self.visit_any(ast, cond)?;
self.visit_block(ast, body)?;
Ok(None)
}
fn visit_subscript_expr_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
lhs: Index,
index: Index,
) -> Result<Self::Value, Self::Error> {
let lhs = self.visit_any(ast, lhs)?.unwrap();
let index = self.visit_any(ast, index)?.unwrap();
if index != intern::Index::USIZE {
self.errors.push(Error {
idx,
kind: ErrorKind::MismatchingTypes,
});
}
let pointee = self.ip.try_get_pointee_type(lhs).unwrap();
let pointer = self.ip.get_pointer_type(pointee, None);
Ok(Some(pointer))
}
fn visit_assign_impl(
&mut self,
ast: &'a mut Ast,
idx: Index,
lhs: Index,
rhs: Index,
) -> Result<Self::Value, Self::Error> {
let lhs = self.visit_any(ast, lhs)?.unwrap();
let rhs = self.visit_any(ast, rhs)?.unwrap();
if self.ip.try_get_pointee_type(lhs) != Some(rhs) {
self.errors.push(Error {
idx,
kind: ErrorKind::MismatchingTypes,
});
}
Ok(None)
}
}