using System.Diagnostics.CodeAnalysis;

namespace InnovEnergy.Lib.Utils;

public static class GraphTraversal
{
    public static IEnumerable<T> TraverseDepthFirstPreOrder<T>(T root,
                                         Func<T, IEnumerable<T>> getChildren,
                                           IEqualityComparer<T>? comparer = null)
    {
        return Traverse(root, TreeTraversal.TraverseDepthFirstPreOrder, getChildren, comparer);
    }


    public static IEnumerable<T> TraverseDepthFirstPostOrder<T>(T root,
                                          Func<T, IEnumerable<T>> getChildren,
                                            IEqualityComparer<T>? comparer = null)
    {
        return Traverse(root, TreeTraversal.TraverseDepthFirstPostOrder, getChildren, comparer);
    }

    public static IEnumerable<T> TraverseBreadthFirst<T>(T root,
                                   Func<T, IEnumerable<T>> getChildren,
                                     IEqualityComparer<T>? comparer = null)
    {
        return Traverse(root, TreeTraversal.TraverseBreadthFirst, getChildren, comparer);
    }

    public static IEnumerable<T> TraverseDepthFirstPreOrder<T>(IEnumerable<T> sources,
                                                      Func<T, IEnumerable<T>> getChildren,
                                                        IEqualityComparer<T>? comparer = null)
    {
        return Traverse(sources, TreeTraversal.TraverseDepthFirstPreOrder, getChildren, comparer);
    }


    public static IEnumerable<T> TraverseDepthFirstPostOrder<T>(IEnumerable<T> sources,
                                                       Func<T, IEnumerable<T>> getChildren,
                                                         IEqualityComparer<T>? comparer = null)
    {
        return Traverse(sources, TreeTraversal.TraverseDepthFirstPostOrder, getChildren, comparer);
    }

    public static IEnumerable<T> TraverseBreadthFirst<T>(IEnumerable<T> sources,
                                                Func<T, IEnumerable<T>> getChildren,
                                                  IEqualityComparer<T>? comparer = null)
    {
        return Traverse(sources, TreeTraversal.TraverseBreadthFirst, getChildren, comparer);
    }


    private static IEnumerable<T> Traverse<T>(T root, 
                                              Func<T , Func<T, IEnumerable<T>>,IEnumerable<T>> traversor,
                                              Func<T, IEnumerable<T>> getChildren, 
                                              IEqualityComparer<T>? comparer = null)
    {
        var getUniqueChildren = GetUniqueChildren(getChildren, root, comparer);
        return traversor(root, getUniqueChildren);
    }

    [SuppressMessage("ReSharper", "PossibleMultipleEnumeration")]
    private static IEnumerable<T> Traverse<T>(
        IEnumerable<T> sources, 
        Func<T , Func<T, IEnumerable<T>>,IEnumerable<T>> traversor,
        Func<T, IEnumerable<T>> getChildren,
        IEqualityComparer<T>? comparer = null)
    {
        var set = new HashSet<T>(sources, comparer ?? EqualityComparer<T>.Default);
        IEnumerable<T> GetUniqueChildren(T n) => getChildren(n).Where(set!.Add);

        return from s in sources
               from e in traversor(s, GetUniqueChildren)
               select e;
    }

    // TODO: IEqualityComparer
    public static IEnumerable<(T source, T target)> Edges<T>(T node, Func<T, IEnumerable<T>> getChildren)
    {
        return Edges(node.AsSingleEnumerable(), getChildren);
    }


    // TODO: IEqualityComparer
    public static IEnumerable<(T source, T target)> Edges<T>(IEnumerable<T> sources,
                                                    Func<T, IEnumerable<T>> getChildren)
    {
        var hs = new HashSet<(T source, T target)>();

        IEnumerable<(T source, T target)> GetChildEdges((T source, T target) edge)
        {
            return getChildren(edge.target)
                  .Select(c => (edge.target, c))
                  .Where(hs!.Add);
        }

        return from src in sources.Select(s => (s, s))
               from e in TreeTraversal.TraverseDepthFirstPreOrder(src, GetChildEdges).Skip(1)
               select e;
    }

    private static Func<T, IEnumerable<T>> GetUniqueChildren<T>(Func<T, IEnumerable<T>> getChildren, 
                                                                                      T root, 
                                                                  IEqualityComparer<T>? comparer)
    {
        return GetUniqueChildren(getChildren, root.AsSingleEnumerable(), comparer);
    }

    private static Func<T, IEnumerable<T>> GetUniqueChildren<T>(Func<T, IEnumerable<T>> getChildren, 
                                                                         IEnumerable<T> sources, 
                                                                  IEqualityComparer<T>? comparer)
    {
        var set = new HashSet<T>(sources, comparer ?? EqualityComparer<T>.Default);
        return n => getChildren(n).Where(set.Add);
    }

   
}