diff --git a/destination.go b/destination.go index 08236713..06e78f05 100644 --- a/destination.go +++ b/destination.go @@ -203,6 +203,7 @@ func (d *destinationConn) run(ctx context.Context) { } func (d *destinationConn) runErr(ctx context.Context) error { + // TODO do we need graceful shutdown of these streams? for { stream, err := d.conn.AcceptStream(ctx) if err != nil { diff --git a/pkg/reliable/drain.go b/pkg/reliable/drain.go new file mode 100644 index 00000000..19aaa9a3 --- /dev/null +++ b/pkg/reliable/drain.go @@ -0,0 +1,39 @@ +package reliable + +import ( + "context" + "fmt" + "sync" + "time" +) + +type Drain struct { + ctx context.Context + wg sync.WaitGroup +} + +func NewDrain(ctx context.Context) *Drain { + return &Drain{ctx: ctx} +} + +func (d *Drain) Go(f func(context.Context)) { + d.wg.Go(func() { f(d.ctx) }) +} + +func (d *Drain) Wait(timeout time.Duration) error { + done := make(chan struct{}) + go func() { + d.wg.Wait() + close(done) + }() + + drainCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + select { + case <-done: + return nil + case <-drainCtx.Done(): + return fmt.Errorf("drain timeout") + } +} diff --git a/pkg/statusc/server.go b/pkg/statusc/server.go index 68e76ec2..9f119ead 100644 --- a/pkg/statusc/server.go +++ b/pkg/statusc/server.go @@ -34,10 +34,15 @@ func Run[T any](ctx context.Context, addr *net.TCPAddr, f func(ctx context.Conte go func() { <-ctx.Done() - if err := srv.Close(); err != nil { - slogc.FineDefault("error closing status server", "err", err) + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + slogc.FineDefault("error shutting down status server", "err", err) } }() + // TODO return no error on server shutdown? return srv.ListenAndServe() } diff --git a/server/control/clients.go b/server/control/clients.go index fd46fa1a..7d27cc9c 100644 --- a/server/control/clients.go +++ b/server/control/clients.go @@ -339,10 +339,18 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { } }() + drain := reliable.NewDrain(ctx) + defer func() { + if err := drain.Wait(10 * time.Second); err != nil { + s.logger.Warn("clients drain", "err", err) + } + }() + s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr()) for { conn, err := l.Accept(ctx) if err != nil { + // TODO return no error on context cancel? slogc.Fine(s.logger, "accept error", "err", err) return fmt.Errorf("client server quic accept: %w", err) } @@ -352,7 +360,7 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go cc.run(ctx) + drain.Go(cc.run) } } @@ -479,6 +487,7 @@ type clientConnAuth struct { func (c *clientConn) run(ctx context.Context) { c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr()) defer func() { + // TODO richer errors if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { slogc.Fine(c.logger, "error closing connection", "err", err) } @@ -517,6 +526,13 @@ func (c *clientConn) runErr(ctx context.Context) error { } }() + drain := reliable.NewDrain(ctx) + defer func() { + if err := drain.Wait(5 * time.Second); err != nil { + c.logger.Warn("client streams drain", "err", err) + } + }() + for { stream, err := c.conn.AcceptStream(ctx) if err != nil { @@ -527,7 +543,7 @@ func (c *clientConn) runErr(ctx context.Context) error { conn: c, stream: stream, } - go cs.run(ctx) + drain.Go(cs.run) } } @@ -606,6 +622,7 @@ type clientStream struct { func (s *clientStream) run(ctx context.Context) { defer func() { + // TODO richer errors if err := s.stream.Close(); err != nil { slogc.Fine(s.conn.logger, "error closing client stream", "err", err) } @@ -669,8 +686,8 @@ func (s *clientStream) announce(ctx context.Context, req *pbclient.Request_Annou return err } defer func() { - if s.conn.server.endpointExpiry > 0 && s.conn.conn.Context().Err() != nil { - // Connection dead — mark as expired, consumer will delete after timeout + if s.conn.server.endpointExpiry > 0 && (ctx.Err() != nil || s.conn.conn.Context().Err() != nil) { + // Server shutting down or connection dead — mark as expired, consumer will delete after timeout if err := s.conn.server.expire(endpoint, role, s.conn.id, s.conn.connID); err != nil { s.conn.logger.Warn("failed to expire peer", "id", s.conn.id, "err", err) } diff --git a/server/control/relays.go b/server/control/relays.go index 36a52d28..615df0ad 100644 --- a/server/control/relays.go +++ b/server/control/relays.go @@ -11,6 +11,7 @@ import ( "maps" "net" "sync" + "time" "github.com/connet-dev/connet/model" "github.com/connet-dev/connet/pkg/iterc" @@ -279,10 +280,18 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { } }() + drain := reliable.NewDrain(ctx) + defer func() { + if err := drain.Wait(10 * time.Second); err != nil { + s.logger.Warn("relays drain", "err", err) + } + }() + s.logger.Info("accepting relay connections", "addr", transport.Conn.LocalAddr()) for { conn, err := l.Accept(ctx) if err != nil { + // TODO return no error on context cancel? slogc.Fine(s.logger, "accept error", "err", err) return fmt.Errorf("relay server quic accept: %w", err) } @@ -292,7 +301,7 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + drain.Go(rc.run) } } diff --git a/server/relay/clients.go b/server/relay/clients.go index e5395860..9f2197fb 100644 --- a/server/relay/clients.go +++ b/server/relay/clients.go @@ -12,6 +12,7 @@ import ( "net" "slices" "sync" + "time" "github.com/connet-dev/connet/model" "github.com/connet-dev/connet/pkg/certc" @@ -233,10 +234,18 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { } }() + drain := reliable.NewDrain(ctx) + defer func() { + if err := drain.Wait(10 * time.Second); err != nil { + s.logger.Warn("clients drain", "err", err) + } + }() + s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr()) for { conn, err := l.Accept(ctx) if err != nil { + // TODO return no error on context cancel? slogc.Fine(s.logger, "accept error", "err", err) return fmt.Errorf("client server quic accept: %w", err) } @@ -246,7 +255,7 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + drain.Go(rc.run) } } @@ -261,6 +270,7 @@ type clientConn struct { func (c *clientConn) run(ctx context.Context) { c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr()) defer func() { + // TODO richer errors if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { slogc.Fine(c.logger, "error closing connection", "err", err) } @@ -353,6 +363,7 @@ func (c *clientConn) runSource(ctx context.Context) error { }) g.Go(func(ctx context.Context) error { + // TODO do we need graceful shutdown of these streams? for { stream, err := c.conn.AcceptStream(ctx) if err != nil { @@ -367,6 +378,7 @@ func (c *clientConn) runSource(ctx context.Context) error { func (c *clientConn) runSourceStream(ctx context.Context, stream *quic.Stream, fcs *endpointClients) { defer func() { + // TODO richer errors if err := stream.Close(); err != nil { slogc.Fine(c.logger, "error closing source stream", "err", err) }