namespace InnovEnergy.Lib.Utils;

public static class TreeTraversal
{
    // public static IEnumerable<T> TraverseDepthFirstPreOrder<T>(this T root, Func<T, IEnumerable<T>> getChildren)
    // {
    //     var stack = new Stack<IEnumerator<T>>();
    //
    //     var iterator = root.AsSingleEnumerator();
    //
    //     while (true)
    //     {
    //         while (iterator.MoveNext())
    //         {
    //             yield return iterator.Current;
    //
    //             iterator = getChildren(iterator.Current).GetEnumerator();
    //             stack.Push(iterator);
    //         }
    //
    //         iterator.Dispose();
    //         if (stack.Count == 0) 
    //             yield break;
    //
    //         iterator = stack.Pop();
    //     }
    // }

    public static IEnumerable<T> TraverseDepthFirstPreOrder<T>(this T root, Func<T, IEnumerable<T>> getChildren)
    {
        return Traverse(root, 
                        getChildren, 
                        branchOpen: true, 
                        branchClose: false, 
                        leaf: true);
    }

    public static IEnumerable<T> TraverseDepthFirstPostOrder<T>(this T root, Func<T, IEnumerable<T>> getChildren)
    {
        return Traverse(root, 
                        getChildren, 
                        branchOpen: false, 
                        branchClose: true, 
                        leaf: true);
    }
    
    public static IEnumerable<T> TraverseLeaves<T>(this T root, Func<T, IEnumerable<T>> getChildren)
    {
        return Traverse(root, 
                        getChildren, 
                        branchOpen: false, 
                        branchClose: false, 
                        leaf: true);
    }
    
    
    public static IEnumerable<(T node, IEnumerable<T> path)> TraverseLeavesWithPath<T>(this T root, Func<T, IEnumerable<T>> getChildren)
    {
        return TraverseWithPath(root, 
                                getChildren, 
                                branchOpen: false, 
                                branchClose: false, 
                                leaf: true);
    }



    public static IEnumerable<T> Traverse<T>(T root,
                       Func<T, IEnumerable<T>> getChildren,
                                       Boolean branchOpen,
                                       Boolean branchClose,
                                       Boolean leaf)
    {
        // the if-checks on the constant booleans are
        // almost free because of modern branch predictors
        
        var stack = new Stack<IEnumerator<T>>();
        var it    = root.AsSingleEnumerator();
        it.MoveNext();

        while (true)
        {
            //////// going down //////// 

            while (true)
            {
                var cit = getChildren(it.Current).GetEnumerator();

                if (cit.MoveNext()) // node has children, must be a branch (and not a leaf)
                {
                    if (branchOpen)
                        yield return it.Current;

                    stack.Push(it);
                    it = cit;
                }
                else // no children, hence a leaf
                {
                    var node = it.Current;

                    if (leaf)
                        yield return node;

                    if (!it.MoveNext())
                        break; // no more siblings: goto parent
                }
            }

            //////// going up //////// 

            while (true)
            {
                it.Dispose();
                if (stack.Count == 0) yield break; // we got to the bottom of the stack, we're done

                it = stack.Pop();

                var node = it.Current;

                if (branchClose)
                    yield return node; // we've seen all its children: close the branch

                if (it.MoveNext())
                    break;
            }
        }
    }

    public static IEnumerable<(T node, IEnumerable<T> path)> TraverseWithPath<T>(T root,
                                                           Func<T, IEnumerable<T>> getChildren,
                                                                           Boolean branchOpen,
                                                                           Boolean branchClose,
                                                                           Boolean leaf)
    {
        // the if-checks on the constant booleans are
        // almost free because of modern branch predictors
        
        var stack = new Stack<IEnumerator<T>>();
        var it    = root.AsSingleEnumerator();
        it.MoveNext();

        while (true)
        {
            //////// going down //////// 

            while (true)
            {
                var cit = getChildren(it.Current).GetEnumerator();

                if (cit.MoveNext()) // node has children, must be a branch (and not a leaf)
                {
                    if (branchOpen)
                        yield return (node: it.Current, path: stack.Select(e => e.Current));

                    stack.Push(it);
                    it = cit;
                }
                else // no children, hence a leaf
                {
                    if (leaf)
                        yield return (node: it.Current, path: stack.Select(e => e.Current));

                    if (!it.MoveNext())
                        break; // no more siblings: goto parent
                }
            }

            //////// going up //////// 

            while (true)
            {
                it.Dispose();
                if (stack.Count == 0) yield break; // we got to the bottom of the stack, we're done

                it = stack.Pop();

                if (branchClose)
                    yield return (node: it.Current, path: stack.Select(e => e.Current)); // we've seen all its children: close the branch

                if (it.MoveNext())
                    break;
            }
        }
    }
    

    public static IEnumerable<T> TraverseBreadthFirst<T>(this T node, Func<T, IEnumerable<T>> getChildren)
    {
        var queue = new Queue<IEnumerator<T>>();
    
        var iterator = node.AsSingleEnumerator();
        
        while(true)
        {
            while (iterator.MoveNext())
            {
                yield return iterator.Current;
    
                iterator.Current
                        .Apply(getChildren)
                        .GetEnumerator()
                        .Apply(queue.Enqueue);
            }
    
            iterator.Dispose();
    
            if (queue.Count == 0)
                yield break;
            
            iterator = queue.Dequeue();
        }
        
    }
}