Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func (d *destinationConn) run(ctx context.Context) {
}

func (d *destinationConn) runErr(ctx context.Context) error {
// TODO do we need graceful shutdown of these streams?
for {
stream, err := d.conn.AcceptStream(ctx)
if err != nil {
Expand Down
39 changes: 39 additions & 0 deletions pkg/reliable/drain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package reliable

import (
"context"
"fmt"
"sync"
"time"
)

type Drain struct {
ctx context.Context
wg sync.WaitGroup
}

func NewDrain(ctx context.Context) *Drain {
return &Drain{ctx: ctx}
}

func (d *Drain) Go(f func(context.Context)) {
d.wg.Go(func() { f(d.ctx) })
}

func (d *Drain) Wait(timeout time.Duration) error {
done := make(chan struct{})
go func() {
d.wg.Wait()
close(done)
}()

drainCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

select {
case <-done:
return nil
case <-drainCtx.Done():
return fmt.Errorf("drain timeout")
}
}
9 changes: 7 additions & 2 deletions pkg/statusc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ func Run[T any](ctx context.Context, addr *net.TCPAddr, f func(ctx context.Conte

go func() {
<-ctx.Done()
if err := srv.Close(); err != nil {
slogc.FineDefault("error closing status server", "err", err)

shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if err := srv.Shutdown(shutdownCtx); err != nil {
slogc.FineDefault("error shutting down status server", "err", err)
}
}()

// TODO return no error on server shutdown?
return srv.ListenAndServe()
}
25 changes: 21 additions & 4 deletions server/control/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,18 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {
}
}()

drain := reliable.NewDrain(ctx)
defer func() {
if err := drain.Wait(10 * time.Second); err != nil {
s.logger.Warn("clients drain", "err", err)
}
}()

s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr())
for {
conn, err := l.Accept(ctx)
if err != nil {
// TODO return no error on context cancel?
slogc.Fine(s.logger, "accept error", "err", err)
return fmt.Errorf("client server quic accept: %w", err)
}
Expand All @@ -352,7 +360,7 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {
conn: conn,
logger: s.logger,
}
go cc.run(ctx)
drain.Go(cc.run)
}
}

Expand Down Expand Up @@ -479,6 +487,7 @@ type clientConnAuth struct {
func (c *clientConn) run(ctx context.Context) {
c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr())
defer func() {
// TODO richer errors
if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil {
slogc.Fine(c.logger, "error closing connection", "err", err)
}
Expand Down Expand Up @@ -517,6 +526,13 @@ func (c *clientConn) runErr(ctx context.Context) error {
}
}()

drain := reliable.NewDrain(ctx)
defer func() {
if err := drain.Wait(5 * time.Second); err != nil {
c.logger.Warn("client streams drain", "err", err)
}
}()

for {
stream, err := c.conn.AcceptStream(ctx)
if err != nil {
Expand All @@ -527,7 +543,7 @@ func (c *clientConn) runErr(ctx context.Context) error {
conn: c,
stream: stream,
}
go cs.run(ctx)
drain.Go(cs.run)
}
}

Expand Down Expand Up @@ -606,6 +622,7 @@ type clientStream struct {

func (s *clientStream) run(ctx context.Context) {
defer func() {
// TODO richer errors
if err := s.stream.Close(); err != nil {
slogc.Fine(s.conn.logger, "error closing client stream", "err", err)
}
Expand Down Expand Up @@ -669,8 +686,8 @@ func (s *clientStream) announce(ctx context.Context, req *pbclient.Request_Annou
return err
}
defer func() {
if s.conn.server.endpointExpiry > 0 && s.conn.conn.Context().Err() != nil {
// Connection dead — mark as expired, consumer will delete after timeout
if s.conn.server.endpointExpiry > 0 && (ctx.Err() != nil || s.conn.conn.Context().Err() != nil) {
// Server shutting down or connection dead — mark as expired, consumer will delete after timeout
if err := s.conn.server.expire(endpoint, role, s.conn.id, s.conn.connID); err != nil {
s.conn.logger.Warn("failed to expire peer", "id", s.conn.id, "err", err)
}
Expand Down
11 changes: 10 additions & 1 deletion server/control/relays.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"maps"
"net"
"sync"
"time"

"github.com/connet-dev/connet/model"
"github.com/connet-dev/connet/pkg/iterc"
Expand Down Expand Up @@ -279,10 +280,18 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {
}
}()

drain := reliable.NewDrain(ctx)
defer func() {
if err := drain.Wait(10 * time.Second); err != nil {
s.logger.Warn("relays drain", "err", err)
}
}()

s.logger.Info("accepting relay connections", "addr", transport.Conn.LocalAddr())
for {
conn, err := l.Accept(ctx)
if err != nil {
// TODO return no error on context cancel?
slogc.Fine(s.logger, "accept error", "err", err)
return fmt.Errorf("relay server quic accept: %w", err)
}
Expand All @@ -292,7 +301,7 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {
conn: conn,
logger: s.logger,
}
go rc.run(ctx)
drain.Go(rc.run)
}
}

Expand Down
14 changes: 13 additions & 1 deletion server/relay/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"slices"
"sync"
"time"

"github.com/connet-dev/connet/model"
"github.com/connet-dev/connet/pkg/certc"
Expand Down Expand Up @@ -233,10 +234,18 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error {
}
}()

drain := reliable.NewDrain(ctx)
defer func() {
if err := drain.Wait(10 * time.Second); err != nil {
s.logger.Warn("clients drain", "err", err)
}
}()

s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr())
for {
conn, err := l.Accept(ctx)
if err != nil {
// TODO return no error on context cancel?
slogc.Fine(s.logger, "accept error", "err", err)
return fmt.Errorf("client server quic accept: %w", err)
}
Expand All @@ -246,7 +255,7 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error {
conn: conn,
logger: s.logger,
}
go rc.run(ctx)
drain.Go(rc.run)
}
}

Expand All @@ -261,6 +270,7 @@ type clientConn struct {
func (c *clientConn) run(ctx context.Context) {
c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr())
defer func() {
// TODO richer errors
if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil {
slogc.Fine(c.logger, "error closing connection", "err", err)
}
Expand Down Expand Up @@ -353,6 +363,7 @@ func (c *clientConn) runSource(ctx context.Context) error {
})

g.Go(func(ctx context.Context) error {
// TODO do we need graceful shutdown of these streams?
for {
stream, err := c.conn.AcceptStream(ctx)
if err != nil {
Expand All @@ -367,6 +378,7 @@ func (c *clientConn) runSource(ctx context.Context) error {

func (c *clientConn) runSourceStream(ctx context.Context, stream *quic.Stream, fcs *endpointClients) {
defer func() {
// TODO richer errors
if err := stream.Close(); err != nil {
slogc.Fine(c.logger, "error closing source stream", "err", err)
}
Expand Down