using static System.Threading.Tasks.TaskContinuationOptions;
using static System.Threading.Tasks.TaskStatus;

namespace InnovEnergy.Lib.Utils;

public static class TaskUtils
{



    public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout)
    {
        // We need to be able to cancel the "timeout" task, so create a token source
        var cts = new CancellationTokenSource();

        var timeoutTask = Task.Delay(timeout, cts.Token);

        var completedTask = await Task
                                 .WhenAny(task, timeoutTask)
                                 .ConfigureAwait(false);

        if (completedTask != task)
            throw new TimeoutException($"Task timed out after {timeout}");
       
        cts.Cancel();            // Cancel the "timeout" task so we don't leak a Timer
            
        return await task;       // await the task to bubble up any errors etc
    }

    public static void SupressAwaitWarning(this Task task)
    {
    }


    public static void SupressAwaitWarning<T>(this Task<T> task)
    {
    }

    
    
    
    public static Task WhenAny(this IEnumerable<Task> tasks)
    {
        return Task.WhenAny(tasks).Unwrap();
    }
    
    
    public static Task<TResult> WhenAny<TResult>(this IEnumerable<Task<TResult>> tasks)
    {
        return Task.WhenAny(tasks).Unwrap();
    }
    

    public static Task<TResult[]> WhenAll<TResult>(this IEnumerable<Task<TResult>> tasks)
    {
        return Task.WhenAll(tasks);
    }

    
    public static Task WhenAll(this IEnumerable<Task> tasks)
    {
        return Task.WhenAll(tasks);
    }


   
    public static Task<T> OnError<T>(this Task<T> task, Func<Exception, Task<T>> action)
    {
        return task
              .ContinueWith(t => action(t.Exception!), OnlyOnFaulted)
              .Unwrap();
    }

    public static Task<T> OnError<T>(this Task<T> task, Func<Exception, T> onError)
    {
        return task.ContinueWith(t => onError(t.Exception!), OnlyOnFaulted);
    }
    
    public static Task<T> OnError<T>(this Task<T> task, Func<T> onError)
    {
        return task.ContinueWith(_ => onError(), OnlyOnFaulted);
    }
    
    public static Task<T> OnError<T>(this Task<T> task, T onError)
    {
        return task.ContinueWith(_ => onError, OnlyOnFaulted);
    }
    
    public static Task<T> OnError<T>(this Task<T> task, Action onError)
    {
        Task<T> DoIt(Task<T> t)
        {
            onError();
            return t;
        }

        return task.ContinueWith(DoIt, OnlyOnFaulted).Unwrap();
    }
    
    public static Task<T> OnError<T>(this Task<T> task, Action<Exception> onError)
    {
        Task<T> DoIt(Task<T> t)
        {
            onError(t.Exception!);
            return t;
        }

        return task.ContinueWith(DoIt, OnlyOnFaulted).Unwrap();
    }
    
    public static Task<T> OnError<T>(this Task<T> task, Func<Exception, Exception> action)
    {
        return task
              .ContinueWith(t => Task.FromException<T>(action(t.Exception!)), OnlyOnFaulted)
              .Unwrap();
    }
    
    
    
    public static Task<R> Then<T, R>(this Task<T> task, Func<T, Task<R>> func)
    {
        return task.ContinueWith(t => func(t.Result), OnlyOnRanToCompletion).Unwrap();
    }
    
    public static Task<R> Then<T, R>(this Task<T> task, Func<T, R> func)
    {
        return task.ContinueWith(t => func(t.Result), OnlyOnRanToCompletion);
    }
    
    public static Task<R> Then<T,R>(this Task<T> task, Func<R> func)
    {
        return task.ContinueWith(_ => func(), OnlyOnRanToCompletion);
    }
    
    public static Task<R> Then<T,R>(this Task<T> task, R r)
    {
        return task.ContinueWith(_ => r, OnlyOnRanToCompletion);
    }

    

    public static Task<T> Then<T>(this Task<T> task, Action action)
    {
        T DoIt(Task<T> t)
        {
            action();
            return t.Result;
        }

        return task.ContinueWith(DoIt, OnlyOnRanToCompletion);
    }
    
    public static Task<T> Then<T>(this Task<T> task, Action<T> action)
    {
        T DoIt(Task<T> t)
        {
            action(t.Result);
            return t.Result;
        }
        
        return task.ContinueWith(DoIt, OnlyOnRanToCompletion);
    }
    
    
    
    
    public static Task<R> Match<T, R>(this Task<T> task, 
                                        Func<T, R> onSuccess,
                                Func<Exception, R> onError,
                                           Func<R> onAborted)
    {
        return task.ContinueWith(t => t.Status switch
        {
            RanToCompletion => onSuccess(t.Result),
            Faulted         => onError(t.Exception!),
            _               => onAborted(),
        });
    }
    

    
    // public static Task<T> Do<T>(this Task<T> task, 
    //                                Action<T> onSuccess,
    //                        Action<Exception> onError,
    //                                   Action onAborted)
    // {
    //     return task.ContinueWith(t => t.Status switch
    //     {
    //         RanToCompletion => onSuccess(t.Result),
    //         Faulted         => onError(t.Exception!),
    //         _               => onAborted(),
    //     });
    // }
    
    
    
    // public static Task<R> MatchAsync<T, R>(this Task<T> task, 
    //                                    Func<T, Task<R>>? onSuccess,
    //                            Func<Exception, Task<R>>? onError,
    //                                       Func<Task<R>>? onAborted)
    // {
    //     return task.ContinueWith(t => t.Status switch
    //     {
    //         RanToCompletion => onSuccess is not null ? onSuccess(t.Result) : ,
    //         Faulted         => onError(t.Exception!),
    //         _               => onAborted()
    //     }).Unwrap();
    // }
    
    
    public static Task<R> Return<R>(this Task task, R r)
    {
        return task.ContinueWith(_ => r, OnlyOnRanToCompletion);
    }
    
    
    public static Task Catch(this Task task, Action<Exception> onException)
    {
        return task.ContinueWith(t => onException(t.Exception!), OnlyOnFaulted);
    }
    
    public static Task<T> Catch<T>(this Task<T> task, T onException)
    {
        return task.ContinueWith(_ => onException, OnlyOnFaulted);
    }
    
    public static Task<T> Catch<T>(this Task<T> task, Func<Exception, T> onException)
    {
        return task.ContinueWith(t => onException(t.Exception!), OnlyOnFaulted);
    }
    
   

}