Skip to content

Commit b761511

Browse files
committed
Fixes for concurrency in WS streams
1 parent d9cf8fb commit b761511

File tree

5 files changed

+100
-55
lines changed

5 files changed

+100
-55
lines changed

src/eglt/actions/action.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,9 @@ class Action : public std::enable_shared_from_this<Action> {
333333
AsyncNode* node = node_map_->Get(GetInputId(name));
334334
if (stream_ != nullptr &&
335335
bind_stream.value_or(bind_streams_on_inputs_default_)) {
336-
absl::flat_hash_map<std::string_view,
337-
std::shared_ptr<EvergreenWireStream>>
336+
absl::flat_hash_map<std::string, std::shared_ptr<EvergreenWireStream>>
338337
peers;
339-
peers.insert({stream_->GetId(), stream_});
338+
peers.insert({std::string(stream_->GetId()), stream_});
340339
node->BindPeers(std::move(peers));
341340
nodes_with_bound_streams_.insert(node);
342341
}
@@ -375,10 +374,9 @@ class Action : public std::enable_shared_from_this<Action> {
375374
AsyncNode* node = node_map_->Get(GetOutputId(name));
376375
if (stream_ != nullptr &&
377376
bind_stream.value_or(bind_streams_on_outputs_default_)) {
378-
absl::flat_hash_map<std::string_view,
379-
std::shared_ptr<EvergreenWireStream>>
377+
absl::flat_hash_map<std::string, std::shared_ptr<EvergreenWireStream>>
380378
peers;
381-
peers.insert({stream_->GetId(), stream_});
379+
peers.insert({std::string(stream_->GetId()), stream_});
382380
node->BindPeers(std::move(peers));
383381
nodes_with_bound_streams_.insert(node);
384382
}

src/eglt/concurrency/concurrency.h

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,7 @@ class ABSL_LOCKABLE ABSL_ATTRIBUTE_WARN_UNUSED ExclusiveAccessGuard {
142142
mutex_->Unlock();
143143
}
144144

145-
void StartBlockingPendingOperation() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
146-
++pending_operations_;
147-
}
145+
void StartBlockingPendingOperation() { ++pending_operations_; }
148146

149147
void FinishPendingOperation() ABSL_LOCKS_EXCLUDED(mutex_)
150148
ABSL_EXCLUSIVE_LOCK_FUNCTION(mutex_) {
@@ -153,7 +151,7 @@ class ABSL_LOCKABLE ABSL_ATTRIBUTE_WARN_UNUSED ExclusiveAccessGuard {
153151
cv_->SignalAll();
154152
}
155153

156-
void FinishBlockingPendingOperation() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
154+
void FinishBlockingPendingOperation() {
157155
--pending_operations_;
158156
cv_->SignalAll();
159157
}
@@ -163,8 +161,8 @@ class ABSL_LOCKABLE ABSL_ATTRIBUTE_WARN_UNUSED ExclusiveAccessGuard {
163161
friend class PreventExclusiveAccess;
164162

165163
Mutex* absl_nonnull const mutex_;
166-
CondVar* absl_nonnull const cv_ ABSL_GUARDED_BY(mutex_);
167-
int pending_operations_ ABSL_GUARDED_BY(mutex_) = 0;
164+
CondVar* absl_nonnull const cv_;
165+
int pending_operations_ = 0;
168166
};
169167

170168
class ABSL_SCOPED_LOCKABLE ABSL_ATTRIBUTE_WARN_UNUSED EnsureExclusiveAccess {
@@ -242,14 +240,34 @@ inline Case OnCancel() {
242240
return impl::OnCancel();
243241
}
244242

245-
inline int Select(const CaseArray& cases) {
243+
inline int Select(const CaseArray& cases) noexcept {
246244
return impl::Select(cases);
247245
}
248246

249-
inline int SelectUntil(const absl::Time deadline, const CaseArray& cases) {
247+
inline int SelectWithScopedUnlock(Mutex* absl_nonnull mu,
248+
const CaseArray& cases)
249+
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) {
250+
mu->Unlock();
251+
const int selected = concurrency::Select(cases);
252+
mu->Lock();
253+
return selected;
254+
}
255+
256+
inline int SelectUntil(const absl::Time deadline,
257+
const CaseArray& cases) noexcept {
250258
return impl::SelectUntil(deadline, cases);
251259
}
252260

261+
inline int SelectWithScopedUnlockUntil(Mutex* absl_nonnull mu,
262+
const absl::Time deadline,
263+
const CaseArray& cases)
264+
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) {
265+
mu->Unlock();
266+
const int selected = concurrency::SelectUntil(deadline, cases);
267+
mu->Lock();
268+
return selected;
269+
}
270+
253271
inline void Detach(std::unique_ptr<Fiber> fiber) {
254272
impl::Detach(std::move(fiber));
255273
}

src/eglt/nodes/async_node.h

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,11 @@ class AsyncNode {
5050
AsyncNode& operator=(AsyncNode& other) = delete;
5151
AsyncNode& operator=(AsyncNode&& other) noexcept;
5252

53-
~AsyncNode() {
54-
concurrency::MutexLock lock(&mutex_);
55-
finalized_ = true;
56-
concurrency::EnsureExclusiveAccess guard(&finalization_guard_);
57-
}
53+
~AsyncNode() { concurrency::MutexLock lock(&mutex_); }
5854

59-
void BindPeers(absl::flat_hash_map<std::string_view,
60-
std::shared_ptr<EvergreenWireStream>>
61-
peers) {
55+
void BindPeers(
56+
absl::flat_hash_map<std::string, std::shared_ptr<EvergreenWireStream>>
57+
peers) {
6258
concurrency::MutexLock lock(&mutex_);
6359
peers_ = std::move(peers);
6460
}
@@ -163,13 +159,10 @@ class AsyncNode {
163159

164160
mutable concurrency::Mutex mutex_;
165161
mutable concurrency::CondVar cv_ ABSL_GUARDED_BY(mutex_);
166-
concurrency::ExclusiveAccessGuard finalization_guard_ ABSL_GUARDED_BY(mutex_){
167-
&mutex_, &cv_};
168-
bool finalized_ ABSL_GUARDED_BY(mutex_) = false;
169162
std::unique_ptr<ChunkStoreReader> default_reader_ ABSL_GUARDED_BY(mutex_);
170163
std::unique_ptr<ChunkStoreWriter> default_writer_ ABSL_GUARDED_BY(mutex_);
171-
absl::flat_hash_map<std::string_view, std::shared_ptr<EvergreenWireStream>>
172-
peers_ ABSL_GUARDED_BY(mutex_);
164+
absl::flat_hash_map<std::string, std::shared_ptr<EvergreenWireStream>> peers_
165+
ABSL_GUARDED_BY(mutex_);
173166
};
174167

175168
template <>
@@ -187,6 +180,9 @@ inline auto AsyncNode::Put<Chunk>(Chunk value, int seq_id, bool final)
187180
template <>
188181
inline auto AsyncNode::Put(NodeFragment value, int seq_id, bool final)
189182
-> absl::Status {
183+
if (seq_id == -1) {
184+
seq_id = value.seq;
185+
}
190186
return PutFragment(std::move(value), seq_id);
191187
}
192188

src/eglt/sdk/serving/websockets.h

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define BOOST_ASIO_NO_DEPRECATED
88

99
#include <boost/asio/ip/tcp.hpp>
10+
#include <boost/asio/strand.hpp>
1011
#include <boost/asio/thread_pool.hpp>
1112
#include <boost/beast/core.hpp>
1213
#include <boost/beast/websocket.hpp>
@@ -87,41 +88,52 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
8788
}
8889

8990
absl::Status Send(SessionMessage message) override {
90-
91+
auto message_bytes = cppack::Pack(std::move(message));
9192
concurrency::MutexLock lock(&mutex_);
9293
if (!status_.ok()) {
9394
return status_;
9495
}
9596

97+
if (concurrency::Cancelled()) {
98+
return absl::CancelledError("Send cancelled");
99+
}
100+
101+
while (send_pending_) {
102+
cv_.Wait(&mutex_);
103+
}
96104
send_pending_ = true;
97-
mutex_.Unlock();
98-
boost::system::error_code error;
99105
stream_.binary(true);
100-
RunInAsioContext(
101-
[this, &error, &message]() {
102-
auto message_bytes = cppack::Pack(std::move(message));
103-
stream_.write(asio::buffer(message_bytes), error);
104-
},
105-
{concurrency::OnCancel()});
106+
107+
boost::system::error_code error;
108+
concurrency::PermanentEvent write_done;
109+
stream_.async_write(asio::buffer(message_bytes),
110+
[&error, &write_done, this](
111+
const boost::system::error_code& ec, std::size_t) {
112+
concurrency::MutexLock lock(&mutex_);
113+
error = ec;
114+
write_done.Notify();
115+
send_pending_ = false;
116+
cv_.SignalAll();
117+
});
118+
mutex_.Unlock();
119+
concurrency::Select({write_done.OnEvent(), concurrency::OnCancel()});
106120
mutex_.Lock();
107-
send_pending_ = false;
108-
last_send_status_ = absl::OkStatus();
109121

122+
absl::Status status = absl::OkStatus();
110123
if (concurrency::Cancelled()) {
111-
stream_.next_layer().shutdown(asio::socket_base::shutdown_send);
124+
stream_.next_layer().shutdown(asio::socket_base::shutdown_send, error);
112125
DLOG(INFO) << absl::StrFormat("WESt %s Send cancelled", id_);
113-
last_send_status_ = absl::CancelledError("Send cancelled");
126+
status = absl::CancelledError("Send cancelled");
114127
}
115128

129+
// send_pending_ = false;
130+
// cv_.SignalAll();
116131
if (error) {
117-
last_send_status_ = absl::InternalError(error.message());
118-
DLOG(INFO) << absl::StrFormat("WESt %s Send failed: %v", id_,
119-
last_send_status_);
132+
status = absl::InternalError(error.message());
133+
DLOG(INFO) << absl::StrFormat("WESt %s Send failed: %v", id_, status);
120134
}
121135

122-
cv_.SignalAll();
123-
124-
return last_send_status_;
136+
return status;
125137
}
126138

127139
std::optional<SessionMessage> Receive() override {
@@ -135,6 +147,14 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
135147
return std::nullopt;
136148
}
137149

150+
if (concurrency::Cancelled()) {
151+
return std::nullopt;
152+
}
153+
154+
while (recv_pending_) {
155+
cv_.Wait(&mutex_);
156+
}
157+
138158
recv_pending_ = true;
139159
mutex_.Unlock();
140160
boost::system::error_code error;
@@ -144,11 +164,23 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
144164
stream_.read(dynamic_buffer, error);
145165
},
146166
{concurrency::OnCancel()});
167+
168+
if (concurrency::Cancelled()) {
169+
stream_.next_layer().shutdown(asio::socket_base::shutdown_receive,
170+
error);
171+
if (error && error != asio::error::not_connected) {
172+
LOG(ERROR) << absl::StrFormat(
173+
"WESt %s Cannot shut down receive on socket: %v", id_,
174+
error.message());
175+
}
176+
}
177+
147178
mutex_.Lock();
148179
recv_pending_ = false;
149180

181+
cv_.SignalAll();
182+
150183
if (concurrency::Cancelled()) {
151-
stream_.next_layer().shutdown(asio::socket_base::shutdown_receive);
152184
DLOG(INFO) << absl::StrFormat("WESt %s Receive cancelled", id_);
153185
return std::nullopt;
154186
}
@@ -158,8 +190,6 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
158190
error.message());
159191
return std::nullopt;
160192
}
161-
162-
cv_.SignalAll();
163193
}
164194

