Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions core/patina_internal_collections/src/bst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ where
///
pub fn get(&self, key: &D::Key) -> Option<&D> {
match self.get_node(key) {
Some(node) => Some(&node.data),
Some(node) => {
// SAFETY: Nodes in the tree always have initialized data
Some(unsafe { node.data() })
}
None => None,
}
}
Expand All @@ -186,7 +189,7 @@ where
// SAFETY: The pointer comes from as_mut_ptr() on a valid node reference obtained from get_node().
// The caller is responsible for ensuring that the mutable reference doesn't modify key-affecting
// values.
Some(unsafe { &mut (*ptr).data })
Some(unsafe { (*ptr).data_mut() })
}
None => None,
}
Expand All @@ -209,7 +212,10 @@ where
///
pub fn get_with_idx(&self, idx: usize) -> Option<&D> {
match self.storage.get(idx) {
Some(node) => Some(&node.data),
Some(node) => {
// SAFETY: Nodes in storage always have initialized data
Some(unsafe { node.data() })
}
None => None,
}
}
Expand All @@ -236,7 +242,10 @@ where
///
pub unsafe fn get_with_idx_mut(&mut self, idx: usize) -> Option<&mut D> {
match self.storage.get_mut(idx) {
Some(node) => Some(&mut node.data),
Some(node) => {
// SAFETY: Nodes in storage always have initialized data
Some(unsafe { node.data_mut() })
}
None => None,
}
}
Expand Down Expand Up @@ -281,7 +290,8 @@ where
let mut current = self.root();
let mut closest = None;
while let Some(node) = current {
match key.cmp(node.data.key()) {
// SAFETY: Nodes in the tree always have initialized data
match key.cmp(unsafe { node.data() }.key()) {
Ordering::Equal => return Some(self.storage.idx(node.as_mut_ptr())),
Ordering::Less => current = node.left(),
Ordering::Greater => {
Expand Down Expand Up @@ -494,7 +504,8 @@ where
fn get_node(&self, key: &D::Key) -> Option<&Node<D>> {
let mut current_idx = self.root();
while let Some(node) = current_idx {
match key.cmp(node.data.key()) {
// SAFETY: Nodes in the tree always have initialized data
match key.cmp(unsafe { node.data() }.key()) {
Ordering::Equal => return Some(node),
Ordering::Less => current_idx = node.left(),
Ordering::Greater => current_idx = node.right(),
Expand Down Expand Up @@ -646,7 +657,8 @@ where
fn _dfs(node: Option<&Node<D>>, values: &mut alloc::vec::Vec<D>) {
if let Some(node) = node {
Self::_dfs(node.left(), values);
values.push(node.data);
// SAFETY: Nodes in the tree always have initialized data
values.push(unsafe { *node.data() });
Self::_dfs(node.right(), values);
}
}
Expand All @@ -666,6 +678,7 @@ where

#[cfg(test)]
#[coverage(off)]
#[allow(clippy::undocumented_unsafe_blocks)]
mod tests {
use crate::{Bst, node_size};

Expand Down Expand Up @@ -883,6 +896,7 @@ mod tests {
}

#[cfg(test)]
#[allow(clippy::undocumented_unsafe_blocks)]
mod fuzz_tests {
extern crate std;
use crate::{Bst, node_size};
Expand Down
133 changes: 103 additions & 30 deletions core/patina_internal_collections/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//!
//! SPDX-License-Identifier: Apache-2.0
//!
use core::{cell::Cell, mem, ptr::NonNull, slice};
use core::{cell::Cell, mem, mem::MaybeUninit, ptr::NonNull, slice};

use crate::{Error, Result, SliceKey};

Expand Down Expand Up @@ -51,23 +51,32 @@ where

/// Create a new storage container with a slice of memory.
pub fn with_capacity(slice: &'a mut [u8]) -> Storage<'a, D> {
let storage = Storage {
// SAFETY: This is reinterpreting a byte slice as a Node<D> slice.
// 1. The alignment is checked implicitly by the slice bounds.
// 2. The correct number of Node<D> elements that fit in the byte slice is calculated.
// 3. The lifetime ensures the byte slice remains valid for the storage's lifetime
data: unsafe {
slice::from_raw_parts_mut::<'a, Node<D>>(
slice as *mut [u8] as *mut Node<D>,
slice.len() / mem::size_of::<Node<D>>(),
)
},
length: 0,
available: Cell::default(),
// SAFETY: This is reinterpreting a byte slice as a MaybeUninit<Node<D>> slice.
// Using MaybeUninit explicitly represents uninitialized memory.
let uninit_buffer = unsafe {
slice::from_raw_parts_mut::<'a, MaybeUninit<Node<D>>>(
slice as *mut [u8] as *mut MaybeUninit<Node<D>>,
slice.len() / mem::size_of::<Node<D>>(),
)
};

Self::build_linked_list(storage.data);
storage.available.set(storage.data[0].as_mut_ptr());
// Initialize nodes with uninitialized data fields
for elem in uninit_buffer.iter_mut() {
elem.write(Node::new_uninit());
}

// SAFETY: All nodes have been initialized (though their data fields are uninitialized).
// We can now safely convert from MaybeUninit<Node<D>> to Node<D>.
let buffer =
unsafe { slice::from_raw_parts_mut(uninit_buffer.as_mut_ptr() as *mut Node<D>, uninit_buffer.len()) };

let storage = Storage { data: buffer, length: 0, available: Cell::default() };

if !storage.data.is_empty() {
Self::build_linked_list(storage.data);
storage.available.set(storage.data[0].as_mut_ptr());
}

storage
}

Expand Down Expand Up @@ -105,7 +114,11 @@ where
node.set_left(None);
node.set_right(None);
node.set_parent(None);
node.data = data;
// SAFETY: The node is from the available list, so its data field is uninitialized.
// We initialize it here when moving the node to the "in use" state.
unsafe {
node.init_data(data);
}
self.length += 1;
Ok((self.idx(node.as_mut_ptr()), node))
} else {
Expand Down Expand Up @@ -216,7 +229,13 @@ where
for i in 0..self.capacity() {
let old = &self.data[i];

buffer[i].data = old.data;
// SAFETY: Nodes at indices 0..self.len() are "in use" and have initialized data.
// We copy the initialized data from old to new.
unsafe {
let old_data = old.data();
// Use ptr::copy to copy the data from old to new
buffer[i].data = MaybeUninit::new(*old_data);
}
buffer[i].set_color(old.color());

if let Some(left) = old.left() {
Expand Down Expand Up @@ -467,7 +486,7 @@ pub struct Node<D>
where
D: SliceKey,
{
pub data: D,
pub data: MaybeUninit<D>,
color: Cell<bool>,
parent: Cell<*mut Node<D>>,
left: Cell<*mut Node<D>>,
Expand All @@ -478,8 +497,48 @@ impl<D> Node<D>
where
D: SliceKey,
{
/// Create a new node with uninitialized data.
/// The data field must be initialized separately using `init_data()`.
pub fn new_uninit() -> Self {
Node {
data: MaybeUninit::uninit(),
color: Cell::new(RED),
parent: Cell::default(),
left: Cell::default(),
right: Cell::default(),
}
}

/// Initialize the data field of an uninitialized node.
/// # Safety
/// The caller must ensure the data field has not been previously initialized.
pub unsafe fn init_data(&mut self, data: D) {
self.data.write(data);
}

/// Creates a new Node with initialized data.
/// Used for testing purposes.
#[cfg(test)]
pub fn new(data: D) -> Self {
Node { data, color: Cell::new(RED), parent: Cell::default(), left: Cell::default(), right: Cell::default() }
let mut node = Self::new_uninit();
node.data.write(data);
node
}

/// Get a reference to the data, assuming it is initialized.
/// # Safety
/// The caller must ensure the data field has been initialized.
pub unsafe fn data(&self) -> &D {
// SAFETY: Caller guarantees data is initialized
unsafe { self.data.assume_init_ref() }
}

/// Get a mutable reference to the data, assuming it is initialized.
/// # Safety
/// The caller must ensure the data field has been initialized.
pub unsafe fn data_mut(&mut self) -> &mut D {
// SAFETY: Caller guarantees data is initialized
unsafe { self.data.assume_init_mut() }
}

pub fn height_and_balance(node: Option<&Node<D>>) -> (i32, bool) {
Expand Down Expand Up @@ -587,12 +646,15 @@ where
impl<D: SliceKey> SliceKey for Node<D> {
type Key = D::Key;
fn key(&self) -> &Self::Key {
self.data.key()
// SAFETY: This method is only called on nodes that are in use (initialized).
// Nodes in the available list are never accessed for their key.
unsafe { self.data().key() }
}
}

#[cfg(test)]
#[coverage(off)]
#[allow(clippy::undocumented_unsafe_blocks)]
mod tests {
use super::*;

Expand All @@ -605,7 +667,8 @@ mod tests {
for i in 0..10 {
let (index, node) = storage.add(i).unwrap();
assert_eq!(index, i);
assert_eq!(node.data, i);
// SAFETY: Node was just added with data, so it's initialized
assert_eq!(unsafe { *node.data() }, i);
assert_eq!(storage.len(), i + 1);
}

Expand All @@ -616,16 +679,22 @@ mod tests {
storage.delete(storage.get(5).unwrap().as_mut_ptr());
let (index, node) = storage.add(11).unwrap();
assert_eq!(index, 5);
assert_eq!(node.data, 11);
// SAFETY: Node was just added with data, so it's initialized
assert_eq!(unsafe { *node.data() }, 11);

// Try and get a mutable reference to a node
{
let node = storage.get_mut(5).unwrap();
assert_eq!(node.data, 11);
node.data = 12;
// SAFETY: Node is in use, so data is initialized
assert_eq!(unsafe { *node.data() }, 11);
// SAFETY: Node is in use, we can modify the initialized data
unsafe {
*node.data_mut() = 12;
}
}
let node = storage.get(5).unwrap();
assert_eq!(node.data, 12);
// SAFETY: Node is in use, so data is initialized
assert_eq!(unsafe { *node.data() }, 12);
}

#[test]
Expand All @@ -643,8 +712,10 @@ mod tests {

p4.set_parent(Some(p1));

assert_eq!(Node::sibling(p2).unwrap().data, 3);
assert_eq!(Node::sibling(p3).unwrap().data, 2);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::sibling(p2).unwrap().data() }, 3);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::sibling(p3).unwrap().data() }, 2);
assert!(Node::sibling(p1).is_none());
}

Expand Down Expand Up @@ -683,7 +754,8 @@ mod tests {
p2.set_right(Some(p4));
p4.set_parent(Some(p2));

assert_eq!(Node::predecessor(p1).unwrap().data, 4);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::predecessor(p1).unwrap().data() }, 4);
assert!(Node::predecessor(p4).is_none());
}

Expand All @@ -703,7 +775,8 @@ mod tests {
p2.set_right(Some(p4));
p4.set_parent(Some(p2));

assert_eq!(Node::successor(p1).unwrap().data, 3);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::successor(p1).unwrap().data() }, 3);
assert!(Node::successor(p4).is_none());
}

Expand Down
Loading