Skip to content

Commit aff927b

Browse files
authored
Write yield in C++ like Python (#4167)
See test for example code.
1 parent 485c023 commit aff927b

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed

csrc/utils.h

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <c10/core/thread_pool.h>
2222

2323
#include <concepts>
24+
#include <coroutine>
2425
#include <deque>
2526
#include <iterator>
2627
#include <memory>
@@ -843,4 +844,136 @@ using views::zip;
843844

844845
#endif // C++23
845846

847+
// Helper: turn T into reference_wrapper<U> if T is reference
848+
template <typename T>
849+
using Yielded = std::conditional_t<
850+
std::is_reference_v<T>,
851+
std::reference_wrapper<std::remove_reference_t<T>>,
852+
T>;
853+
854+
// Writing yield in C++20 just like Python:
855+
// See NVFuserTest.Generator[1-5] for usage examples
856+
template <typename T>
857+
class Generator : public std::ranges::view_interface<Generator<T>> {
858+
public:
859+
struct promise_type;
860+
using handle_type = std::coroutine_handle<promise_type>;
861+
using stored_type = Yielded<T>;
862+
863+
Generator(handle_type h) : coroutine_(h) {}
864+
Generator(Generator&& other) noexcept : coroutine_(other.coroutine_) {
865+
other.coroutine_ = nullptr;
866+
}
867+
Generator& operator=(Generator&& other) noexcept {
868+
if (this != &other) {
869+
if (coroutine_) {
870+
coroutine_.destroy();
871+
}
872+
coroutine_ = other.coroutine_;
873+
other.coroutine_ = nullptr;
874+
}
875+
return *this;
876+
}
877+
~Generator() {
878+
if (coroutine_) {
879+
coroutine_.destroy();
880+
}
881+
}
882+
Generator(const Generator&) = delete;
883+
Generator& operator=(const Generator&) = delete;
884+
885+
struct iterator {
886+
using value_type = std::remove_reference_t<T>;
887+
using reference = T;
888+
using difference_type = std::ptrdiff_t;
889+
using iterator_category = std::input_iterator_tag;
890+
891+
iterator() = default;
892+
explicit iterator(handle_type h) : coroutine(h) {
893+
++(*this);
894+
}
895+
896+
reference operator*() const {
897+
if constexpr (std::is_reference_v<T>) {
898+
return value->get(); // unwrap reference_wrapper<T>
899+
} else {
900+
return *value;
901+
}
902+
}
903+
904+
iterator& operator++() {
905+
coroutine.resume();
906+
if (coroutine.done()) {
907+
if (coroutine.promise().exception) {
908+
std::rethrow_exception(coroutine.promise().exception);
909+
}
910+
value.reset();
911+
} else {
912+
value = std::ref(coroutine.promise().current_value);
913+
}
914+
return *this;
915+
}
916+
917+
iterator operator++(int) {
918+
auto tmp = *this;
919+
++(*this);
920+
return tmp;
921+
}
922+
bool operator==(std::default_sentinel_t) const {
923+
return !value.has_value();
924+
}
925+
bool operator!=(std::default_sentinel_t) const {
926+
return value.has_value();
927+
}
928+
friend bool operator==(std::default_sentinel_t s, const iterator& it) {
929+
return it == s;
930+
}
931+
friend bool operator!=(std::default_sentinel_t s, const iterator& it) {
932+
return it != s;
933+
}
934+
935+
handle_type coroutine = nullptr;
936+
std::optional<stored_type> value;
937+
};
938+
939+
iterator begin() const {
940+
return iterator{coroutine_};
941+
}
942+
std::default_sentinel_t end() const {
943+
return {};
944+
}
945+
946+
private:
947+
handle_type coroutine_;
948+
949+
public:
950+
struct promise_type {
951+
std::optional<stored_type> current_value;
952+
std::exception_ptr exception;
953+
954+
auto get_return_object() {
955+
return Generator{handle_type::from_promise(*this)};
956+
}
957+
std::suspend_always initial_suspend() {
958+
return {};
959+
}
960+
std::suspend_always final_suspend() noexcept {
961+
return {};
962+
}
963+
std::suspend_always yield_value(T value) {
964+
if constexpr (std::is_reference_v<T>) {
965+
current_value = std::ref(value); // wraps T& as reference_wrapper
966+
} else {
967+
current_value = std::move(value);
968+
}
969+
return {};
970+
}
971+
972+
void return_void() {}
973+
void unhandled_exception() {
974+
exception = std::current_exception();
975+
}
976+
};
977+
};
978+
846979
} // namespace nvfuser

