diff --git a/pkg/statusc/server.go b/pkg/statusc/server.go index 68e76ec..a4cd46d 100644 --- a/pkg/statusc/server.go +++ b/pkg/statusc/server.go @@ -3,6 +3,7 @@ package statusc import ( "context" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -34,10 +35,15 @@ func Run[T any](ctx context.Context, addr *net.TCPAddr, f func(ctx context.Conte go func() { <-ctx.Done() - if err := srv.Close(); err != nil { - slogc.FineDefault("error closing status server", "err", err) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + slogc.FineDefault("error shutting down status server", "err", err) } }() - return srv.ListenAndServe() + if err := srv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil } diff --git a/server/control/clients.go b/server/control/clients.go index fd46fa1..dd6c9db 100644 --- a/server/control/clients.go +++ b/server/control/clients.go @@ -172,6 +172,7 @@ type clientServer struct { peersMu sync.RWMutex endpointExpiry time.Duration + connsWg sync.WaitGroup } type peerKey struct { @@ -339,10 +340,28 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { } }() + defer func() { + done := make(chan struct{}) + go func() { + s.connsWg.Wait() + close(done) + }() + drainCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + select { + case <-done: + case <-drainCtx.Done(): + s.logger.Debug("connection drain timeout") + } + }() + s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr()) for { conn, err := l.Accept(ctx) if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } slogc.Fine(s.logger, "accept error", "err", err) return fmt.Errorf("client server quic accept: %w", err) } @@ -352,7 +371,9 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go cc.run(ctx) + s.connsWg.Go(func() { + cc.run(ctx) + }) } } @@ -527,7 +548,9 @@ func (c *clientConn) runErr(ctx context.Context) error { conn: c, stream: stream, } - go cs.run(ctx) + c.server.connsWg.Go(func() { + cs.run(ctx) + }) } } diff --git a/server/control/relays.go b/server/control/relays.go index 36a52d2..9bdd798 100644 --- a/server/control/relays.go +++ b/server/control/relays.go @@ -11,6 +11,7 @@ import ( "maps" "net" "sync" + "time" "github.com/connet-dev/connet/model" "github.com/connet-dev/connet/pkg/iterc" @@ -279,10 +280,29 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { } }() + var wg sync.WaitGroup + defer func() { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + drainCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + select { + case <-done: + case <-drainCtx.Done(): + s.logger.Debug("connection drain timeout") + } + }() + s.logger.Info("accepting relay connections", "addr", transport.Conn.LocalAddr()) for { conn, err := l.Accept(ctx) if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } slogc.Fine(s.logger, "accept error", "err", err) return fmt.Errorf("relay server quic accept: %w", err) } @@ -292,7 +312,11 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + wg.Add(1) + go func() { + defer wg.Done() + rc.run(ctx) + }() } } diff --git a/server/relay/clients.go b/server/relay/clients.go index e539586..be7299c 100644 --- a/server/relay/clients.go +++ b/server/relay/clients.go @@ -12,6 +12,7 @@ import ( "net" "slices" "sync" + "time" "github.com/connet-dev/connet/model" "github.com/connet-dev/connet/pkg/certc" @@ -233,10 +234,29 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { } }() + var wg sync.WaitGroup + defer func() { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + drainCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + select { + case <-done: + case <-drainCtx.Done(): + s.logger.Debug("connection drain timeout") + } + }() + s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr()) for { conn, err := l.Accept(ctx) if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } slogc.Fine(s.logger, "accept error", "err", err) return fmt.Errorf("client server quic accept: %w", err) } @@ -246,7 +266,11 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + wg.Add(1) + go func() { + defer wg.Done() + rc.run(ctx) + }() } }