#![feature(test)]

fn available_parallelism() -> usize {
    bevy_tasks::available_parallelism().max(4)
}
use executor::{self};
use test::Bencher;

extern crate test;

mod tree {

    pub struct Tree<T> {
        nodes: Box<[Node<T>]>,
        root: Option<usize>,
    }
    pub struct Node<T> {
        pub leaf: T,
        pub left: Option<usize>,
        pub right: Option<usize>,
    }

    impl<T> Tree<T> {
        pub fn new(depth: usize, t: T) -> Tree<T>
        where
            T: Copy,
        {
            let mut nodes = Vec::with_capacity((0..depth).sum());
            let root = Self::build_node(&mut nodes, depth, t);
            Self {
                nodes: nodes.into_boxed_slice(),
                root: Some(root),
            }
        }

        pub fn root(&self) -> Option<usize> {
            self.root
        }

        pub fn get(&self, index: usize) -> &Node<T> {
            &self.nodes[index]
        }

        pub fn build_node(nodes: &mut Vec<Node<T>>, depth: usize, t: T) -> usize
        where
            T: Copy,
        {
            let node = Node {
                leaf: t,
                left: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
                right: (depth != 0).then(|| Self::build_node(nodes, depth - 1, t)),
            };
            nodes.push(node);
            nodes.len() - 1
        }
    }
}

const TREE_SIZE: usize = 16;

#[bench]
fn join_melange(b: &mut Bencher) {
    let pool = executor::melange::ThreadPool::new(available_parallelism());

    let mut scope = pool.new_worker();

    let tree = tree::Tree::new(TREE_SIZE, 1u32);

    fn sum(
        tree: &tree::Tree<u32>,
        node: usize,
        scope: &mut executor::melange::WorkerThread,
    ) -> u32 {
        let node = tree.get(node);
        let (l, r) = scope.join(
            |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
            |s| {
                node.right
                    .map(|node| sum(tree, node, s))
                    .unwrap_or_default()
            },
        );

        node.leaf + l + r
    }

    b.iter(move || {
        let sum = sum(&tree, tree.root().unwrap(), &mut scope);
        //eprintln!("{sum}");
        assert_ne!(sum, 0);
    });
}

#[bench]
fn join_praetor(b: &mut Bencher) {
    use executor::praetor::Scope;
    let pool = executor::praetor::ThreadPool::global();

    let tree = tree::Tree::new(TREE_SIZE, 1u32);

    fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
        let node = tree.get(node);
        Scope::with(|s| {
            let (l, r) = s.join(
                || node.left.map(|node| sum(tree, node)).unwrap_or_default(),
                || node.right.map(|node| sum(tree, node)).unwrap_or_default(),
            );

            node.leaf + l + r
        })
    }

    b.iter(move || {
        let sum = pool.scope(|_| sum(&tree, tree.root().unwrap()));
        // eprintln!("{sum}");
        assert_ne!(sum, 0);
    });
}

#[bench]
fn join_sync(b: &mut Bencher) {
    let tree = tree::Tree::new(TREE_SIZE, 1u32);

    fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
        let node = tree.get(node);
        let (l, r) = (
            node.left.map(|node| sum(tree, node)).unwrap_or_default(),
            node.right.map(|node| sum(tree, node)).unwrap_or_default(),
        );

        node.leaf + l + r
    }

    b.iter(move || {
        assert_ne!(sum(&tree, tree.root().unwrap()), 0);
    });
}

#[bench]
fn join_chili(b: &mut Bencher) {
    let tree = tree::Tree::new(TREE_SIZE, 1u32);

    fn sum(tree: &tree::Tree<u32>, node: usize, scope: &mut chili::Scope<'_>) -> u32 {
        let node = tree.get(node);
        let (l, r) = scope.join(
            |s| node.left.map(|node| sum(tree, node, s)).unwrap_or_default(),
            |s| {
                node.right
                    .map(|node| sum(tree, node, s))
                    .unwrap_or_default()
            },
        );

        node.leaf + l + r
    }

    b.iter(move || {
        assert_ne!(
            sum(&tree, tree.root().unwrap(), &mut chili::Scope::global()),
            0
        );
    });
}

#[bench]
fn join_rayon(b: &mut Bencher) {
    let tree = tree::Tree::new(TREE_SIZE, 1u32);

    fn sum(tree: &tree::Tree<u32>, node: usize) -> u32 {
        let node = tree.get(node);
        let (l, r) = rayon::join(
            || node.left.map(|node| sum(tree, node)).unwrap_or_default(),
            || node.right.map(|node| sum(tree, node)).unwrap_or_default(),
        );

        node.leaf + l + r
    }

    b.iter(move || {
        assert_ne!(sum(&tree, tree.root().unwrap()), 0);
    });
}