diff --git a/pkg/reliable/group.go b/pkg/reliable/group.go index 4a1d881a..d0cab8a0 100644 --- a/pkg/reliable/group.go +++ b/pkg/reliable/group.go @@ -2,6 +2,8 @@ package reliable import ( "context" + "sync" + "sync/atomic" "time" "golang.org/x/sync/errgroup" @@ -9,6 +11,26 @@ import ( type RunFn func(context.Context) error +// ReadyNotifier returns a func that waits for n nil calls before sending nil to ch. +// Any non-nil error immediately sends to ch. After sending, ch is closed. +func ReadyNotifier(n int, ch chan<- error) func(error) { + var once sync.Once + var remaining atomic.Int32 + remaining.Store(int32(n)) + send := func(err error) { + once.Do(func() { ch <- err; close(ch) }) + } + return func(err error) { + if err != nil { + send(err) + return + } + if remaining.Add(-1) == 0 { + send(nil) + } + } +} + func Bind[T any](t T, fn func(context.Context, T) error) RunFn { return func(ctx context.Context) error { return fn(ctx, t) diff --git a/server/config.go b/server/config.go index d4a30e98..1ab6eb1e 100644 --- a/server/config.go +++ b/server/config.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "log/slog" "net" @@ -20,6 +21,8 @@ type serverConfig struct { relayIngresses []relay.Ingress + drainTimeout func(ctx context.Context) time.Duration + dir string logger *slog.Logger } @@ -143,6 +146,13 @@ func Logger(logger *slog.Logger) Option { } } +func ServerDrainTimeout(fn func(ctx context.Context) time.Duration) Option { + return func(cfg *serverConfig) error { + cfg.drainTimeout = fn + return nil + } +} + func StoreDirFromEnvPrefixed(prefix string) (string, error) { if stateDir := os.Getenv("CONNET_STATE_DIR"); stateDir != "" { // Support direct override if necessary, currently used in docker diff --git a/server/control/clients.go b/server/control/clients.go index fd46fa1a..b149fe74 100644 --- a/server/control/clients.go +++ b/server/control/clients.go @@ -172,6 +172,8 @@ type clientServer struct { peersMu sync.RWMutex endpointExpiry time.Duration + + connsWg sync.WaitGroup } type peerKey struct { @@ -276,11 +278,12 @@ func (s *clientServer) listen(ctx context.Context, endpoint model.Endpoint, role } } -func (s *clientServer) run(ctx context.Context) error { +func (s *clientServer) run(ctx context.Context, notifyReady func(error)) error { g := reliable.NewGroup(ctx) for _, ingress := range s.ingresses { - g.Go(reliable.Bind(ingress, s.runListener)) + ingress := ingress + g.Go(func(ctx context.Context) error { return s.runListener(ctx, ingress, notifyReady) }) } g.Go(s.runPeerCache) if s.endpointExpiry > 0 { @@ -293,7 +296,16 @@ func (s *clientServer) run(ctx context.Context) error { return g.Wait() } -func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { +func (s *clientServer) waitDrain(ctx context.Context) { + done := make(chan struct{}) + go func() { s.connsWg.Wait(); close(done) }() + select { + case <-done: + case <-ctx.Done(): + } +} + +func (s *clientServer) runListener(ctx context.Context, ingress Ingress, notifyReady func(error)) error { s.logger.Debug("start udp listener", "addr", ingress.Addr) udpConn, err := net.ListenUDP("udp", ingress.Addr) if err != nil { @@ -331,6 +343,7 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { l, err := transport.Listen(tlsConf, quicConf) if err != nil { + notifyReady(fmt.Errorf("client server quic listen: %w", err)) return fmt.Errorf("client server quic listen: %w", err) } defer func() { @@ -340,6 +353,7 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { }() s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr()) + notifyReady(nil) for { conn, err := l.Accept(ctx) if err != nil { @@ -352,7 +366,8 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go cc.run(ctx) + s.connsWg.Add(1) + go func() { defer s.connsWg.Done(); cc.run(ctx) }() } } @@ -467,6 +482,8 @@ type clientConn struct { connID ConnID clientConnAuth + + streamsWg sync.WaitGroup } type clientConnAuth struct { @@ -478,15 +495,15 @@ type clientConnAuth struct { func (c *clientConn) run(ctx context.Context) { c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr()) - defer func() { - if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { - slogc.Fine(c.logger, "error closing connection", "err", err) - } - }() if err := c.runErr(ctx); err != nil { c.logger.Debug("error while running client conn", "err", err) } + + if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { + slogc.Fine(c.logger, "error closing connection", "err", err) + } + c.streamsWg.Wait() } func (c *clientConn) runErr(ctx context.Context) error { @@ -527,7 +544,8 @@ func (c *clientConn) runErr(ctx context.Context) error { conn: c, stream: stream, } - go cs.run(ctx) + c.streamsWg.Add(1) + go func() { defer c.streamsWg.Done(); cs.run(ctx) }() } } diff --git a/server/control/relays.go b/server/control/relays.go index 36a52d28..47d34a3d 100644 --- a/server/control/relays.go +++ b/server/control/relays.go @@ -132,6 +132,8 @@ type relayServer struct { connsCache map[RelayID]cachedRelay connsOffset int64 connsMu sync.RWMutex + + connsWg sync.WaitGroup } type cachedRelay struct { @@ -220,11 +222,12 @@ func (s *relayServer) Relays(ctx context.Context, endpoint model.Endpoint, role } } -func (s *relayServer) run(ctx context.Context) error { +func (s *relayServer) run(ctx context.Context, notifyReady func(error)) error { g := reliable.NewGroup(ctx) for _, ingress := range s.ingresses { - g.Go(reliable.Bind(ingress, s.runListener)) + ingress := ingress + g.Go(func(ctx context.Context) error { return s.runListener(ctx, ingress, notifyReady) }) } g.Go(s.runConnsCache) @@ -233,7 +236,16 @@ func (s *relayServer) run(ctx context.Context) error { return g.Wait() } -func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { +func (s *relayServer) waitDrain(ctx context.Context) { + done := make(chan struct{}) + go func() { s.connsWg.Wait(); close(done) }() + select { + case <-done: + case <-ctx.Done(): + } +} + +func (s *relayServer) runListener(ctx context.Context, ingress Ingress, notifyReady func(error)) error { s.logger.Debug("start udp listener", "addr", ingress.Addr) udpConn, err := net.ListenUDP("udp", ingress.Addr) if err != nil { @@ -271,6 +283,7 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { l, err := transport.Listen(tlsConf, quicConf) if err != nil { + notifyReady(fmt.Errorf("relay server quic listen: %w", err)) return fmt.Errorf("relay server quic listen: %w", err) } defer func() { @@ -280,6 +293,7 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { }() s.logger.Info("accepting relay connections", "addr", transport.Conn.LocalAddr()) + notifyReady(nil) for { conn, err := l.Accept(ctx) if err != nil { @@ -292,7 +306,8 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + s.connsWg.Add(1) + go func() { defer s.connsWg.Done(); rc.run(ctx) }() } } @@ -358,16 +373,15 @@ type relayConnAuth struct { } func (c *relayConn) run(ctx context.Context) { - defer func() { - if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { - slogc.Fine(c.logger, "error closing connection", "err", err) - } - }() c.logger.Debug("new relay connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr()) if err := c.runErr(ctx); err != nil { c.logger.Debug("error while running relay conn", "err", err) } + + if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { + slogc.Fine(c.logger, "error closing connection", "err", err) + } } func (c *relayConn) runErr(ctx context.Context) error { diff --git a/server/control/server.go b/server/control/server.go index 1d6e4a15..f7a510ca 100644 --- a/server/control/server.go +++ b/server/control/server.go @@ -13,6 +13,14 @@ import ( "github.com/connet-dev/connet/pkg/restr" ) +func newDrainCtx(ctx context.Context, fn func(context.Context) time.Duration) (context.Context, context.CancelFunc) { + d := 30 * time.Second + if fn != nil { + d = fn(ctx) + } + return context.WithTimeout(context.Background(), d) +} + type Config struct { ClientsIngress []Ingress ClientsAuth ClientAuthenticator @@ -24,6 +32,8 @@ type Config struct { Stores Stores Logger *slog.Logger + + DrainTimeout func(ctx context.Context) time.Duration } func NewServer(cfg Config) (*Server, error) { @@ -51,6 +61,10 @@ func NewServer(cfg Config) (*Server, error) { relays: relays, config: configStore, + + drainTimeout: cfg.DrainTimeout, + readyCh: make(chan error, 1), + doneCh: make(chan error, 1), }, nil } @@ -59,14 +73,42 @@ type Server struct { relays *relayServer config logc.KV[ConfigKey, ConfigValue] + + drainTimeout func(ctx context.Context) time.Duration + readyCh chan error + doneCh chan error +} + +// Ready returns a channel that receives nil when the server is ready to accept traffic, +// or a non-nil error if startup failed. The channel is then closed. +func (s *Server) Ready() <-chan error { + return s.readyCh +} + +// Done returns a channel that receives nil on clean shutdown, or an error on unclean shutdown. +// The channel is then closed. Sent after all connections and streams have drained. +func (s *Server) Done() <-chan error { + return s.doneCh } func (s *Server) Run(ctx context.Context) error { - return reliable.RunGroup(ctx, - s.relays.run, - s.clients.run, + n := len(s.clients.ingresses) + len(s.relays.ingresses) + notifyReady := reliable.ReadyNotifier(n, s.readyCh) + + err := reliable.RunGroup(ctx, + func(ctx context.Context) error { return s.relays.run(ctx, notifyReady) }, + func(ctx context.Context) error { return s.clients.run(ctx, notifyReady) }, logc.ScheduleCompact(s.config), ) + + drainCtx, cancel := newDrainCtx(ctx, s.drainTimeout) + defer cancel() + s.clients.waitDrain(drainCtx) + s.relays.waitDrain(drainCtx) + + s.doneCh <- err + close(s.doneCh) + return err } func (s *Server) Status(ctx context.Context) (Status, error) { diff --git a/server/relay/clients.go b/server/relay/clients.go index e5395860..b4a84db9 100644 --- a/server/relay/clients.go +++ b/server/relay/clients.go @@ -47,6 +47,8 @@ type clientsServer struct { endpointsMu sync.RWMutex logger *slog.Logger + + connsWg sync.WaitGroup } func newClientsServer(cfg Config, cert *certc.Cert, auth ClientAuthenticator) (*clientsServer, error) { @@ -189,7 +191,7 @@ type clientsServerCfg struct { removeTransport func(*quic.Transport) } -func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { +func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg, notifyReady func(error)) error { s.logger.Debug("start udp listener", "addr", cfg.ingress.Addr) udpConn, err := net.ListenUDP("udp", cfg.ingress.Addr) if err != nil { @@ -225,6 +227,7 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { l, err := transport.Listen(s.tlsConf, quicConf) if err != nil { + notifyReady(fmt.Errorf("client server udp listen: %w", err)) return fmt.Errorf("client server udp listen: %w", err) } defer func() { @@ -234,6 +237,7 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { }() s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr()) + notifyReady(nil) for { conn, err := l.Accept(ctx) if err != nil { @@ -246,7 +250,17 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + s.connsWg.Add(1) + go func() { defer s.connsWg.Done(); rc.run(ctx) }() + } +} + +func (s *clientsServer) waitDrain(ctx context.Context) { + done := make(chan struct{}) + go func() { s.connsWg.Wait(); close(done) }() + select { + case <-done: + case <-ctx.Done(): } } @@ -256,19 +270,21 @@ type clientConn struct { logger *slog.Logger auth *clientAuth + + streamsWg sync.WaitGroup } func (c *clientConn) run(ctx context.Context) { c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr()) - defer func() { - if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { - slogc.Fine(c.logger, "error closing connection", "err", err) - } - }() if err := c.runErr(ctx); err != nil { c.logger.Debug("error while running client conn", "err", err) } + + if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil { + slogc.Fine(c.logger, "error closing connection", "err", err) + } + c.streamsWg.Wait() } var errNotRecognizedClient = errors.New("client not recognized as a destination or a source") @@ -358,7 +374,8 @@ func (c *clientConn) runSource(ctx context.Context) error { if err != nil { return fmt.Errorf("accept source stream: %w", err) } - go c.runSourceStream(ctx, stream, fcs) + c.streamsWg.Add(1) + go func() { defer c.streamsWg.Done(); c.runSourceStream(ctx, stream, fcs) }() } }) diff --git a/server/relay/server.go b/server/relay/server.go index c885da91..a4797b2c 100644 --- a/server/relay/server.go +++ b/server/relay/server.go @@ -20,6 +20,14 @@ import ( "github.com/quic-go/quic-go" ) +func newDrainCtx(ctx context.Context, fn func(context.Context) time.Duration) (context.Context, context.CancelFunc) { + d := 30 * time.Second + if fn != nil { + d = fn(ctx) + } + return context.WithTimeout(context.Background(), d) +} + type Config struct { Metadata string @@ -35,6 +43,8 @@ type Config struct { Stores Stores Logger *slog.Logger + + DrainTimeout func(ctx context.Context) time.Duration } func NewServer(cfg Config) (*Server, error) { @@ -92,6 +102,10 @@ func NewServer(cfg Config) (*Server, error) { control: control, clients: clients, + + drainTimeout: cfg.DrainTimeout, + readyCh: make(chan error, 1), + doneCh: make(chan error, 1), }, nil } @@ -101,9 +115,40 @@ type Server struct { control *controlClient clients *clientsServer + + drainTimeout func(ctx context.Context) time.Duration + readyCh chan error + doneCh chan error +} + +// Ready returns a channel that receives nil when the server is ready to accept traffic, +// or a non-nil error if startup failed. The channel is then closed. +func (s *Server) Ready() <-chan error { + return s.readyCh +} + +// Done returns a channel that receives nil on clean shutdown, or an error on unclean shutdown. +// The channel is then closed. Sent after all connections and streams have drained. +func (s *Server) Done() <-chan error { + return s.doneCh } func (s *Server) Run(ctx context.Context) error { + n := len(s.ingress) + notifyReady := reliable.ReadyNotifier(n, s.readyCh) + + err := s.run(ctx, notifyReady) + + drainCtx, cancel := newDrainCtx(ctx, s.drainTimeout) + defer cancel() + s.clients.waitDrain(drainCtx) + + s.doneCh <- err + close(s.doneCh) + return err +} + +func (s *Server) run(ctx context.Context, notifyReady func(error)) error { transports := notify.NewEmpty[[]*quic.Transport]() var waitForTransport TransportsFn = func(ctx context.Context) ([]*quic.Transport, error) { t, _, err := transports.GetAny(ctx) @@ -123,7 +168,7 @@ func (s *Server) Run(ctx context.Context) error { notify.SliceRemove(transports, t) }, } - g.Go(reliable.Bind(cfg, s.clients.run)) + g.Go(func(ctx context.Context) error { return s.clients.run(ctx, cfg, notifyReady) }) } g.Go(reliable.Bind(waitForTransport, s.control.run)) diff --git a/server/server.go b/server/server.go index a14ffa66..183a2a24 100644 --- a/server/server.go +++ b/server/server.go @@ -15,11 +15,15 @@ import ( "github.com/connet-dev/connet/server/selfhosted" ) + type Server struct { serverConfig control *control.Server relay *relay.Server + + readyCh chan error + doneCh chan error } func New(opts ...Option) (*Server, error) { @@ -71,6 +75,8 @@ func New(opts ...Option) (*Server, error) { Stores: control.NewFileStores(filepath.Join(cfg.dir, "control")), Logger: cfg.logger, + + DrainTimeout: cfg.drainTimeout, }) if err != nil { return nil, fmt.Errorf("create control server: %w", err) @@ -88,6 +94,8 @@ func New(opts ...Option) (*Server, error) { Stores: relay.NewFileStores(filepath.Join(cfg.dir, "relay")), Logger: cfg.logger, + + DrainTimeout: cfg.drainTimeout, }) if err != nil { return nil, fmt.Errorf("create relay server: %w", err) @@ -98,14 +106,39 @@ func New(opts ...Option) (*Server, error) { control: control, relay: relay, + + readyCh: make(chan error, 1), + doneCh: make(chan error, 1), }, nil } +// Ready returns a channel that receives nil when the server is ready to accept traffic, +// or a non-nil error if startup failed. The channel is then closed. +func (s *Server) Ready() <-chan error { + return s.readyCh +} + +// Done returns a channel that receives nil on clean shutdown, or an error on unclean shutdown. +// The channel is then closed. Sent after all connections and streams have drained. +func (s *Server) Done() <-chan error { + return s.doneCh +} + func (s *Server) Run(ctx context.Context) error { + notifyReady := reliable.ReadyNotifier(2, s.readyCh) + for _, ch := range []<-chan error{s.control.Ready(), s.relay.Ready()} { + ch := ch + go func() { notifyReady(<-ch) }() + } + g := reliable.NewGroup(ctx) g.Go(s.control.Run) g.Go(s.relay.Run) - return g.Wait() + err := g.Wait() + + s.doneCh <- err + close(s.doneCh) + return err } func (s *Server) Status(ctx context.Context) (ServerStatus, error) {