From 4f8ca72dc9d683afe3997ec25927d747d9d0ff2b Mon Sep 17 00:00:00 2001 From: Alexander van der Grinten Date: Sat, 7 Feb 2026 17:29:36 +0100 Subject: [PATCH] wait-group: Only decrement ctr to zero with mutex held This fixes a use-after-free that could previously happen if a wait() call ran after reaching zero but before the mutex was taken by done(). --- include/async/wait-group.hpp | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/include/async/wait-group.hpp b/include/async/wait-group.hpp index 2a4fda9..e01b06f 100644 --- a/include/async/wait-group.hpp +++ b/include/async/wait-group.hpp @@ -46,20 +46,19 @@ struct wait_group { > > items; - while (1) { - auto v = ctr_.load(std::memory_order_acquire); - if (ctr_.compare_exchange_strong(v, v - 1, std::memory_order_acq_rel, std::memory_order_acquire)) { - assert(v > 0); - if (v != 1) { - return; - } else { - break; - } - } + // We can decrement outside of the lock as long as we do not reach zero. + auto v = ctr_.load(std::memory_order_relaxed); + while (v > 1) { + if (ctr_.compare_exchange_strong(v, v - 1, std::memory_order_acq_rel, std::memory_order_relaxed)) + return; } + assert(v > 0); { frg::unique_lock lock(mutex_); + // Only wake waiters if we reach zero. + if (ctr_.fetch_sub(1, std::memory_order_acq_rel) > 1) + return; items.splice(items.end(), queue_); } @@ -89,7 +88,8 @@ struct wait_group { { frg::unique_lock lock(wg_->mutex_); - if(wg_->ctr_.load(std::memory_order_acquire) > 0) { + // Relaxed since non-zero -> zero transitions cannot happen while the mutex is held. + if(wg_->ctr_.load(std::memory_order_relaxed) > 0) { if(!cobs_.try_set(ct_)) { cancelled = true; }else{ @@ -108,7 +108,8 @@ struct wait_group { { frg::unique_lock lock(wg_->mutex_); - if(wg_->ctr_.load(std::memory_order_acquire) > 0) { + // Relaxed since non-zero -> zero transitions cannot happen while the mutex is held. + if(wg_->ctr_.load(std::memory_order_relaxed) > 0) { cancelled = true; auto it = wg_->queue_.iterator_to(this); wg_->queue_.erase(it);