diff --git a/include/async/generator.hpp b/include/async/generator.hpp new file mode 100644 index 0000000..44a36a3 --- /dev/null +++ b/include/async/generator.hpp @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include + +namespace async { + +template +struct generator { + struct promise_type; + using handle_type = corons::coroutine_handle; + + struct continuation { + virtual void complete() = 0; + }; + + struct promise_type { + std::optional current_value; + continuation *cont = nullptr; + + generator get_return_object() { + return generator{corons::coroutine_handle::from_promise(*this)}; + } + + corons::suspend_always initial_suspend() noexcept { return {}; } + + struct yield_awaiter { + bool await_ready() noexcept { return false; } + + void await_suspend(handle_type h) noexcept { + auto cont = h.promise().cont; + h.promise().cont = nullptr; + if (cont) { + cont->complete(); + return; + } + FRG_INTF(panic)("Generator yielded but no consumer is waiting"); + } + + void await_resume() noexcept {} + }; + + yield_awaiter yield_value(T value) noexcept { + current_value = std::move(value); + return yield_awaiter{}; + } + + yield_awaiter final_suspend() noexcept { + return yield_awaiter{}; + } + + void return_void() {} + + void unhandled_exception() { + FRG_INTF(panic)("Unhandled exception in generator coroutine"); + } + }; + + explicit generator(corons::coroutine_handle h) : h_(h) {} + + generator(generator &&other) : h_(std::exchange(other.h_, nullptr)) {} + + generator &operator=(generator &&other) { + auto h = std::exchange(other.h_, nullptr); + if (h_) + h_.destroy(); + h_ = h; + return *this; + } + + ~generator() { + if (h_) + h_.destroy(); + } + + template + struct next_operation : continuation { + next_operation(handle_type h, Receiver r) + : h_{h}, r_{std::move(r)} {} + + void start() { + h_.promise().cont = this; + h_.resume(); + } + + void complete() override { + auto val = std::move(h_.promise().current_value); + h_.promise().current_value = {}; + execution::set_value(std::move(r_), std::move(val)); + } + + handle_type h_; + Receiver r_; + }; + + struct next_sender { + handle_type h_; + using value_type = std::optional; + + friend sender_awaiter operator co_await (next_sender s) { + return {s}; + } + + template + next_operation connect(Receiver r) { + return {h_, std::move(r)}; + } + }; + + next_sender next() { + return {h_}; + } + +private: + corons::coroutine_handle h_; +}; + +} // namespace async diff --git a/meson.build b/meson.build index 7ee5f6e..c8d1a21 100644 --- a/meson.build +++ b/meson.build @@ -24,6 +24,7 @@ if get_option('install_headers') 'include/async/result.hpp', 'include/async/sequenced-event.hpp', 'include/async/wait-group.hpp', + 'include/async/generator.hpp', subdir : 'async/') pkgconfig.generate( diff --git a/tests/generator.cpp b/tests/generator.cpp new file mode 100644 index 0000000..f55fd0f --- /dev/null +++ b/tests/generator.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include + +namespace { + +async::generator generate_nothing() { + co_return; +} + +async::generator generate_ints() { + co_yield 1; + co_yield 2; + co_yield 3; +} + +async::generator> generate_unique_ptrs() { + co_yield std::make_unique(1); + co_yield std::make_unique(2); + co_yield std::make_unique(3); +} + +} // anonymous namespace + +TEST(Generator, YieldNothing) { + async::run([]() -> async::result { + auto gen = generate_nothing(); + + auto v = co_await gen.next(); + EXPECT_FALSE(v.has_value()); + }()); +} + +TEST(Generator, YieldInts) { + async::run([]() -> async::result { + auto gen = generate_ints(); + + auto v1 = co_await gen.next(); + EXPECT_TRUE(v1.has_value()); + EXPECT_EQ(*v1, 1); + + auto v2 = co_await gen.next(); + EXPECT_TRUE(v2.has_value()); + EXPECT_EQ(*v2, 2); + + auto v3 = co_await gen.next(); + EXPECT_TRUE(v3.has_value()); + EXPECT_EQ(*v3, 3); + + auto v4 = co_await gen.next(); + EXPECT_FALSE(v4.has_value()); + }()); +} + +TEST(Generator, YieldMoveOnly) { + async::run([]() -> async::result { + auto gen = generate_unique_ptrs(); + + auto v1 = co_await gen.next(); + EXPECT_TRUE(v1.has_value()); + EXPECT_EQ(**v1, 1); + + auto v2 = co_await gen.next(); + EXPECT_TRUE(v2.has_value()); + EXPECT_EQ(**v2, 2); + + auto v3 = co_await gen.next(); + EXPECT_TRUE(v3.has_value()); + EXPECT_EQ(**v3, 3); + + auto v4 = co_await gen.next(); + EXPECT_FALSE(v4.has_value()); + }()); +} diff --git a/tests/meson.build b/tests/meson.build index f5711b6..529c858 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -23,6 +23,7 @@ sources = files( 'sequenced.cpp', 'post-ack.cpp', 'with_cancel_cb.cpp', + 'generator.cpp', ) exe = executable('gtests',