From 2293b514cf3531987a9ff627697b0fc72c679920 Mon Sep 17 00:00:00 2001 From: janis Date: Fri, 19 Sep 2025 20:12:21 +0200 Subject: [PATCH] idk man.. --- src/ast2/biunification.rs | 140 ++++++++++++++++++++++++++++++++++---- src/ast2/internable.rs | 34 +++++++++ 2 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 src/ast2/internable.rs diff --git a/src/ast2/biunification.rs b/src/ast2/biunification.rs index 765140e..1636dd7 100644 --- a/src/ast2/biunification.rs +++ b/src/ast2/biunification.rs @@ -5,30 +5,43 @@ // Visitor pattern has lots of unused arguments #![allow(unused_variables)] -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; + +use crate::ast2::tag::{AstNode, AstNodeExt}; use super::{Ast, Index, intern, visitor::AstVisitorTrait}; type Id = u32; -enum Type { - Reified(intern::Index), - Variable(Id), +trait TypeVariance { + type T; + type Opposite; +} + +#[derive(Debug, Clone)] +enum TypeHead { + Real(intern::Index), + Function { args: Vec, ret: T::Opposite }, } /// Variance of a type parameter or constraint. /// A function of type `A -> B` is covariant in `B` and contravariant in `A`. -/// This means that a type `T` may be substituted for `A` if `T` is a subtype of `A`, but -/// a type `T` may only be substituted for `B` if `T` is a supertype of `B`. +/// This means that a type `T` may be substituted for `A` if `T` is a subtype of +/// `A`, that is, every `T` is also an `A`, +/// but a type `T` may only be substituted for `B` if `T` is a supertype of `B`, +/// that is, every `B` is also a `T`. /// /// Namely, in a type system with `int` and `nat <: int`, for a function `f: int /// -> int` in the expression `let u: int = 3; let t: nat = f(u);`, `u` may /// safely be used as an argument to `f` because `nat <: int`, but `f(u`)` may /// not be assigned to `t` because `int <: nat` is not true. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] enum Variance { - #[doc(alias = "Positive")] + /// A Positive, or union relationship between types. + /// used in value-places. Covariant, - #[doc(alias = "Negative")] + /// A Negative, or intersection relationship between types. + /// used in use-places. Contravariant, } @@ -37,27 +50,81 @@ struct Value(Id); #[derive(Debug, Clone, Copy)] struct Use(Id); +impl TypeVariance for Value { + type T = Value; + type Opposite = Use; +} + +impl TypeVariance for Use { + type T = Use; + type Opposite = Value; +} + /// Typechecking error. #[derive(Debug, Clone, thiserror::Error)] enum Error { #[error("Unimplemented feature")] Unimplemented, + #[error("{0}")] + StringError(String), } type Result = std::result::Result; struct Bindings { - inner: HashMap, + next_id: Id, + inner: HashMap, + bounds: HashSet<(Id, Id, Variance)>, + types: HashMap, +} + +impl Bindings { + fn new() -> Self { + Bindings { + next_id: 1, + inner: HashMap::new(), + bounds: HashSet::new(), + types: HashMap::new(), + } + } + + fn new_id(&mut self) -> Id { + let id = self.next_id; + self.next_id += 1; + id + } + + fn get_or_create(&mut self, idx: super::Index) -> Id { + self.inner.get(&idx).copied().unwrap_or_else(|| { + let id = self.new_id(); + self.inner.insert(idx, id); + id + }) + } + + /// retrieves the type Id for the given ast node. + fn get(&self, idx: super::Index) -> Option { + self.inner.get(&idx).copied() + } + + /// inserts a proper type for `id`. + fn insert_type(&mut self, id: Id, ty: intern::Index) { + self.types.insert(id, ty); + } } struct TypeChecker<'a> { pool: &'a mut intern::InternPool, + bindings: Bindings, } // Core impl TypeChecker<'_> { pub fn new(pool: &mut intern::InternPool) -> TypeChecker { - TypeChecker { pool } + TypeChecker { + pool, + bindings: Bindings::new(), + } } fn var(&mut self) -> (Value, Use) { @@ -73,6 +140,33 @@ impl<'a> AstVisitorTrait<&'a Ast> for TypeChecker<'_> { const UNIMPL: Self::Error = Error::Unimplemented; + fn visit_interned_type_impl( + &mut self, + ast: &'a Ast, + idx: Index, + intern: intern::Index, + ) -> std::result::Result { + let id = self.bindings.get_or_create(idx); + match self.pool.get_key(intern) { + intern::Key::SimpleType { + ty: intern::SimpleType::ComptimeInt, + } => { + // This is a type variable. + } + intern::Key::SimpleType { .. } + | intern::Key::PointerType { .. } + | intern::Key::ArrayType { .. } + | intern::Key::FunctionType { .. } + | intern::Key::StructType { .. } => { + // This is a real type. + self.bindings.insert_type(id, intern); + } + _ => unreachable!(), + } + + Ok(Value(id)) + } + fn visit_constant_impl( &mut self, ast: &'a Ast, @@ -80,8 +174,30 @@ impl<'a> AstVisitorTrait<&'a Ast> for TypeChecker<'_> { ty: Index, value: intern::Index, ) -> std::result::Result { - // constants may be of type `comptime_int`, which is a special type that - // cannot exist at runtime. + // get type from the pool + + let AstNode::InternedType { intern } = ast.get_ast_node(ty) else { + panic!( + "Expected an interned type node, got {:?}", + ast.get_ast_node(ty) + ); + }; + + match self.pool.get_key(intern) { + intern::Key::SimpleType { + ty: intern::SimpleType::ComptimeInt, + } => { + // This is a type variable. + } + intern::Key::SimpleType { .. } + | intern::Key::PointerType { .. } + | intern::Key::ArrayType { .. } + | intern::Key::FunctionType { .. } + | intern::Key::StructType { .. } => { + // This is a real type. + } + _ => unreachable!(), + } todo!() } diff --git a/src/ast2/internable.rs b/src/ast2/internable.rs new file mode 100644 index 0000000..caf48fa --- /dev/null +++ b/src/ast2/internable.rs @@ -0,0 +1,34 @@ +use super::*; +use core::hash::Hash; +use std::hash::Hasher; + +// Types implementing this trait can be stored in the internpool. +trait KeyTrait: Hash + Eq { + const TAG: Tag; + fn serialise(self, pool: &mut InternPool); + fn deserialise(index: Index, pool: &mut InternPool) -> Self; +} + +impl KeyTrait for String { + const TAG: Tag = Tag::String; + fn serialise(self, pool: &mut InternPool) { + todo!() + } + + fn deserialise(index: Index, pool: &mut InternPool) -> Self { + // let mut hasher = std::hash::DefaultHasher::new(); + // core::any::TypeId::of::().hash(&mut hasher); + // let tag = hasher.finish() as u32; + let item = pool.get_item(index).unwrap(); + assert_eq!(item.tag, Self::TAG); + + let start = pool.words[item.idx()] as usize; + let len = pool.words[item.idx() + 1] as usize; + let str = unsafe { + let bytes = &pool.strings[start..start + len]; + std::str::from_utf8_unchecked(bytes) + }; + + str.to_owned() + } +}