tests/cpp/test_utils.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <scheduler/vectorize_helper.h>
1919
#include <tests/cpp/utils.h>
2020
#include <tests/cpp/validator.h>
21+
#include <utils.h>
2122

2223
#include <cstdlib>
2324
#include <filesystem>
@@ -2072,4 +2073,101 @@ TEST_F(TestCpp23BackPort, Enumerate) {
20722073
// bidirectional
20732074
}
20742075

2076+
namespace {
2077+
2078+
// Generator that yields integers from 0 to n-1
2079+
Generator<int> zeroToN(int n) {
2080+
for (int i = 0; i < n; ++i) {
2081+
co_yield i;
2082+
}
2083+
}
2084+
2085+
// Generator that yields integers from n to 2*n - 1
2086+
Generator<int> nTo2N(int n) {
2087+
for (int i = n; i < 2 * n; ++i) {
2088+
co_yield i;
2089+
}
2090+
}
2091+
2092+
// Generator that yields integers from m to m + 2*n - 1
2093+
Generator<int> mTo2NplusM(int n, int m) {
2094+
for (auto x : zeroToN(n)) {
2095+
co_yield x + m;
2096+
}
2097+
for (auto x : nTo2N(n)) {
2098+
co_yield x + m;
2099+
}
2100+
}
2101+
2102+
// Generator that yields references
2103+
Generator<int&> items(std::vector<int>& v) {
2104+
for (auto& x : v) {
2105+
co_yield x;
2106+
}
2107+
}
2108+
2109+
} // namespace
2110+
2111+
TEST_F(NVFuserTest, Generator1) {
2112+
static_assert(std::ranges::view<decltype(zeroToN(10))>);
2113+
std::vector<int> generated;
2114+
for (auto x : zeroToN(10) |
2115+
std::views::filter([](int x) { return x % 2 == 0; }) |
2116+
std::views::transform([](int x) { return x * x; })) {
2117+
generated.push_back(x);
2118+
}
2119+
std::vector<int> expect{0, 4, 16, 36, 64};
2120+
EXPECT_EQ(generated, expect);
2121+
}
2122+
2123+
TEST_F(NVFuserTest, Generator2) {
2124+
static_assert(std::ranges::view<decltype(mTo2NplusM(10, 10))>);
2125+
std::vector<int> generated;
2126+
for (auto x : mTo2NplusM(10, 10)) {
2127+
generated.push_back(x);
2128+
}
2129+
std::vector<int> expect{10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
2130+
20, 21, 22, 23, 24, 25, 26, 27, 28, 29};
2131+
EXPECT_EQ(generated, expect);
2132+
}
2133+
2134+
TEST_F(NVFuserTest, Generator3) {
2135+
std::vector<int> v{0, 0, 0, 0, 0};
2136+
for (auto&& [i, x] : enumerate(items(v))) {
2137+
x = i * 10;
2138+
}
2139+
std::vector<int> expect{0, 10, 20, 30, 40};
2140+
EXPECT_EQ(v, expect);
2141+
}
2142+
2143+
TEST_F(NVFuserTest, Generator4) {
2144+
auto one2five = []() -> Generator<int> {
2145+
for (int i = 1; i <= 5; ++i) {
2146+
co_yield i;
2147+
}
2148+
};
2149+
std::vector<int> v;
2150+
for (auto x : one2five()) {
2151+
v.push_back(x);
2152+
}
2153+
std::vector<int> expect{1, 2, 3, 4, 5};
2154+
EXPECT_EQ(v, expect);
2155+
}
2156+
2157+
TEST_F(NVFuserTest, Generator5) {
2158+
auto excepted_exception = []() -> Generator<int> {
2159+
co_yield 1;
2160+
throw std::runtime_error("Hello, world!");
2161+
co_yield 2;
2162+
};
2163+
auto run_generator = [&]() {
2164+
for (auto x : excepted_exception()) {
2165+
EXPECT_EQ(x, 1);
2166+
}
2167+
};
2168+
EXPECT_THAT(
2169+
run_generator,
2170+
::testing::ThrowsMessage<std::runtime_error>("Hello, world!"));
2171+
}
2172+
20752173
} // namespace nvfuser

0 commit comments

Comments
 (0)