Skip to content

Commit 760ad13

Browse files
committed
Refactored data_generator implementation to allow alternating the internal 'receivers'-collection implementation. Basically to allow injecting a thread safe variant that holds the receivers.
Added 'guarded_data_generator' which uses the exact same implementation as data_generator, besides that the receivers collection is guarded (lock when adding/removing receivers, and atomically swap so that no locking is required when sending data to receivers).
1 parent 6857cf9 commit 760ad13

File tree

5 files changed

+240
-56
lines changed

5 files changed

+240
-56
lines changed

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ export(
3434
)
3535

3636
if(PIPEABLE_BUILD_TESTS)
37-
# ----------- TESTS -----------
37+
38+
find_package(Threads REQUIRED)
39+
3840
enable_testing()
3941

4042
add_subdirectory(catch2)
@@ -43,10 +45,12 @@ if(PIPEABLE_BUILD_TESTS)
4345
"tests/pipeable_tests.cpp"
4446
"tests/data_source_tests.cpp"
4547
"tests/data_generator_tests.cpp"
48+
"tests/guarded_data_generator_tests.cpp"
4649
)
4750
target_link_libraries( pipeable_tests
4851
pipeable
4952
catch2
53+
Threads::Threads
5054
)
5155
add_test(
5256
NAME pipeable_tests

include/pipeable/data_generator.hpp

Lines changed: 93 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,38 @@ namespace pipeable
1414

1515
namespace impl
1616
{
17-
template<typename output_t>
17+
template<typename T>
18+
struct non_threadsafe_receivers : std::vector<T>
19+
{
20+
template<typename callback_t>
21+
void for_each(callback_t&& callback) const
22+
{
23+
for (auto&& downstream : *this)
24+
{
25+
std::forward<callback_t>(callback)(std::forward<decltype(downstream)>(downstream));
26+
}
27+
}
28+
29+
template<typename callback_t>
30+
void modify_list(callback_t&& callback)
31+
{
32+
std::forward<callback_t>(callback)(*this);
33+
}
34+
};
35+
36+
template<typename receiver_t, typename>
37+
struct receivers_t
38+
{
39+
};
40+
41+
struct non_thread_safe{};
42+
template<typename receiver_t>
43+
struct receivers_t<receiver_t, non_thread_safe>
44+
{
45+
using type_t = non_threadsafe_receivers<receiver_t>;
46+
};
47+
48+
template<typename output_t, typename collection_type_tag_t>
1849
struct data_generator_impl : impl::custom_pipeable_tag
1950
{
2051
using downstream_t = std::function<void(output_t)>;
@@ -23,19 +54,18 @@ namespace pipeable
2354
concepts::IsConvertible<arg_t, output_t> = nullptr>
2455
void operator()(arg_t&& arg) const
2556
{
26-
for (auto& downstream : receivers_)
27-
{
57+
receivers_.for_each([&](auto&& downstream) {
2858
if constexpr (std::is_rvalue_reference_v<output_t>)
2959
{
30-
downstream.second(FWD(arg));
60+
std::forward<decltype(downstream)>(downstream).second(std::move(arg));
3161
}
3262
else
3363
{
3464
// Intentionally don't forward here, since if we have multiple receivers and the value is "moved from"
3565
// into the first, remaining receivers will (if it can be moved from) not get the value.
36-
downstream.second(arg);
66+
std::forward<decltype(downstream)>(downstream).second(arg);
3767
}
38-
}
68+
});
3969
}
4070

4171
template<typename callable_t,
@@ -47,20 +77,25 @@ namespace pipeable
4777
{
4878
invocation::invoke(FWD(downstream), FWD(arg));
4979
};
50-
receivers_.emplace_back(id, receiverCall);
80+
receivers_.modify_list([&](auto& receivers) {
81+
receivers.emplace_back(id, receiverCall);
82+
});
5183
}
5284

5385
template<typename callable_t,
5486
concepts::IsInvocable<callable_t, output_t> = nullptr>
5587
void operator-=(callable_t&& downstream)
5688
{
57-
receivers_.erase(std::remove_if(receivers_.begin(), receivers_.end(), [addr = identifier(downstream)](auto&& receiver) {
58-
return receiver.first == addr;
59-
}), receivers_.end());
89+
receivers_.modify_list([&](auto& receivers){
90+
receivers.erase(std::remove_if(receivers.begin(), receivers.end(), [addr = identifier(downstream)](auto&& receiver) {
91+
return receiver.first == addr;
92+
}), receivers.end());
93+
});
6094
}
6195

6296
private:
63-
std::vector<std::pair<const void*, downstream_t>> receivers_;
97+
98+
typename receivers_t<std::pair<const void*, downstream_t>, collection_type_tag_t>::type_t receivers_;
6499

65100
template<typename callable_t>
66101
static constexpr const void* identifier(const callable_t& callable)
@@ -83,53 +118,57 @@ namespace pipeable
83118
using bases_t::operator+=...;
84119
using bases_t::operator-=...;
85120
};
86-
}
87121

