From 2dd4697b297e8dddde326152418f8a1ab83a6a08 Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 8 Aug 2025 21:52:01 +0200 Subject: [PATCH] refactor(tree): improve node handling and add tests for leaf promotion - Refactored `InternalNode` initialization and parent-child relationships to improve clarity and safety. - made `InternalNode` `repr(C)` because I was stupid - Added `reparent` method to simplify node reparenting logic. - Introduced `as_leaf_non_null` for safer leaf node handling. - Updated test utilities to streamline insertion and debugging. - Added a new test case `promote_leaf` to validate leaf promotion logic. --- src/tree.rs | 146 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 53 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index 703bbeb..7ec3c91 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -75,6 +75,7 @@ struct LeafNode { } #[derive(Debug)] +#[repr(C)] struct InternalNode { data: LeafNode, next: Option>>, @@ -165,10 +166,10 @@ impl InternalNode { fn init_with_leaf(this: *mut Self, leaf: Box>) { let mut uninit: Box>> = unsafe { core::mem::transmute(leaf) }; unsafe { - core::ptr::swap(&raw mut (*this).data, uninit.as_mut_ptr()); + core::mem::swap(uninit.assume_init_mut(), &mut (*this).data); (&raw mut (*this).data.leaf).write(false); (&raw mut (*this).next).write(None); - (&raw mut (*this).len).write(1); + (&raw mut (*this).len).write(0); } } @@ -200,13 +201,6 @@ impl NodeRef { } impl NodeRef { - fn new_internal_with_child(child: Root) -> Self { - let mut node = InternalNode::new(); - node.edges[0].write(child.node); - - unsafe { NodeRef::from_new_internal(node) } - } - fn new_internal() -> Self { let node = InternalNode::new(); unsafe { NodeRef::from_new_internal(node) } @@ -560,6 +554,11 @@ impl NodeRef { // it should be unique or shared. this.node.as_ptr() } + + fn as_leaf_non_null(this: &Self) -> NonNull> { + // SAFETY: the static node type is `Leaf`. + unsafe { NonNull::new_unchecked(Self::as_leaf_ptr(this)) } + } } impl<'a, K: 'a, V: 'a> NodeRef, K, V, marker::Internal> { @@ -743,17 +742,43 @@ impl<'a, K, V> NodeRef, K, V, marker::Leaf> { let leaf = self.node; let mut parent = self.ascend().ok().unwrap(); - let internal = Box::into_non_null(InternalNode::new_with_leaf(unsafe { - Box::from_non_null(leaf) - })) - .cast(); - parent.node.as_internal_mut().edges[parent.idx].write(internal); + let mut internal = unsafe { + NodeRef::from_new_internal(InternalNode::new_with_leaf(Box::from_non_null(leaf))) + }; - NodeRef { - node: internal, - _marker: PhantomData, + parent.node.as_internal_mut().edges[parent.idx].write(NodeRef::as_leaf_non_null(&internal)); + + unsafe { internal.borrow_mut().dormant().awaken() } + } +} + +impl NodeRef { + pub(super) fn reparent<'a>( + mut self, + mut parent: Handle, K, V, marker::Internal>, marker::Edge>, + key: K, + ) -> NodeRef, K, V, Type> + where + K: 'a, + V: 'a, + { + self.borrow_mut().as_leaf_mut().parent = Some(NodeRef::as_internal_non_null(&parent.node)); + self.borrow_mut() + .as_leaf_mut() + .parent_idx + .write(parent.idx as u16); + + let new_len = parent.node.len() + 1; + + unsafe { + slice_insert(parent.node.key_area_mut(..new_len), parent.idx, key); + slice_insert(parent.node.edge_area_mut(..new_len), parent.idx, self.node); } + + *parent.node.len_mut() = new_len as u16; + + unsafe { self.borrow_mut().dormant().awaken() } } } @@ -860,29 +885,11 @@ impl<'a, K: 'a, V: 'a, Type> Handle, K, V, Type>, marker }; let last = unsafe { - let mut child = NodeRef::new_leaf(); + let child = NodeRef::new_leaf().reparent(internal, key); - // set parent link - child.borrow_mut().as_leaf_mut().parent = - Some(NodeRef::as_internal_non_null(&internal.node)); - child - .borrow_mut() - .as_leaf_mut() - .parent_idx - .write(internal.idx as u16); - - let new_len = internal.node.len() + 1; - - slice_insert(internal.node.key_area_mut(..new_len), internal.idx, key); - slice_insert( - internal.node.edge_area_mut(..new_len), - internal.idx, - child.node, - ); - *internal.node.len_mut() = new_len as u16; - - let child = Handle::new_edge(child.borrow_mut(), 0); - child.insert_recursing(key_seq, val).dormant() + Handle::new_edge(child, 0) + .insert_recursing(key_seq, val) + .dormant() }; unsafe { last.awaken() } @@ -961,11 +968,14 @@ mod search { SearchResult::Insert(key, unsafe { Handle::new_edge(leaf.forget_type(), 0) }) } ForceResult::Internal(internal) => { + // search through the keys of the internal node match unsafe { internal.find_key_index(&key, 0) } { IndexResult::Insert(idx) => Insert(key, unsafe { + // the key wasn't present, but should be inserted at `idx`. Handle::new_edge(internal.forget_type(), idx) }), IndexResult::Edge(idx) => { + // the key was found, continue searching down the edge GoDown(unsafe { Handle::new_edge(internal, idx) }) } } @@ -1002,7 +1012,7 @@ mod search { } } - std::eprintln!("insert key at index {}", keys.len()); + std::eprintln!("push_back key at index {}", keys.len()); IndexResult::Insert(keys.len()) } } @@ -1012,22 +1022,24 @@ mod search { use super::super::Tree; use super::*; + fn insert_and_dbg<'a>( + tree: &'a mut Tree, + key: &'a str, + value: &'static str, + ) { + let entry = tree.entry(key.chars()); + std::dbg!(&entry); + let entry = entry.or_insert(value); + std::dbg!(&entry); + } + #[test] fn asdf() { let mut tree = Tree::new(); - let entry = tree.entry("+".chars()); - std::dbg!(&entry); - entry.or_insert("Plus"); - - let entry = tree.entry("++".chars()); - std::dbg!(&entry); - entry.or_insert("PlusPlus"); - - let entry = tree.entry("+=".chars()); - std::dbg!(&entry); - entry.or_insert("PlusEqual"); - - std::dbg!(tree.entry("++".chars())); + insert_and_dbg(&mut tree, "+", "Plus"); + insert_and_dbg(&mut tree, "++", "PlusPlus"); + insert_and_dbg(&mut tree, "+=", "PlusEqual"); + insert_and_dbg(&mut tree, "++-", "PlusPlusMinus"); std::eprintln!("tree: {:?}", &tree); @@ -1035,6 +1047,7 @@ mod search { tree.entry("++".chars()).or_insert("asdf").get(), &"PlusPlus" ); + std::dbg!(tree.entry("+".chars())); assert_eq!(tree.entry("+".chars()).or_insert("asdf").get(), &"Plus"); } } @@ -1445,3 +1458,30 @@ unsafe fn slice_insert(slice: &mut [MaybeUninit], idx: usize, value: T) { (*slice_ptr.add(idx)).write(value); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn promote_leaf() { + #[derive(Debug, PartialEq, Eq)] + struct Test(&'static str); + + impl Drop for Test { + fn drop(&mut self) { + std::eprintln!("Dropping: {}", self.0); + } + } + + let mut leaf = NodeRef::<_, (), Test, _>::new_leaf(); + leaf.borrow_mut().as_leaf_mut().value = Some(Test("test")); + let mut root = NodeRef::new_internal(); + + let leaf = leaf.reparent(unsafe { Handle::new_edge(root.borrow_mut(), 0) }, ()); + + let mut internal = unsafe { leaf.make_internal_node() }; + + assert_eq!(internal.as_leaf_mut().value, Some(Test("test"))); + } +}