Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions modules/graceful/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ func initManager(ctx context.Context) {
// The Cancel function should stop the Run function in predictable time.
func (g *Manager) RunWithCancel(rc RunCanceler) {
g.RunAtShutdown(context.Background(), rc.Cancel)
g.runningServerWaitGroup.Add(1)
if !g.runningServerWaitGroup.AddIfRunning() {
// Shutdown already in progress, do not start this server
return
}
defer g.runningServerWaitGroup.Done()
defer func() {
if err := recover(); err != nil {
Expand All @@ -87,7 +90,10 @@ func (g *Manager) RunWithCancel(rc RunCanceler) {
// After the provided context is Done(), the main function must return once shutdown is complete.
// (Optionally the HammerContext may be obtained and waited for however, this should be avoided if possible.)
func (g *Manager) RunWithShutdownContext(run func(context.Context)) {
g.runningServerWaitGroup.Add(1)
if !g.runningServerWaitGroup.AddIfRunning() {
// Shutdown already in progress, do not start this server
return
}
defer g.runningServerWaitGroup.Done()
defer func() {
if err := recover(); err != nil {
Expand All @@ -102,7 +108,10 @@ func (g *Manager) RunWithShutdownContext(run func(context.Context)) {

// RunAtTerminate adds to the terminate wait group and creates a go-routine to run the provided function at termination
func (g *Manager) RunAtTerminate(terminate func()) {
g.terminateWaitGroup.Add(1)
if !g.terminateWaitGroup.AddIfRunning() {
// Termination already in progress, do not add this function
return
}
g.lock.Lock()
defer g.lock.Unlock()
g.toRunAtTerminate = append(g.toRunAtTerminate,
Expand Down Expand Up @@ -155,12 +164,12 @@ func (g *Manager) doShutdown() {
go g.doHammerTime(setting.GracefulHammerTime)
}
go func() {
g.runningServerWaitGroup.Wait()
g.runningServerWaitGroup.Shutdown()
// Mop up any remaining unclosed events.
g.doHammerTime(0)
<-time.After(1 * time.Second)
g.doTerminate()
g.terminateWaitGroup.Wait()
g.terminateWaitGroup.Shutdown()
g.lock.Lock()
g.managerCtxCancel()
g.lock.Unlock()
Expand Down
9 changes: 6 additions & 3 deletions modules/graceful/manager_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ type Manager struct {
hammerCtxCancel context.CancelFunc
terminateCtxCancel context.CancelFunc
managerCtxCancel context.CancelFunc
runningServerWaitGroup sync.WaitGroup
terminateWaitGroup sync.WaitGroup
runningServerWaitGroup SafeWaitGroup
terminateWaitGroup SafeWaitGroup
createServerCond sync.Cond
createdServer int
shutdownRequested chan struct{}
Expand Down Expand Up @@ -106,5 +106,8 @@ func (g *Manager) DoGracefulShutdown() {
// Any call to RegisterServer must be matched by a call to ServerDone
func (g *Manager) RegisterServer() {
KillParent()
g.runningServerWaitGroup.Add(1)
if !g.runningServerWaitGroup.AddIfRunning() {
// Shutdown already in progress, server should not start
return
}
}
12 changes: 8 additions & 4 deletions modules/graceful/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Server struct {
network string
address string
listener net.Listener
wg sync.WaitGroup
wg SafeWaitGroup
state state
lock *sync.RWMutex
BeforeBegin func(network, address string)
Expand All @@ -50,7 +50,6 @@ func NewServer(network, address, name string) *Server {
log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
}
srv := &Server{
wg: sync.WaitGroup{},
state: stateInit,
lock: &sync.RWMutex{},
network: network,
Expand Down Expand Up @@ -154,7 +153,8 @@ func (srv *Server) Serve(serve ServeFunction) error {
GetManager().RegisterServer()
err := serve(srv.listener)
log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
srv.wg.Wait()
// Shutdown the waitgroup and wait for all connections to finish
srv.wg.Shutdown()
srv.setState(stateTerminate)
GetManager().ServerDone()
// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
Expand Down Expand Up @@ -224,7 +224,11 @@ func (wl *wrappedListener) Accept() (net.Conn, error) {
perWritePerKbTimeout: wl.server.PerWritePerKbTimeout,
}

wl.server.wg.Add(1)
if !wl.server.wg.AddIfRunning() {
// Server is shutting down, reject new connection
_ = c.Close()
return nil, net.ErrClosed
}
return c, nil
}

Expand Down
25 changes: 7 additions & 18 deletions modules/graceful/server_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,15 @@
}

func (srv *Server) doHammer() {
defer func() {
// We call srv.wg.Done() until it panics.
// This happens if we call Done() when the WaitGroup counter is already at 0
// So if it panics -> we're done, Serve() will return and the
// parent will goroutine will exit.
if r := recover(); r != nil {
log.Error("WaitGroup at 0: Error: %v", r)
}
}()
if srv.getState() != stateShuttingDown {
return
}
log.Warn("Forcefully shutting down parent")
for {
if srv.getState() == stateTerminate {
break
}
srv.wg.Done()

// Give other goroutines a chance to finish before we forcibly stop them.
runtime.Gosched()
}

Check failure on line 55 in modules/graceful/server_hooks.go

View workflow job for this annotation

GitHub Actions / lint-backend

File is not properly formatted (gofmt)

Check failure on line 55 in modules/graceful/server_hooks.go

View workflow job for this annotation

GitHub Actions / lint-go-windows

File is not properly formatted (gofmt)

Check failure on line 55 in modules/graceful/server_hooks.go

View workflow job for this annotation

GitHub Actions / lint-go-gogit

File is not properly formatted (gofmt)
// Shutdown the waitgroup to prevent new connections
// and wait for existing connections to finish
srv.wg.Shutdown()
Comment on lines +56 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the current design would work or it can be improved.

doHammer means that it will stop the server immediately without waiting for any running tasks (for example: existing user request connection).

The graceful package's design overall looks problematic and it seems unfixable to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for clarifying, i understand what you mean now... that makes sense: doHammer should remain immediate and not wait for any running connections, just forcefully stop the server....

i'll keep this note in mind, and if there’s a future plan to refactor or redesign the graceful manager to address the deeper issues, i'd be happy to participate or help with that whenever it’s appropriate...


// Give other goroutines a chance to finish
runtime.Gosched()
}
49 changes: 49 additions & 0 deletions modules/graceful/wg_safe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2025 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package graceful

import (
"sync"
)

// SafeWaitGroup is a small wrapper around sync.WaitGroup that prevents
// new Adds after Shutdown has been requested. It prevents the "WaitGroup
// is reused before previous Wait has returned" panic by gating Add calls.
type SafeWaitGroup struct {
mu sync.Mutex
wg sync.WaitGroup
shutting bool
}

// AddIfRunning attempts to add one to the waitgroup. It returns true if
// the add succeeded; false if shutdown has already started and add was rejected.
func (s *SafeWaitGroup) AddIfRunning() bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.shutting {
return false
}
s.wg.Add(1)
return true
}

// Done decrements the wait group counter.
// Call only if AddIfRunning returned true previously.
func (s *SafeWaitGroup) Done() {
s.wg.Done()
}

// Wait waits for the waitgroup to complete.
func (s *SafeWaitGroup) Wait() {
s.wg.Wait()
}

// Shutdown marks the group as shutting and then waits for all existing
// routines to finish. After Shutdown returns, AddIfRunning will return false.
func (s *SafeWaitGroup) Shutdown() {
s.mu.Lock()
s.shutting = true
s.mu.Unlock()
s.wg.Wait()
}
178 changes: 178 additions & 0 deletions modules/graceful/wg_safe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright 2025 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package graceful

import (
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestSafeWaitGroup_AddIfRunning(t *testing.T) {
var swg SafeWaitGroup

// Test that AddIfRunning succeeds before shutdown
assert.True(t, swg.AddIfRunning(), "AddIfRunning should succeed before shutdown")
swg.Done()

// Test that AddIfRunning fails after shutdown
swg.Shutdown()
assert.False(t, swg.AddIfRunning(), "AddIfRunning should fail after shutdown")
}

func TestSafeWaitGroup_Shutdown(t *testing.T) {
var swg SafeWaitGroup

// Add some work
assert.True(t, swg.AddIfRunning())
assert.True(t, swg.AddIfRunning())

// Complete work in goroutines
go func() {
time.Sleep(50 * time.Millisecond)
swg.Done()
}()
go func() {
time.Sleep(100 * time.Millisecond)
swg.Done()
}()

// Shutdown should wait for all work to complete
start := time.Now()
swg.Shutdown()
elapsed := time.Since(start)

assert.GreaterOrEqual(t, elapsed, 100*time.Millisecond, "Shutdown should wait for all work")
}

func TestSafeWaitGroup_ConcurrentAddAndShutdown(t *testing.T) {
var swg SafeWaitGroup
var addCount atomic.Int32
var successCount atomic.Int32

// Start many goroutines trying to add
const numGoroutines = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)

for range numGoroutines {
go func() {
defer wg.Done()
if swg.AddIfRunning() {
addCount.Add(1)
time.Sleep(10 * time.Millisecond)
swg.Done()
successCount.Add(1)
}
}()
}

// Give some goroutines time to add
time.Sleep(5 * time.Millisecond)

// Now shutdown
swg.Shutdown()

// Wait for all goroutines to finish
wg.Wait()

// All adds that succeeded should have completed
assert.Equal(t, addCount.Load(), successCount.Load(), "All successful adds should complete")

// After shutdown, no new adds should succeed
assert.False(t, swg.AddIfRunning(), "No adds should succeed after shutdown")
}

func TestSafeWaitGroup_MultipleShutdowns(t *testing.T) {
var swg SafeWaitGroup

// First shutdown
swg.Shutdown()

// Second shutdown should not panic
assert.NotPanics(t, func() {
swg.Shutdown()
}, "Multiple shutdowns should not panic")
}

func TestSafeWaitGroup_WaitWithoutAdd(t *testing.T) {
var swg SafeWaitGroup

// Wait without any adds should not block
done := make(chan struct{})
go func() {
swg.Wait()
close(done)
}()

select {
case <-done:
// Success
case <-time.After(100 * time.Millisecond):
t.Fatal("Wait should not block when no work is added")
}
}

func TestSafeWaitGroup_PreventsPanic(t *testing.T) {
var swg SafeWaitGroup

// This pattern would cause a panic with sync.WaitGroup:
// Add -> Wait (in goroutine) -> Add (would panic)

assert.True(t, swg.AddIfRunning())

go func() {
time.Sleep(10 * time.Millisecond)
swg.Done()
}()

// Start shutdown which will wait
go swg.Shutdown()

// Give shutdown time to start
time.Sleep(5 * time.Millisecond)

// This should not panic, just return false
assert.NotPanics(t, func() {
result := swg.AddIfRunning()
assert.False(t, result, "Add should be rejected during shutdown")
}, "AddIfRunning should not panic during shutdown")
}

func TestSafeWaitGroup_RaceCondition(t *testing.T) {
// This test is designed to catch race conditions
// Run with: go test -race
var swg SafeWaitGroup
var wg sync.WaitGroup

const numWorkers = 50
wg.Add(numWorkers * 2) // workers + shutdown goroutines

// Start workers
for range numWorkers {
go func() {
defer wg.Done()
for range 10 {
if swg.AddIfRunning() {
time.Sleep(time.Millisecond)
swg.Done()
}
}
}()
}

// Start shutdown attempts
for range numWorkers {
go func() {
defer wg.Done()
time.Sleep(5 * time.Millisecond)
swg.Shutdown()
}()
}

wg.Wait()
}
Loading