165195
if (auto unpacked = cppack::Unpack<SessionMessage>(buffer); unpacked.ok()) {
@@ -185,6 +215,7 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
185215
beast::http::field::server,
186216
"Action Engine / Evergreen Light 0.1.0 WebsocketEvergreenServer");
187217
}));
218+
stream_.write_buffer_bytes(16);
188219

189220
boost::system::error_code error;
190221

@@ -214,7 +245,7 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
214245
status_ = absl::InternalError(error.message());
215246
}
216247
status_ = absl::CancelledError("Cancelled");
217-
while (send_pending_ || recv_pending_) {
248+
while (send_pending_) {
218249
cv_.Wait(&mutex_);
219250
}
220251
stream_.next_layer().wait(tcp::socket::wait_error, error);
@@ -235,7 +266,7 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
235266
}
236267
}
237268

238-
absl::Status GetStatus() const override { return last_send_status_; }
269+
absl::Status GetStatus() const override { return status_; }
239270

240271
[[nodiscard]] std::string_view GetId() const override { return id_; }
241272

@@ -255,7 +286,6 @@ class WebsocketEvergreenWireStream final : public EvergreenWireStream {
255286
std::string id_;
256287

257288
absl::Status status_;
258-
absl::Status last_send_status_;
259289

260290
mutable concurrency::Mutex mutex_;
261291
mutable concurrency::CondVar cv_ ABSL_GUARDED_BY(mutex_);
@@ -317,7 +347,8 @@ class WebsocketEvergreenServer {
317347
concurrency::MutexLock lock(&mutex_);
318348
main_loop_ = concurrency::NewTree({}, [this]() {
319349
while (!concurrency::Cancelled()) {
320-
tcp::socket socket{*GetDefaultAsioExecutionContext()};
350+
tcp::socket socket{
351+
asio::make_strand(*GetDefaultAsioExecutionContext())};
321352

322353
DLOG(INFO) << "WES waiting for connection.";
323354
boost::system::error_code error;
@@ -499,6 +530,8 @@ MakeWebsocketEvergreenWireStream(std::string_view address = "127.0.0.1",
499530
return absl::CancelledError("Cancelled");
500531
}
501532

533+
ws_stream.write_buffer_bytes(16);
534+
502535
if (prepare_stream) {
503536
if (auto status = std::move(prepare_stream)(&ws_stream); !status.ok()) {
504537
return status;

src/thread_on_boost/thread_on_boost/fiber.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ void Fiber::Body() {
155155
// MarkFinished returns whether the fiber was detached when finished.
156156
// Detached fibers are self-joining.
157157
InternalJoin();
158+
context_.detach();
158159
delete this;
159160
}
160161
}
@@ -222,9 +223,8 @@ void Fiber::Join() {
222223
CHECK(parent_ == current_fiber) << "Join() called from non-parent fiber";
223224
}
224225

225-
context_ = nullptr;
226-
227226
InternalJoin();
227+
context_->join();
228228
}
229229

230230
// Update *this to a FINISHED state. Preparing it to be Join()-ed (and

0 commit comments

Comments
 (0)