From 1ac8f98d90f053bc2b0c9d30e975e851950a4e70 Mon Sep 17 00:00:00 2001 From: Laura Andelare Date: Sun, 15 Oct 2023 13:11:32 +0000 Subject: [PATCH] Add TArray> overloads to WhenAny and WhenAll --- .../UE5Coro/Private/AggregateAwaiters.cpp | 26 ++++++++++ .../Public/UE5Coro/AggregateAwaiters.h | 18 ++++++- .../Private/AggregateAwaiterTest.cpp | 52 +++++++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/Plugins/UE5Coro/Source/UE5Coro/Private/AggregateAwaiters.cpp b/Plugins/UE5Coro/Source/UE5Coro/Private/AggregateAwaiters.cpp index 484ab9e3..9a418b9d 100644 --- a/Plugins/UE5Coro/Source/UE5Coro/Private/AggregateAwaiters.cpp +++ b/Plugins/UE5Coro/Source/UE5Coro/Private/AggregateAwaiters.cpp @@ -41,6 +41,18 @@ int FAggregateAwaiter::GetResumerIndex() const return Data->Index; } +template +FAggregateAwaiter::FAggregateAwaiter(T All, const TArray>& Coroutines) + : Data(std::make_shared(All.value ? Coroutines.Num() : !!Coroutines.Num())) +{ + for (int i = 0; i < Coroutines.Num(); ++i) + Consume(Data, i, Coroutines[i]); +} +template UE5CORO_API FAggregateAwaiter::FAggregateAwaiter( + std::false_type, const TArray>&); +template UE5CORO_API FAggregateAwaiter::FAggregateAwaiter( + std::true_type, const TArray>&); + bool FAggregateAwaiter::await_ready() { checkf(Data, TEXT("Attempting to await moved-from aggregate awaiter")); @@ -64,11 +76,25 @@ void FAggregateAwaiter::Suspend(FPromise& Promise) Data->Lock.unlock(); } +#if UE5CORO_CPP20 +FAnyAwaiter UE5Coro::WhenAny(const TArray>& Coroutines) +{ + return FAnyAwaiter(std::false_type(), Coroutines); +} +#endif + FRaceAwaiter UE5Coro::Race(TArray> Array) { return FRaceAwaiter(std::move(Array)); } +#if UE5CORO_CPP20 +FAllAwaiter UE5Coro::WhenAll(const TArray>& Coroutines) +{ + return FAllAwaiter(std::true_type(), Coroutines); +} +#endif + FRaceAwaiter::FRaceAwaiter(TArray>&& Array) : Data(std::make_shared(std::move(Array))) { diff --git a/Plugins/UE5Coro/Source/UE5Coro/Public/UE5Coro/AggregateAwaiters.h b/Plugins/UE5Coro/Source/UE5Coro/Public/UE5Coro/AggregateAwaiters.h index 0d4bb33e..257c6f89 100644 --- a/Plugins/UE5Coro/Source/UE5Coro/Public/UE5Coro/AggregateAwaiters.h +++ b/Plugins/UE5Coro/Source/UE5Coro/Public/UE5Coro/AggregateAwaiters.h @@ -35,6 +35,7 @@ #include "UE5Coro/Definitions.h" #include #include "UE5Coro/AsyncCoroutine.h" +#include "UE5Coro/CoroutineAwaiters.h" #include "UE5Coro/Private.h" namespace UE5Coro::Private @@ -61,12 +62,19 @@ concept TAggregateAwaitable = namespace UE5Coro { /** co_awaits all parameters, resumes its own awaiting coroutine when the first - * one of them finishes. + * one of them finishes.
* The result of the co_await expression is the index of the parameter that * finished first. */ template Private::FAnyAwaiter WhenAny(T&&...); +#if UE5CORO_CPP20 +/** Resumes the awaiting coroutine when all other coroutines have completed.
+ * The result of the co_await expression is the index of the parameter that + * finished first. */ +UE5CORO_API Private::FAnyAwaiter WhenAny(const TArray>&); +#endif + /** co_awaits all coroutines in the array. * The first one to finish cancels the others and resumes the caller. * The result of the co_await expression is the array index of the coroutine @@ -84,6 +92,11 @@ Private::FRaceAwaiter Race(TCoroutine... Args); * of them finish. */ template Private::FAllAwaiter WhenAll(T&&...); + +#if UE5CORO_CPP20 +/** Resumes the awaiting coroutine when all other coroutines have completed. */ +UE5CORO_API Private::FAllAwaiter WhenAll(const TArray>&); +#endif } namespace UE5Coro::Private @@ -118,6 +131,9 @@ class [[nodiscard]] UE5CORO_API FAggregateAwaiter (Consume(Data, Idx++, std::forward(Awaiters)), ...); } + template + explicit FAggregateAwaiter(T, const TArray>& Coroutines); + bool await_ready(); void Suspend(FPromise&); }; diff --git a/Plugins/UE5Coro/Source/UE5CoroTests/Private/AggregateAwaiterTest.cpp b/Plugins/UE5Coro/Source/UE5CoroTests/Private/AggregateAwaiterTest.cpp index 12553deb..6cd876da 100644 --- a/Plugins/UE5Coro/Source/UE5CoroTests/Private/AggregateAwaiterTest.cpp +++ b/Plugins/UE5Coro/Source/UE5CoroTests/Private/AggregateAwaiterTest.cpp @@ -34,6 +34,7 @@ #include "Misc/AutomationTest.h" #include "UE5Coro/AggregateAwaiters.h" #include "UE5Coro/CoroutineAwaiters.h" +#include "UE5Coro/Threading.h" using namespace UE5Coro; using namespace UE5Coro::Private::Test; @@ -232,6 +233,57 @@ void DoTest(FAutomationTestBase& Test) Test.TestEqual(TEXT("State"), State, 2); Test.TestEqual(TEXT("Return value"), Coro.GetResult(), 1); } + +#if UE5CORO_CPP20 + { + int State = 0; + FAwaitableEvent Event(EEventMode::ManualReset); + World.Run(CORO + { + TArray> Coros; + for (int i = 0; i < 10; ++i) + Coros.Add(World.Run(CORO + { + ++State; + co_await Event; + ++State; + })); + Test.TestEqual(TEXT("Initial state inside"), State, 10); + co_await WhenAll(Coros); + Test.TestEqual(TEXT("Final state inside"), State, 20); + ++State; + }); + Test.TestEqual(TEXT("Initial state outside"), State, 10); + Event.Trigger(); + Test.TestEqual(TEXT("Final state outside"), State, 21); + } + + { + int State = 0; + FAwaitableEvent Event(EEventMode::AutoReset); + World.Run(CORO + { + TArray> Coros; + for (int i = 0; i < 10; ++i) + Coros.Add(World.Run(CORO + { + ++State; + co_await Event; + ++State; + })); + Test.TestEqual(TEXT("Initial state inside"), State, 10); + co_await WhenAny(Coros); + Test.TestEqual(TEXT("Final state inside"), State, 11); + ++State; + }); + Test.TestEqual(TEXT("Initial state outside"), State, 10); + for (int i = 0; i < 10; ++i) + { + Event.Trigger(); + Test.TestEqual(TEXT("State outside"), State, i + 12); + } + } +#endif } }