using static System.Threading.Tasks.TaskContinuationOptions; using static System.Threading.Tasks.TaskStatus; namespace InnovEnergy.Lib.Utils; public static class TaskUtils { public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout) { // We need to be able to cancel the "timeout" task, so create a token source var cts = new CancellationTokenSource(); var timeoutTask = Task.Delay(timeout, cts.Token); var completedTask = await Task .WhenAny(task, timeoutTask) .ConfigureAwait(false); if (completedTask != task) throw new TimeoutException($"Task timed out after {timeout}"); cts.Cancel(); // Cancel the "timeout" task so we don't leak a Timer return await task; // await the task to bubble up any errors etc } public static void SupressAwaitWarning(this Task task) { } public static void SupressAwaitWarning<T>(this Task<T> task) { } public static Task WhenAny(this IEnumerable<Task> tasks) { return Task.WhenAny(tasks).Unwrap(); } public static Task<TResult> WhenAny<TResult>(this IEnumerable<Task<TResult>> tasks) { return Task.WhenAny(tasks).Unwrap(); } public static Task<TResult[]> WhenAll<TResult>(this IEnumerable<Task<TResult>> tasks) { return Task.WhenAll(tasks); } public static Task WhenAll(this IEnumerable<Task> tasks) { return Task.WhenAll(tasks); } public static Task<T> OnError<T>(this Task<T> task, Func<Exception, Task<T>> action) { return task .ContinueWith(t => action(t.Exception!), OnlyOnFaulted) .Unwrap(); } public static Task<T> OnError<T>(this Task<T> task, Func<Exception, T> onError) { return task.ContinueWith(t => onError(t.Exception!), OnlyOnFaulted); } public static Task<T> OnError<T>(this Task<T> task, Func<T> onError) { return task.ContinueWith(_ => onError(), OnlyOnFaulted); } public static Task<T> OnError<T>(this Task<T> task, T onError) { return task.ContinueWith(_ => onError, OnlyOnFaulted); } public static Task<T> OnError<T>(this Task<T> task, Action onError) { Task<T> DoIt(Task<T> t) { onError(); return t; } return task.ContinueWith(DoIt, OnlyOnFaulted).Unwrap(); } public static Task<T> OnError<T>(this Task<T> task, Action<Exception> onError) { Task<T> DoIt(Task<T> t) { onError(t.Exception!); return t; } return task.ContinueWith(DoIt, OnlyOnFaulted).Unwrap(); } public static Task<T> OnError<T>(this Task<T> task, Func<Exception, Exception> action) { return task .ContinueWith(t => Task.FromException<T>(action(t.Exception!)), OnlyOnFaulted) .Unwrap(); } public static Task<R> Then<T, R>(this Task<T> task, Func<T, Task<R>> func) { return task.ContinueWith(t => func(t.Result), OnlyOnRanToCompletion).Unwrap(); } public static Task<R> Then<T, R>(this Task<T> task, Func<T, R> func) { return task.ContinueWith(t => func(t.Result), OnlyOnRanToCompletion); } public static Task<R> Then<T,R>(this Task<T> task, Func<R> func) { return task.ContinueWith(_ => func(), OnlyOnRanToCompletion); } public static Task<R> Then<T,R>(this Task<T> task, R r) { return task.ContinueWith(_ => r, OnlyOnRanToCompletion); } public static Task<T> Then<T>(this Task<T> task, Action action) { T DoIt(Task<T> t) { action(); return t.Result; } return task.ContinueWith(DoIt, OnlyOnRanToCompletion); } public static Task<T> Then<T>(this Task<T> task, Action<T> action) { T DoIt(Task<T> t) { action(t.Result); return t.Result; } return task.ContinueWith(DoIt, OnlyOnRanToCompletion); } public static Task<R> Match<T, R>(this Task<T> task, Func<T, R> onSuccess, Func<Exception, R> onError, Func<R> onAborted) { return task.ContinueWith(t => t.Status switch { RanToCompletion => onSuccess(t.Result), Faulted => onError(t.Exception!), _ => onAborted(), }); } // public static Task<T> Do<T>(this Task<T> task, // Action<T> onSuccess, // Action<Exception> onError, // Action onAborted) // { // return task.ContinueWith(t => t.Status switch // { // RanToCompletion => onSuccess(t.Result), // Faulted => onError(t.Exception!), // _ => onAborted(), // }); // } // public static Task<R> MatchAsync<T, R>(this Task<T> task, // Func<T, Task<R>>? onSuccess, // Func<Exception, Task<R>>? onError, // Func<Task<R>>? onAborted) // { // return task.ContinueWith(t => t.Status switch // { // RanToCompletion => onSuccess is not null ? onSuccess(t.Result) : , // Faulted => onError(t.Exception!), // _ => onAborted() // }).Unwrap(); // } public static Task<R> Return<R>(this Task task, R r) { return task.ContinueWith(_ => r, OnlyOnRanToCompletion); } public static Task Catch(this Task task, Action<Exception> onException) { return task.ContinueWith(t => onException(t.Exception!), OnlyOnFaulted); } public static Task<T> Catch<T>(this Task<T> task, T onException) { return task.ContinueWith(_ => onException, OnlyOnFaulted); } public static Task<T> Catch<T>(this Task<T> task, Func<Exception, T> onException) { return task.ContinueWith(t => onException(t.Exception!), OnlyOnFaulted); } }