238 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			C#
		
	
	
	
			
		
		
	
	
			238 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			C#
		
	
	
	
| #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
 | |
| 
 | |
| using System;
 | |
| using System.Collections.Generic;
 | |
| using System.Threading;
 | |
| using Cysharp.Threading.Tasks.Internal;
 | |
| 
 | |
| namespace Cysharp.Threading.Tasks
 | |
| {
 | |
|     public partial struct UniTask
 | |
|     {
 | |
|         public static UniTask<T[]> WhenAll<T>(params UniTask<T>[] tasks)
 | |
|         {
 | |
|             if (tasks.Length == 0)
 | |
|             {
 | |
|                 return UniTask.FromResult(Array.Empty<T>());
 | |
|             }
 | |
| 
 | |
|             return new UniTask<T[]>(new WhenAllPromise<T>(tasks, tasks.Length), 0);
 | |
|         }
 | |
| 
 | |
|         public static UniTask<T[]> WhenAll<T>(IEnumerable<UniTask<T>> tasks)
 | |
|         {
 | |
|             using (var span = ArrayPoolUtil.Materialize(tasks))
 | |
|             {
 | |
|                 var promise = new WhenAllPromise<T>(span.Array, span.Length); // consumed array in constructor.
 | |
|                 return new UniTask<T[]>(promise, 0);
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         public static UniTask WhenAll(params UniTask[] tasks)
 | |
|         {
 | |
|             if (tasks.Length == 0)
 | |
|             {
 | |
|                 return UniTask.CompletedTask;
 | |
|             }
 | |
| 
 | |
|             return new UniTask(new WhenAllPromise(tasks, tasks.Length), 0);
 | |
|         }
 | |
| 
 | |
|         public static UniTask WhenAll(IEnumerable<UniTask> tasks)
 | |
|         {
 | |
|             using (var span = ArrayPoolUtil.Materialize(tasks))
 | |
|             {
 | |
|                 var promise = new WhenAllPromise(span.Array, span.Length); // consumed array in constructor.
 | |
|                 return new UniTask(promise, 0);
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         sealed class WhenAllPromise<T> : IUniTaskSource<T[]>
 | |
|         {
 | |
|             T[] result;
 | |
|             int completeCount;
 | |
|             UniTaskCompletionSourceCore<T[]> core; // don't reset(called after GetResult, will invoke TrySetException.)
 | |
| 
 | |
|             public WhenAllPromise(UniTask<T>[] tasks, int tasksLength)
 | |
|             {
 | |
|                 TaskTracker.TrackActiveTask(this, 3);
 | |
| 
 | |
|                 this.completeCount = 0;
 | |
| 
 | |
|                 if (tasksLength == 0)
 | |
|                 {
 | |
|                     this.result = Array.Empty<T>();
 | |
|                     core.TrySetResult(result);
 | |
|                     return;
 | |
|                 }
 | |
| 
 | |
|                 this.result = new T[tasksLength];
 | |
| 
 | |
|                 for (int i = 0; i < tasksLength; i++)
 | |
|                 {
 | |
|                     UniTask<T>.Awaiter awaiter;
 | |
|                     try
 | |
|                     {
 | |
|                         awaiter = tasks[i].GetAwaiter();
 | |
|                     }
 | |
|                     catch (Exception ex)
 | |
|                     {
 | |
|                         core.TrySetException(ex);
 | |
|                         continue;
 | |
|                     }
 | |
| 
 | |
|                     if (awaiter.IsCompleted)
 | |
|                     {
 | |
|                         TryInvokeContinuation(this, awaiter, i);
 | |
|                     }
 | |
|                     else
 | |
|                     {
 | |
|                         awaiter.SourceOnCompleted(state =>
 | |
|                         {
 | |
|                             using (var t = (StateTuple<WhenAllPromise<T>, UniTask<T>.Awaiter, int>)state)
 | |
|                             {
 | |
|                                 TryInvokeContinuation(t.Item1, t.Item2, t.Item3);
 | |
|                             }
 | |
|                         }, StateTuple.Create(this, awaiter, i));
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             static void TryInvokeContinuation(WhenAllPromise<T> self, in UniTask<T>.Awaiter awaiter, int i)
 | |
|             {
 | |
|                 try
 | |
|                 {
 | |
|                     self.result[i] = awaiter.GetResult();
 | |
|                 }
 | |
|                 catch (Exception ex)
 | |
|                 {
 | |
|                     self.core.TrySetException(ex);
 | |
|                     return;
 | |
|                 }
 | |
| 
 | |
|                 if (Interlocked.Increment(ref self.completeCount) == self.result.Length)
 | |
|                 {
 | |
|                     self.core.TrySetResult(self.result);
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             public T[] GetResult(short token)
 | |
|             {
 | |
|                 TaskTracker.RemoveTracking(this);
 | |
|                 GC.SuppressFinalize(this);
 | |
|                 return core.GetResult(token);
 | |
|             }
 | |
| 
 | |
|             void IUniTaskSource.GetResult(short token)
 | |
|             {
 | |
|                 GetResult(token);
 | |
|             }
 | |
| 
 | |
|             public UniTaskStatus GetStatus(short token)
 | |
|             {
 | |
|                 return core.GetStatus(token);
 | |
|             }
 | |
| 
 | |
|             public UniTaskStatus UnsafeGetStatus()
 | |
|             {
 | |
|                 return core.UnsafeGetStatus();
 | |
|             }
 | |
| 
 | |
|             public void OnCompleted(Action<object> continuation, object state, short token)
 | |
|             {
 | |
|                 core.OnCompleted(continuation, state, token);
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         sealed class WhenAllPromise : IUniTaskSource
 | |
|         {
 | |
|             int completeCount;
 | |
|             int tasksLength;
 | |
|             UniTaskCompletionSourceCore<AsyncUnit> core; // don't reset(called after GetResult, will invoke TrySetException.)
 | |
| 
 | |
|             public WhenAllPromise(UniTask[] tasks, int tasksLength)
 | |
|             {
 | |
|                 TaskTracker.TrackActiveTask(this, 3);
 | |
| 
 | |
|                 this.tasksLength = tasksLength;
 | |
|                 this.completeCount = 0;
 | |
| 
 | |
|                 if (tasksLength == 0)
 | |
|                 {
 | |
|                     core.TrySetResult(AsyncUnit.Default);
 | |
|                     return;
 | |
|                 }
 | |
| 
 | |
|                 for (int i = 0; i < tasksLength; i++)
 | |
|                 {
 | |
|                     UniTask.Awaiter awaiter;
 | |
|                     try
 | |
|                     {
 | |
|                         awaiter = tasks[i].GetAwaiter();
 | |
|                     }
 | |
|                     catch (Exception ex)
 | |
|                     {
 | |
|                         core.TrySetException(ex);
 | |
|                         continue;
 | |
|                     }
 | |
| 
 | |
|                     if (awaiter.IsCompleted)
 | |
|                     {
 | |
|                         TryInvokeContinuation(this, awaiter);
 | |
|                     }
 | |
|                     else
 | |
|                     {
 | |
|                         awaiter.SourceOnCompleted(state =>
 | |
|                         {
 | |
|                             using (var t = (StateTuple<WhenAllPromise, UniTask.Awaiter>)state)
 | |
|                             {
 | |
|                                 TryInvokeContinuation(t.Item1, t.Item2);
 | |
|                             }
 | |
|                         }, StateTuple.Create(this, awaiter));
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             static void TryInvokeContinuation(WhenAllPromise self, in UniTask.Awaiter awaiter)
 | |
|             {
 | |
|                 try
 | |
|                 {
 | |
|                     awaiter.GetResult();
 | |
|                 }
 | |
|                 catch (Exception ex)
 | |
|                 {
 | |
|                     self.core.TrySetException(ex);
 | |
|                     return;
 | |
|                 }
 | |
| 
 | |
|                 if (Interlocked.Increment(ref self.completeCount) == self.tasksLength)
 | |
|                 {
 | |
|                     self.core.TrySetResult(AsyncUnit.Default);
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             public void GetResult(short token)
 | |
|             {
 | |
|                 TaskTracker.RemoveTracking(this);
 | |
|                 GC.SuppressFinalize(this);
 | |
|                 core.GetResult(token);
 | |
|             }
 | |
| 
 | |
|             public UniTaskStatus GetStatus(short token)
 | |
|             {
 | |
|                 return core.GetStatus(token);
 | |
|             }
 | |
| 
 | |
|             public UniTaskStatus UnsafeGetStatus()
 | |
|             {
 | |
|                 return core.UnsafeGetStatus();
 | |
|             }
 | |
| 
 | |
|             public void OnCompleted(Action<object> continuation, object state, short token)
 | |
|             {
 | |
|                 core.OnCompleted(continuation, state, token);
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 |