88-
template<typename... outputs_t>
89-
struct data_generator : impl::multi_generator_impl<impl::data_generator_impl<outputs_t>...>
90-
{
91-
template<typename callable_t,
92-
concepts::IsInvocableWithAny<callable_t, outputs_t...> = nullptr>
93-
void operator+=(callable_t&& downstream)
122+
template<typename threading_t, typename... outputs_t>
123+
struct multi_generator : impl::multi_generator_impl<impl::data_generator_impl<outputs_t, threading_t>...>
94124
{
95-
// Recursively call += for each output type to find exact base/data_generator
96-
auto callback = [](auto&& base, auto&& downstream){
97-
base->operator+=(FWD(downstream));
98-
};
99-
do_for_each_matching_generator<callable_t, outputs_t...>(FWD(downstream), callback);
100-
}
101-
102-
template<typename callable_t,
103-
concepts::IsInvocableWithAny<callable_t, outputs_t...> = nullptr>
104-
void operator-=(callable_t&& downstream)
105-
{
106-
// Recursively call -= for each output type to find exact base/data_generator
107-
auto callback = [](auto&& base, auto&& downstream)
125+
template<typename callable_t,
126+
concepts::IsInvocableWithAny<callable_t, outputs_t...> = nullptr>
127+
void operator+=(callable_t&& downstream)
108128
{
109-
base->operator-=(FWD(downstream));
110-
};
111-
do_for_each_matching_generator<callable_t, outputs_t...>(FWD(downstream), callback);
112-
}
113-
114-
private:
115-
template<typename callable_t, typename output_t, typename callback_t>
116-
void do_for_each_matching_generator(callable_t&& downstream, callback_t&& callback)
117-
{
118-
// Register callable for matching output only
119-
if constexpr (meta::is_invocable_v<callable_t, output_t>)
129+
// Recursively call += for each output type to find exact base/data_generator
130+
auto callback = [](auto&& base, auto&& downstream) {
131+
base->operator+=(FWD(downstream));
132+
};
133+
do_for_each_matching_generator<callable_t, outputs_t...>(FWD(downstream), callback);
134+
}
135+
136+
template<typename callable_t,
137+
concepts::IsInvocableWithAny<callable_t, outputs_t...> = nullptr>
138+
void operator-=(callable_t&& downstream)
120139
{
121-
// Explicitly call correct base to avoid ambiguity
122-
using exact_base_t = impl::data_generator_impl<output_t>;
123-
callback(static_cast<exact_base_t*>(this), FWD(downstream));
140+
// Recursively call -= for each output type to find exact base/data_generator
141+
auto callback = [](auto&& base, auto&& downstream)
142+
{
143+
base->operator-=(FWD(downstream));
144+
};
145+
do_for_each_matching_generator<callable_t, outputs_t...>(FWD(downstream), callback);
124146
}
125-
}
126147

