comptime folding changes/fixes for declrefs

This commit is contained in:
Janis 2024-08-27 21:52:36 +02:00
parent 62cf214bde
commit 2d8c75ba0d
3 changed files with 127 additions and 117 deletions

View file

@ -572,6 +572,32 @@ pub mod tree_visitor {
pub fn new_range_inclusive(start: Node, end: Node, pre: F1, post: F2) -> Self { pub fn new_range_inclusive(start: Node, end: Node, pre: F1, post: F2) -> Self {
Self::new_inner(start, End::Inclusive(end), pre, post) Self::new_inner(start, End::Inclusive(end), pre, post)
} }
pub fn new_seek(tree: &Tree, start: Node, pre: F1, post: F2) -> Self {
let root_frame = Frame {
node: Node::MAX,
children: tree.global_decls.clone(),
};
Self {
frames: vec![root_frame],
current_node: None,
end: End::Open,
pre,
post,
}
.skip_until(tree, start)
}
pub fn until_before(self, end: Node) -> Self {
Self {
end: End::Exclusive(end),
..self
}
}
pub fn until_after(self, end: Node) -> Self {
Self {
end: End::Inclusive(end),
..self
}
}
} }
impl<F1, F2> Visitor<F1, F2> { impl<F1, F2> Visitor<F1, F2> {

View file

@ -2271,7 +2271,11 @@ impl ComptimeNumber {
match self { match self {
ComptimeNumber::Integral(i) => match i { ComptimeNumber::Integral(i) => match i {
ComptimeInt::Native { bits, ty } => { ComptimeInt::Native { bits, ty } => {
(bits.to_le_bytes().to_vec(), Type::Integer(ty)) let bytes = (u128::BITS - bits.leading_zeros() + 7) / 8;
(
bits.to_le_bytes()[..bytes as usize].to_vec(),
Type::Integer(ty),
)
} }
ComptimeInt::BigInt { bits, ty } => { ComptimeInt::BigInt { bits, ty } => {
(bits.to_le_bytes().to_vec(), Type::Integer(ty)) (bits.to_le_bytes().to_vec(), Type::Integer(ty))

View file

@ -1664,7 +1664,7 @@ impl Tree {
// simplify tree with compile-time math // simplify tree with compile-time math
impl Tree { impl Tree {
fn is_node_comptime(&self, node: Node) -> bool { fn is_node_comptime(&self, node: Node, check_declrefs: bool) -> bool {
match self.nodes.get_node(node) { match self.nodes.get_node(node) {
Tag::Block { Tag::Block {
statements, statements,
@ -1672,10 +1672,10 @@ impl Tree {
} => statements } => statements
.iter() .iter()
.chain(trailing_expr.into_iter()) .chain(trailing_expr.into_iter())
.all(|n| self.is_node_comptime(*n)), .all(|n| self.is_node_comptime(*n, true)),
Tag::Constant { .. } => true, Tag::Constant { .. } => true,
Tag::ExplicitCast { lhs, typename } => { Tag::ExplicitCast { lhs, typename } => {
self.is_node_comptime(*lhs) self.is_node_comptime(*lhs, true)
&& match self.type_of_node(*typename) { && match self.type_of_node(*typename) {
Type::Bool Type::Bool
| Type::ComptimeNumber | Type::ComptimeNumber
@ -1684,9 +1684,34 @@ impl Tree {
_ => false, _ => false,
} }
} }
Tag::DeclRef(lhs) | Tag::Not { lhs } | Tag::Negate { lhs } => { &Tag::DeclRef(lhs) if check_declrefs => {
self.is_node_comptime(*lhs) let start = lhs;
let end = node;
let mut is_comptime = true;
ast::tree_visitor::Visitor::new_seek(
self,start,
|_: &Tree, _| {
},
|tree: &Tree, node| match tree.nodes.get_node(node) {
&Tag::Assign { lhs, rhs } => {
if lhs == start || matches!(tree.nodes.get_node(lhs), &Tag::DeclRef(decl) if decl == start) {
is_comptime &= self.is_node_comptime(rhs, true);
} }
}
&Tag::Ref { lhs } if lhs == start => {
// recursively checking for derefs would get very complicated.
is_comptime = false;
}
_ => {}
},
)
.until_after(end)
.visit(self);
is_comptime
}
Tag::Not { lhs } | Tag::Negate { lhs } => self.is_node_comptime(*lhs, true),
Tag::Or { lhs, rhs } Tag::Or { lhs, rhs }
| Tag::And { lhs, rhs } | Tag::And { lhs, rhs }
| Tag::BitOr { lhs, rhs } | Tag::BitOr { lhs, rhs }
@ -1704,72 +1729,19 @@ impl Tree {
| Tag::Sub { lhs, rhs } | Tag::Sub { lhs, rhs }
| Tag::Mul { lhs, rhs } | Tag::Mul { lhs, rhs }
| Tag::Rem { lhs, rhs } | Tag::Rem { lhs, rhs }
| Tag::Div { lhs, rhs } => self.is_node_comptime(*lhs) && self.is_node_comptime(*rhs), | Tag::Div { lhs, rhs } => {
self.is_node_comptime(*lhs, true) && self.is_node_comptime(*rhs, true)
}
_ => false, _ => false,
} }
} }
fn try_fold_comptime_inner(&mut self, node: Node) {
if self.is_node_comptime(node) {
_ = self.fold_comptime_inner(node);
}
}
fn fold_comptime_with_visitor(&mut self, decl: Node) { fn fold_comptime_with_visitor(&mut self, decl: Node) {
ast::tree_visitor::Visitor::new( ast::tree_visitor::Visitor::new(
decl, decl,
|_: &mut Tree, _| {}, |_: &mut Tree, _| {},
|tree: &mut Tree, node| { |tree: &mut Tree, node| {
if let Ok(value) = tree.fold_comptime_inner(node, false) {
let value_node = if let &Tag::DeclRef(lhs) = tree.nodes.get_node(node) {
let start = lhs;
let end = node;
let mut is_comptime = true;
let mut last_value = None;
eprintln!(
"checking if %{}, referencing %{} is comptime-evaluable",
node.get(),
lhs.get()
);
ast::tree_visitor::Visitor::new_range_inclusive(
decl,
end,
|_: &Tree, _| {
},
|tree: &Tree, node| match tree.nodes.get_node(node) {
&Tag::Assign { lhs, rhs } => {
if lhs == start || matches!(tree.nodes.get_node(lhs), &Tag::DeclRef(decl) if decl == start) {
eprintln!("found assignment at %{}", node.get());
is_comptime &= tree.is_node_comptime(rhs);
if is_comptime {
last_value = Some(rhs);
}
}
}
&Tag::Ref { lhs } if lhs == start => {
// recursively checking for derefs would get very complicated.
is_comptime = false;
}
_ => {}
},
)
.skip_until(tree, start)
.visit(tree);
eprintln!(
"%{} is {}comptime-evaluable.",
node.get(),
if is_comptime { "" } else { "not " }
);
eprintln!("%{node} comptime-value is %{last_value:?}");
is_comptime.then_some(last_value).flatten().unwrap_or(node)
}else {
node
};
if let Ok(value) = tree.fold_comptime_inner(value_node) {
let (bytes, ty) = value.into_bytes_and_type(); let (bytes, ty) = value.into_bytes_and_type();
let idx = tree.strings.insert(bytes); let idx = tree.strings.insert(bytes);
@ -1783,8 +1755,12 @@ impl Tree {
.visit_mut(self); .visit_mut(self);
} }
fn fold_comptime_inner(&mut self, decl: Node) -> comptime::Result<ComptimeNumber> { fn fold_comptime_inner(
if self.is_node_comptime(decl) { &mut self,
decl: Node,
check_declrefs: bool,
) -> comptime::Result<ComptimeNumber> {
if self.is_node_comptime(decl, check_declrefs) {
match self.nodes.get_node(decl) { match self.nodes.get_node(decl) {
Tag::Constant { bytes, ty } => { Tag::Constant { bytes, ty } => {
let bytes = match bytes { let bytes = match bytes {
@ -1823,12 +1799,12 @@ impl Tree {
return Ok(number); return Ok(number);
} }
Tag::Negate { lhs } => { Tag::Negate { lhs } => {
let lhs = self.fold_comptime_inner(*lhs)?; let lhs = self.fold_comptime_inner(*lhs, true)?;
return Ok(lhs.neg()?); return Ok(lhs.neg()?);
} }
Tag::ExplicitCast { lhs, typename } => { Tag::ExplicitCast { lhs, typename } => {
let ty = self.type_of_node(*typename); let ty = self.type_of_node(*typename);
let lhs = self.fold_comptime_inner(*lhs)?; let lhs = self.fold_comptime_inner(*lhs, true)?;
return match ty { return match ty {
Type::Bool => lhs.into_bool(), Type::Bool => lhs.into_bool(),
@ -1838,156 +1814,160 @@ impl Tree {
}; };
} }
Tag::Not { lhs } => { Tag::Not { lhs } => {
let lhs = self.fold_comptime_inner(*lhs)?; let lhs = self.fold_comptime_inner(*lhs, true)?;
return lhs.not(); return lhs.not();
} }
Tag::Or { lhs, rhs } => { Tag::Or { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.or(rhs); return lhs.or(rhs);
} }
Tag::And { lhs, rhs } => { Tag::And { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.and(rhs); return lhs.and(rhs);
} }
Tag::Eq { lhs, rhs } => { Tag::Eq { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.eq(rhs); return lhs.eq(rhs);
} }
Tag::NEq { lhs, rhs } => { Tag::NEq { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.eq(rhs)?.not(); return lhs.eq(rhs)?.not();
} }
Tag::Lt { lhs, rhs } => { Tag::Lt { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.lt(rhs); return lhs.lt(rhs);
} }
Tag::Gt { lhs, rhs } => { Tag::Gt { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.gt(rhs); return lhs.gt(rhs);
} }
Tag::Le { lhs, rhs } => { Tag::Le { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.le(rhs); return lhs.le(rhs);
} }
Tag::Ge { lhs, rhs } => { Tag::Ge { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.ge(rhs); return lhs.ge(rhs);
} }
Tag::BitOr { lhs, rhs } => { Tag::BitOr { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.bitor(rhs); return lhs.bitor(rhs);
} }
Tag::BitAnd { lhs, rhs } => { Tag::BitAnd { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.bitand(rhs); return lhs.bitand(rhs);
} }
Tag::BitXOr { lhs, rhs } => { Tag::BitXOr { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.bitxor(rhs); return lhs.bitxor(rhs);
} }
Tag::Shl { lhs, rhs } => { Tag::Shl { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.shl(rhs); return lhs.shl(rhs);
} }
Tag::Shr { lhs, rhs } => { Tag::Shr { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.shr(rhs); return lhs.shr(rhs);
} }
Tag::Add { lhs, rhs } => { Tag::Add { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.add(rhs); return lhs.add(rhs);
} }
Tag::Sub { lhs, rhs } => { Tag::Sub { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.sub(rhs); return lhs.sub(rhs);
} }
Tag::Mul { lhs, rhs } => { Tag::Mul { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.mul(rhs); return lhs.mul(rhs);
} }
Tag::Rem { lhs, rhs } => { Tag::Rem { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.rem(rhs); return lhs.rem(rhs);
} }
Tag::Div { lhs, rhs } => { Tag::Div { lhs, rhs } => {
let (lhs, rhs) = (*lhs, *rhs); let (lhs, rhs) = (*lhs, *rhs);
let lhs = self.fold_comptime_inner(lhs)?; let lhs = self.fold_comptime_inner(lhs, true)?;
let rhs = self.fold_comptime_inner(rhs)?; let rhs = self.fold_comptime_inner(rhs, true)?;
return lhs.div(rhs); return lhs.div(rhs);
} }
&Tag::DeclRef(lhs) => { &Tag::DeclRef(lhs) => {
variant!(self.nodes.get_node(lhs) => &Tag::VarDecl { assignment, .. }); let start = lhs;
let start = assignment.unwrap_or(lhs);
let end = decl; let end = decl;
let mut last_value = None; let mut last_value = None;
ast::tree_visitor::Visitor::new_range(
start, ast::tree_visitor::Visitor::new_seek(
end, self,start,
|_: &Tree, _| {}, |_: &Tree, node| {
},
|tree: &Tree, node| match tree.nodes.get_node(node) { |tree: &Tree, node| match tree.nodes.get_node(node) {
&Tag::Assign { lhs, rhs } if lhs == start => { &Tag::Assign { lhs, rhs } => {
if lhs == start || matches!(tree.nodes.get_node(lhs), &Tag::DeclRef(decl) if decl == start) {
last_value = Some(rhs); last_value = Some(rhs);
} }
}
_ => {} _ => {}
}, },
) )
.until_after(end)
.visit(self); .visit(self);
return self return self.fold_comptime_inner(
.fold_comptime_inner(last_value.ok_or(comptime::Error::NotComptime)?); last_value.ok_or(comptime::Error::NotComptime)?,
true,
);
} }
_ => { _ => {
unreachable!() unreachable!()
@ -2002,10 +1982,10 @@ impl Tree {
for decl in self.global_decls.clone() { for decl in self.global_decls.clone() {
match self.nodes.get_node(decl) { match self.nodes.get_node(decl) {
Tag::FunctionDecl { body, .. } => { Tag::FunctionDecl { body, .. } => {
_ = self.fold_comptime_inner(*body); self.fold_comptime_with_visitor(*body);
} }
Tag::GlobalDecl { assignment, .. } => { Tag::GlobalDecl { assignment, .. } => {
_ = self.fold_comptime_inner(*assignment); self.fold_comptime_with_visitor(*assignment);
} }
_ => unreachable!(), _ => unreachable!(),
} }