using Microsoft.CodeAnalysis;

namespace InnovEnergy.Lib.SrcGen;

public readonly ref struct Rewriter<TRoot, TNode> where TRoot : SyntaxNode where TNode : SyntaxNode
{
    internal readonly TRoot Root;
    internal readonly IEnumerable<TNode> Descendants;

    internal Rewriter(TRoot root, IEnumerable<TNode> descendants)
    {
        Root = root;
        Descendants = descendants;
    }

    public Rewriter<TRoot, TNode> Where(Func<TNode, Boolean> predicate)
    {
        return new Rewriter<TRoot, TNode>(Root, Descendants.Where(predicate));
    }

    public Rewriter<TRoot, T> OfType<T>() where T : TNode
    {
        return new Rewriter<TRoot, T>(Root, Descendants.OfType<T>());
    }
    
    public Rewriter<TRoot, TNode> HasAncestor<T>() where T : SyntaxNode
    {
        return new Rewriter<TRoot, TNode>(Root, Descendants.Where(d => d.Ancestors().OfType<T>().Any()));
    }
    
    public Rewriter<TRoot, TNode> HasParent<T>() where T : SyntaxNode
    {
        return new Rewriter<TRoot, TNode>(Root, Descendants.Where(d => d.Parent is T));
    }

    public Rewriter<TRoot, TNode> SelectNodes(Func<TNode, IEnumerable<TNode>> nodes)
    {
        return new Rewriter<TRoot, TNode>(Root, Descendants.SelectMany(nodes));
    }
    
    public Rewriter<TRoot, T> SelectNode<T>(Func<TNode, T> node) where T: SyntaxNode
    {
        return new Rewriter<TRoot, T>(Root, Descendants.Select(node));
    }

    public Rewriter<TRoot, T> GetAncestor<T>() where T : SyntaxNode
    {
        return SelectNode(n => n.Ancestors().OfType<T>().First());
    }
    
    public TRoot Replace(SyntaxNode syntaxNode)
    {
        return Root.ReplaceNodes(Descendants, (_, _) => syntaxNode);
    }
    
    public TRoot Replace(Func<TNode, SyntaxNode> map)
    {
        return Root.ReplaceNodes(Descendants, (_, n) => map(n));
    }

    public TRoot Remove() => Remove(SyntaxRemoveOptions.KeepNoTrivia);

    public TRoot Remove(SyntaxRemoveOptions options) => Root.RemoveNodes(Descendants, options)!;
}

public static class Rewrite
{
    public static Rewriter<R, SyntaxNode> EditNodes<R>(this R root) where R : SyntaxNode
    {
        return new Rewriter<R, SyntaxNode>(root, root.DescendantNodes());
    }
}