127-
template<typename callable_t, typename output_t, typename tail_t, typename... tails_t, typename callback_t>
128-
void do_for_each_matching_generator(callable_t&& downstream, callback_t&& callback)
129-
{
130-
do_for_each_matching_generator<callable_t, output_t>(FWD(downstream), FWD(callback));
131-
// Recursively register callable for each (matching) output type
132-
do_for_each_matching_generator<callable_t, tail_t, tails_t...>(FWD(downstream), FWD(callback));
133-
}
134-
};
148+
private:
149+
template<typename callable_t, typename output_t, typename callback_t>
150+
void do_for_each_matching_generator(callable_t&& downstream, callback_t&& callback)
151+
{
152+
// Register callable for matching output only
153+
if constexpr (meta::is_invocable_v<callable_t, output_t>)
154+
{
155+
// Explicitly call correct base to avoid ambiguity
156+
using exact_base_t = impl::data_generator_impl<output_t, threading_t>;
157+
callback(static_cast<exact_base_t*>(this), FWD(downstream));
158+
}
159+
}
160+
161+
template<typename callable_t, typename output_t, typename tail_t, typename... tails_t, typename callback_t>
162+
void do_for_each_matching_generator(callable_t&& downstream, callback_t&& callback)
163+
{
164+
do_for_each_matching_generator<callable_t, output_t>(FWD(downstream), FWD(callback));
165+
// Recursively register callable for each (matching) output type
166+
do_for_each_matching_generator<callable_t, tail_t, tails_t...>(FWD(downstream), FWD(callback));
167+
}
168+
};
169+
}
170+
171+
template<typename... outputs_t>
172+
struct data_generator : impl::multi_generator<impl::non_thread_safe, outputs_t...>
173+
{};
135174
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
#include <pipeable/data_generator.hpp>
4+
#include <memory>
5+
#include <mutex>
6+
#include <vector>
7+
8+
namespace pipeable
9+
{
10+
namespace impl
11+
{
12+
template<typename T>
13+
struct threadsafe_receivers
14+
{
15+
template<typename callback_t>
16+
void for_each(callback_t&& callback) const
17+
{
18+
auto tmp = std::atomic_load(&receivers_);
19+
for (auto&& downstream : *tmp)
20+
{
21+
std::forward<callback_t>(callback)(std::forward<decltype(downstream)>(downstream));
22+
}
23+
}
24+
25+
template<typename callback_t>
26+
void modify_list(callback_t&& callback)
27+
{
28+
std::scoped_lock lock{ mutex_ };
29+
container_t copy = *receivers_;
30+
std::forward<callback_t>(callback)(copy);
31+
std::atomic_store(&receivers_, std::make_shared<container_t>(std::move(copy)));
32+
}
33+
34+
private:
35+
using container_t = std::vector<T>;
36+
37+
std::mutex mutex_;
38+
std::shared_ptr<container_t> receivers_ = std::make_shared<container_t>();
39+
};
40+
41+
struct thread_safe {};
42+
template<typename receiver_t>
43+
struct receivers_t<receiver_t, thread_safe>
44+
{
45+
using type_t = threadsafe_receivers<receiver_t>;
46+
};
47+
}
48+
49+
template<typename... outputs_t>
50+
struct guarded_data_generator : impl::multi_generator<impl::thread_safe, outputs_t...>
51+
{};
52+
}

tests/data_generator_tests.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <pipeable/pipeable.hpp>
21
#include <pipeable/data_generator.hpp>
32

43
#include <catch2/catch.hpp>
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#include <pipeable/guarded_data_generator.hpp>
2+
3+
#include <catch2/catch.hpp>
4+
#include <atomic>
5+
#include <chrono>
6+
#include <thread>
7+
8+
using namespace pipeable;
9+
10+
namespace
11+
{
12+
struct int_to_int
13+
{
14+
int operator()(int val)
15+
{
16+
receivedValue = true;
17+
return val;
18+
}
19+
bool receivedValue = false;
20+
};
21+
struct int_and_string_receiver
22+
{
23+
int receivedInt = 0;
24+
std::string receivedStr = "";
25+
void operator()(int val)
26+
{
27+
receivedInt = val;
28+
}
29+
void operator()(const std::string& val)
30+
{
31+
receivedStr = val;
32+
}
33+
};
34+
}
35+
36+
SCENARIO("Thread safe data generator")
37+
{
38+
GIVEN("a thread safe data generator")
39+
{
40+
guarded_data_generator<int> generator; // Change this to data_generator<...> and it will crash as expected
41+
WHEN("one thread is continuously generating data")
42+
{
43+
std::atomic_bool start = false;
44+
const auto waitForStart = [&] {
45+
while (!start) { std::this_thread::sleep_for(std::chrono::milliseconds{ 1 }); }
46+
};
47+
48+
const auto sendCount = 100;
49+
auto sendingThread = std::thread([&] {
50+
waitForStart();
51+
52+
for (auto i = 0; i < sendCount; ++i)
53+
{
54+
generator(1);
55+
}
56+
});
57+
58+
AND_WHEN("multiple threads are continuously adding and removing receivers")
59+
{
60+
std::vector<std::thread> threads;
61+
const auto threadCount = 10;
62+
const auto registerCount = 100;
63+
for (auto i = 0; i < threadCount; ++i)
64+
{
65+
auto thread = std::thread([&] {
66+
waitForStart();
67+
68+
for (auto j = 0; j < registerCount; ++j)
69+
{
70+
int_to_int receiver;
71+
generator += &receiver;
72+
generator -= &receiver;
73+
}
74+
});
75+
threads.push_back(std::move(thread));
76+
}
77+
78+
THEN("it doesn't crash")
79+
{
80+
start = true;
81+
}
82+
for (auto& thread : threads)
83+
{
84+
thread.join();
85+
}
86+
}
87+
sendingThread.join();
88+
}
89+
}
90+
}

0 commit comments

Comments
 (0)