Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pkg/reliable/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,35 @@ package reliable

import (
"context"
"sync"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"
)

type RunFn func(context.Context) error

// ReadyNotifier returns a func that waits for n nil calls before sending nil to ch.
// Any non-nil error immediately sends to ch. After sending, ch is closed.
func ReadyNotifier(n int, ch chan<- error) func(error) {
var once sync.Once
var remaining atomic.Int32
remaining.Store(int32(n))
send := func(err error) {
once.Do(func() { ch <- err; close(ch) })
}
return func(err error) {
if err != nil {
send(err)
return
}
if remaining.Add(-1) == 0 {
send(nil)
}
}
}

func Bind[T any](t T, fn func(context.Context, T) error) RunFn {
return func(ctx context.Context) error {
return fn(ctx, t)
Expand Down
10 changes: 10 additions & 0 deletions server/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"fmt"
"log/slog"
"net"
Expand All @@ -20,6 +21,8 @@ type serverConfig struct {

relayIngresses []relay.Ingress

drainTimeout func(ctx context.Context) time.Duration

dir string
logger *slog.Logger
}
Expand Down Expand Up @@ -143,6 +146,13 @@ func Logger(logger *slog.Logger) Option {
}
}

func ServerDrainTimeout(fn func(ctx context.Context) time.Duration) Option {
return func(cfg *serverConfig) error {
cfg.drainTimeout = fn
return nil
}
}

func StoreDirFromEnvPrefixed(prefix string) (string, error) {
if stateDir := os.Getenv("CONNET_STATE_DIR"); stateDir != "" {
// Support direct override if necessary, currently used in docker
Expand Down
38 changes: 28 additions & 10 deletions server/control/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ type clientServer struct {
peersMu sync.RWMutex

endpointExpiry time.Duration

connsWg sync.WaitGroup
}

type peerKey struct {
Expand Down Expand Up @@ -276,11 +278,12 @@ func (s *clientServer) listen(ctx context.Context, endpoint model.Endpoint, role
}
}

func (s *clientServer) run(ctx context.Context) error {
func (s *clientServer) run(ctx context.Context, notifyReady func(error)) error {
g := reliable.NewGroup(ctx)

for _, ingress := range s.ingresses {
g.Go(reliable.Bind(ingress, s.runListener))
ingress := ingress
g.Go(func(ctx context.Context) error { return s.runListener(ctx, ingress, notifyReady) })
}
g.Go(s.runPeerCache)
if s.endpointExpiry > 0 {
Expand All @@ -293,7 +296,16 @@ func (s *clientServer) run(ctx context.Context) error {
return g.Wait()
}

func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {
func (s *clientServer) waitDrain(ctx context.Context) {
done := make(chan struct{})
go func() { s.connsWg.Wait(); close(done) }()
select {
case <-done:
case <-ctx.Done():
}
}

func (s *clientServer) runListener(ctx context.Context, ingress Ingress, notifyReady func(error)) error {
s.logger.Debug("start udp listener", "addr", ingress.Addr)
udpConn, err := net.ListenUDP("udp", ingress.Addr)
if err != nil {
Expand Down Expand Up @@ -331,6 +343,7 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {

l, err := transport.Listen(tlsConf, quicConf)
if err != nil {
notifyReady(fmt.Errorf("client server quic listen: %w", err))
return fmt.Errorf("client server quic listen: %w", err)
}
defer func() {
Expand All @@ -340,6 +353,7 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {
}()

s.logger.Info("accepting client connections", "addr", transport.Conn.LocalAddr())
notifyReady(nil)
for {
conn, err := l.Accept(ctx)
if err != nil {
Expand All @@ -352,7 +366,8 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error {
conn: conn,
logger: s.logger,
}
go cc.run(ctx)
s.connsWg.Add(1)
go func() { defer s.connsWg.Done(); cc.run(ctx) }()
}
}

Expand Down Expand Up @@ -467,6 +482,8 @@ type clientConn struct {
connID ConnID

clientConnAuth

streamsWg sync.WaitGroup
}

type clientConnAuth struct {
Expand All @@ -478,15 +495,15 @@ type clientConnAuth struct {

func (c *clientConn) run(ctx context.Context) {
c.logger.Debug("new client connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr())
defer func() {
if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil {
slogc.Fine(c.logger, "error closing connection", "err", err)
}
}()

if err := c.runErr(ctx); err != nil {
c.logger.Debug("error while running client conn", "err", err)
}

if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil {
slogc.Fine(c.logger, "error closing connection", "err", err)
}
c.streamsWg.Wait()
}

func (c *clientConn) runErr(ctx context.Context) error {
Expand Down Expand Up @@ -527,7 +544,8 @@ func (c *clientConn) runErr(ctx context.Context) error {
conn: c,
stream: stream,
}
go cs.run(ctx)
c.streamsWg.Add(1)
go func() { defer c.streamsWg.Done(); cs.run(ctx) }()
}
}

Expand Down
32 changes: 23 additions & 9 deletions server/control/relays.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ type relayServer struct {
connsCache map[RelayID]cachedRelay
connsOffset int64
connsMu sync.RWMutex

connsWg sync.WaitGroup
}

type cachedRelay struct {
Expand Down Expand Up @@ -220,11 +222,12 @@ func (s *relayServer) Relays(ctx context.Context, endpoint model.Endpoint, role
}
}

func (s *relayServer) run(ctx context.Context) error {
func (s *relayServer) run(ctx context.Context, notifyReady func(error)) error {
g := reliable.NewGroup(ctx)

for _, ingress := range s.ingresses {
g.Go(reliable.Bind(ingress, s.runListener))
ingress := ingress
g.Go(func(ctx context.Context) error { return s.runListener(ctx, ingress, notifyReady) })
}
g.Go(s.runConnsCache)

Expand All @@ -233,7 +236,16 @@ func (s *relayServer) run(ctx context.Context) error {
return g.Wait()
}

func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {
func (s *relayServer) waitDrain(ctx context.Context) {
done := make(chan struct{})
go func() { s.connsWg.Wait(); close(done) }()
select {
case <-done:
case <-ctx.Done():
}
}

func (s *relayServer) runListener(ctx context.Context, ingress Ingress, notifyReady func(error)) error {
s.logger.Debug("start udp listener", "addr", ingress.Addr)
udpConn, err := net.ListenUDP("udp", ingress.Addr)
if err != nil {
Expand Down Expand Up @@ -271,6 +283,7 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {

l, err := transport.Listen(tlsConf, quicConf)
if err != nil {
notifyReady(fmt.Errorf("relay server quic listen: %w", err))
return fmt.Errorf("relay server quic listen: %w", err)
}
defer func() {
Expand All @@ -280,6 +293,7 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {
}()

s.logger.Info("accepting relay connections", "addr", transport.Conn.LocalAddr())
notifyReady(nil)
for {
conn, err := l.Accept(ctx)
if err != nil {
Expand All @@ -292,7 +306,8 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {
conn: conn,
logger: s.logger,
}
go rc.run(ctx)
s.connsWg.Add(1)
go func() { defer s.connsWg.Done(); rc.run(ctx) }()
}
}

Expand Down Expand Up @@ -358,16 +373,15 @@ type relayConnAuth struct {
}

func (c *relayConn) run(ctx context.Context) {
defer func() {
if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil {
slogc.Fine(c.logger, "error closing connection", "err", err)
}
}()
c.logger.Debug("new relay connection", "proto", c.conn.ConnectionState().TLS.NegotiatedProtocol, "remote", c.conn.RemoteAddr())

if err := c.runErr(ctx); err != nil {
c.logger.Debug("error while running relay conn", "err", err)
}

if err := c.conn.CloseWithError(quic.ApplicationErrorCode(pberror.Code_Unknown), "connection closed"); err != nil {
slogc.Fine(c.logger, "error closing connection", "err", err)
}
}

func (c *relayConn) runErr(ctx context.Context) error {
Expand Down
48 changes: 45 additions & 3 deletions server/control/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ import (
"github.com/connet-dev/connet/pkg/restr"
)

func newDrainCtx(ctx context.Context, fn func(context.Context) time.Duration) (context.Context, context.CancelFunc) {
d := 30 * time.Second
if fn != nil {
d = fn(ctx)
}
return context.WithTimeout(context.Background(), d)
}

type Config struct {
ClientsIngress []Ingress
ClientsAuth ClientAuthenticator
Expand All @@ -24,6 +32,8 @@ type Config struct {
Stores Stores

Logger *slog.Logger

DrainTimeout func(ctx context.Context) time.Duration
}

func NewServer(cfg Config) (*Server, error) {
Expand Down Expand Up @@ -51,6 +61,10 @@ func NewServer(cfg Config) (*Server, error) {
relays: relays,

config: configStore,

drainTimeout: cfg.DrainTimeout,
readyCh: make(chan error, 1),
doneCh: make(chan error, 1),
}, nil
}

Expand All @@ -59,14 +73,42 @@ type Server struct {
relays *relayServer

config logc.KV[ConfigKey, ConfigValue]

drainTimeout func(ctx context.Context) time.Duration
readyCh chan error
doneCh chan error
}

// Ready returns a channel that receives nil when the server is ready to accept traffic,
// or a non-nil error if startup failed. The channel is then closed.
func (s *Server) Ready() <-chan error {
return s.readyCh
}

// Done returns a channel that receives nil on clean shutdown, or an error on unclean shutdown.
// The channel is then closed. Sent after all connections and streams have drained.
func (s *Server) Done() <-chan error {
return s.doneCh
}

func (s *Server) Run(ctx context.Context) error {
return reliable.RunGroup(ctx,
s.relays.run,
s.clients.run,
n := len(s.clients.ingresses) + len(s.relays.ingresses)
notifyReady := reliable.ReadyNotifier(n, s.readyCh)

err := reliable.RunGroup(ctx,
func(ctx context.Context) error { return s.relays.run(ctx, notifyReady) },
func(ctx context.Context) error { return s.clients.run(ctx, notifyReady) },
logc.ScheduleCompact(s.config),
)

drainCtx, cancel := newDrainCtx(ctx, s.drainTimeout)
defer cancel()
s.clients.waitDrain(drainCtx)
s.relays.waitDrain(drainCtx)

s.doneCh <- err
close(s.doneCh)
return err
}

func (s *Server) Status(ctx context.Context) (Status, error) {
Expand Down
Loading