using System; using System.Collections.Generic; using System.Linq; using System.Threading; using Cysharp.Threading.Tasks.Internal; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); return new Merge(new [] { first, second }); } public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IUniTaskAsyncEnumerable third) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); Error.ThrowArgumentNullException(third, nameof(third)); return new Merge(new[] { first, second, third }); } public static IUniTaskAsyncEnumerable Merge(this IEnumerable> sources) { return sources is IUniTaskAsyncEnumerable[] array ? new Merge(array) : new Merge(sources.ToArray()); } public static IUniTaskAsyncEnumerable Merge(params IUniTaskAsyncEnumerable[] sources) { return new Merge(sources); } } internal sealed class Merge : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable[] sources; public Merge(IUniTaskAsyncEnumerable[] sources) { if (sources.Length <= 0) { Error.ThrowArgumentException("No source async enumerable to merge"); } this.sources = sources; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => new _Merge(sources, cancellationToken); enum MergeSourceState { Pending, Running, Completed, } sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action GetResultAtAction = GetResultAt; readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; readonly MergeSourceState[] states; readonly Queue<(T, Exception, bool)> queuedResult = new Queue<(T, Exception, bool)>(); readonly CancellationToken cancellationToken; int moveNextCompleted; public T Current { get; private set; } public _Merge(IUniTaskAsyncEnumerable[] sources, CancellationToken cancellationToken) { this.cancellationToken = cancellationToken; length = sources.Length; states = ArrayPool.Shared.Rent(length); enumerators = ArrayPool>.Shared.Rent(length); for (var i = 0; i < length; i++) { enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); states[i] = (int)MergeSourceState.Pending;; } } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); Interlocked.Exchange(ref moveNextCompleted, 0); if (HasQueuedResult() && Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { (T, Exception, bool) value; lock (states) { value = queuedResult.Dequeue(); } var resultValue = value.Item1; var exception = value.Item2; var hasNext = value.Item3; if (exception != null) { completionSource.TrySetException(exception); } else { Current = resultValue; completionSource.TrySetResult(hasNext); } return new UniTask(this, completionSource.Version); } for (var i = 0; i < length; i++) { lock (states) { if (states[i] == MergeSourceState.Pending) { states[i] = MergeSourceState.Running; } else { continue; } } var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { GetResultAt(i, awaiter); } else { awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); } } return new UniTask(this, completionSource.Version); } public async UniTask DisposeAsync() { for (var i = 0; i < length; i++) { await enumerators[i].DisposeAsync(); } ArrayPool.Shared.Return(states, true); ArrayPool>.Shared.Return(enumerators, true); } static void GetResultAt(object state) { using (var tuple = (StateTuple<_Merge, int, UniTask.Awaiter>)state) { tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); } } void GetResultAt(int index, UniTask.Awaiter awaiter) { bool hasNext; bool completedAll; try { hasNext = awaiter.GetResult(); } catch (Exception ex) { if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { completionSource.TrySetException(ex); } else { lock (states) { queuedResult.Enqueue((default, ex, default)); } } return; } lock (states) { states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; completedAll = !hasNext && IsCompletedAll(); } if (hasNext || completedAll) { if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { Current = enumerators[index].Current; completionSource.TrySetResult(!completedAll); } else { lock (states) { queuedResult.Enqueue((enumerators[index].Current, null, !completedAll)); } } } } bool HasQueuedResult() { lock (states) { return queuedResult.Count > 0; } } bool IsCompletedAll() { lock (states) { for (var i = 0; i < length; i++) { if (states[i] != MergeSourceState.Completed) { return false; } } } return true; } } } }