From 10e7fcd3a8701cc1c27b03cf03f255b01b9e17d2 Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 00:33:47 -0600 Subject: [PATCH 1/9] some cleanup --- cmd/yatun/main.go | 56 +++++++++++++++-------- cmd/yatund/main.go | 21 ++++++--- internal/config/config.go | 13 ++++-- internal/message/message.go | 42 +++++------------- internal/server/server.go | 88 ++++++++++++++++++------------------- internal/tls/tlsutil.go | 7 +-- internal/tui/tui.go | 15 +++++++ 7 files changed, 136 insertions(+), 106 deletions(-) diff --git a/cmd/yatun/main.go b/cmd/yatun/main.go index fe5a05e..e4c8b55 100644 --- a/cmd/yatun/main.go +++ b/cmd/yatun/main.go @@ -1,7 +1,6 @@ package main import ( - "errors" "flag" "fmt" "io" @@ -18,34 +17,50 @@ import ( "github.com/hashicorp/yamux" ) -func sendMsg(con *yamux.Stream, m message.TransportMessage) { - byt := m.Encode() +func sendMsg(con *yamux.Stream, m message.TransportMessage) (err error) { + byt, err := m.Encode() + if err != nil { + return err + } + + _, err = con.Write(byt) + return - con.Write(byt) } -func clientServerComms(ses *yamux.Session, tuiP *tea.Program) { +func clientServerComms(ses *yamux.Session, tuiP *tea.Program) (err error) { con, err := ses.OpenStream() if err != nil { - tuiP.Kill() - panic(errors.New("failed to accept new stream, is the server running?")) - } - - dat := message.ConnectionDetailsMessageData{ - SubdomainName: "asd", + tuiP.Send(tui.SetState{ + State: tui.ErrorState, + Err: err, + }) + tuiP.Quit() + return } - sendMsg(con, message.TransportMessage{ - Type: message.ConnectionDetails, - Data: &dat, + err = sendMsg(con, message.TransportMessage{ + Type: message.OpenMsg, }) + if err != nil { + tuiP.Send(tui.SetState{ + State: tui.ErrorState, + Err: fmt.Errorf("failed to send initial message to server, is the server running?\n%v", err), + }) + tuiP.Quit() + return err + } go func() { for { msg, err := message.Decode(con) if err != nil { - tuiP.Kill() - panic(errors.New("the server closed unexpectedly")) + tuiP.Send(tui.SetState{ + State: tui.ErrorState, + Err: fmt.Errorf("failed at decoding server message\n%v", err), + }) + tuiP.Quit() + return } switch msg.Type { @@ -60,6 +75,7 @@ func clientServerComms(ses *yamux.Session, tuiP *tea.Program) { } } }() + return nil } func initializeServerConnection(tuiP *tea.Program, server string) (sess *yamux.Session, err error) { @@ -74,7 +90,7 @@ func initializeServerConnection(tuiP *tea.Program, server string) (sess *yamux.S return } - clientServerComms(sess, tuiP) + err = clientServerComms(sess, tuiP) return } @@ -108,6 +124,7 @@ func serverConnectionLoop(sess *yamux.Session, port *string, tuiP *tea.Program) Err: err, State: tui.ErrorState, }) + tuiP.Quit() return } @@ -175,9 +192,12 @@ func main() { sess, err := initializeServerConnection(tuiP, *server) if err != nil { go tuiP.Send(tui.SetState{ - Err: err, + + Err: &tui.FailedInitialConfigError{ + Err: err}, State: tui.ErrorState, }) + // tuiP.Quit() } if err == nil { diff --git a/cmd/yatund/main.go b/cmd/yatund/main.go index d8ce9fb..91148b6 100644 --- a/cmd/yatund/main.go +++ b/cmd/yatund/main.go @@ -9,7 +9,11 @@ import ( ) func main() { - conf := config.ReadFromEnv() + conf, err := config.ReadFromEnv() + if err != nil { + log.Printf("FATAL: Failed to parse configuration\n%v", err) + return + } // So, we need to open a tcp server for external connections and another one for the agent. // But the external should only open if an agent requests it, so the very first server is the agent one. @@ -17,7 +21,7 @@ func main() { // The agents listener listener, err := net.Listen("tcp", "0.0.0.0:5678") if err != nil { - panic(err) + log.Printf("Error binding to port 5678: %v", err) } log.Printf("Listening on %v", listener.Addr().String()) @@ -26,14 +30,21 @@ func main() { conn, err := listener.Accept() if err != nil { - panic(err) + log.Printf("Failed to accept client: %v", err) + continue } sconn, err := server.NewServerConnection(conn, conf) if err != nil { - panic(err) + log.Printf("Failed on the creation of a new server: %v", err) + continue } - go sconn.StartListeningAgents() + go func() { + err := sconn.StartListeningAgents() + if err != nil { + log.Printf("Failed starting the agent(s) setup and loop: %v", err) + } + }() } } diff --git a/internal/config/config.go b/internal/config/config.go index 4e1ec5c..257619f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,7 +17,7 @@ type ServerConfig struct { TLSStore tlsutil.TLSStore } -func ReadFromEnv() ServerConfig { +func ReadFromEnv() (ServerConfig, error) { c := ServerConfig{} domain, ok := os.LookupEnv("Domain") if ok { @@ -26,7 +26,8 @@ func ReadFromEnv() ServerConfig { } u, err := url.Parse(domain) if err != nil { - panic(err) + log.Printf("Failed to parse set domain: %v", err) + return c, err } log.Printf("Host: %v", u.Hostname()) c.Domain = new(u.Hostname()) @@ -35,8 +36,12 @@ func ReadFromEnv() ServerConfig { tls, ok := os.LookupEnv("TLS") if ok && tls != "0" { c.TLS = true - c.TLSStore = tlsutil.LoadTLSCerts() + store, err := tlsutil.LoadTLSCerts() + if err != nil { + return c, err + } + c.TLSStore = store } - return c + return c, nil } diff --git a/internal/message/message.go b/internal/message/message.go index e869e72..414c74a 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "log" "time" "io" @@ -14,25 +15,6 @@ type MessagePayload interface { Decode(b []byte) error } -type ConnectionDetailsMessageData struct { - SubdomainName string `json:"address"` -} - -func (c ConnectionDetailsMessageData) Encode() ([]byte, error) { - return json.Marshal(c) -} -func (c *ConnectionDetailsMessageData) Decode(b []byte) error { - var t ConnectionDetailsMessageData - err := json.Unmarshal(b, &t) - if err != nil { - return err - } - - // log.Printf("Subdom %v", t.SubdomainName) - c.SubdomainName = t.SubdomainName - return nil -} - type Response struct { Ok bool `json:"ok"` Address string `json:"address"` @@ -80,16 +62,13 @@ type TransportMessage struct { Data MessagePayload } -var TCPConnection MessageType = 'A' // Agent -> Server msg type, used for signaling the agent is asking for TCP, no data transfer. -var ConnectionDetails MessageType = 'B' // Server -> Agent msg type, used to transfer data about the server created, data is transferred (json, maybe change later) +var TCPConnection MessageType = 'A' // Agent -> Server msg type, used for signaling the agent is asking for TCP, no data transfer. +var OpenMsg MessageType = 'B' // Server -> Agent msg type, used to transfer data about the server created, data is transferred (json, maybe change later) var ResponseMessage MessageType = 'C' var PingMessageType MessageType = 'D' -func (m TransportMessage) GetConnectionDetails() (*MessagePayload, error) { - return &m.Data, nil -} - func ParseConnectionType(b byte) (*MessageType, error) { + log.Printf("Byte rec %v", b) switch b { case byte(TCPConnection): return &TCPConnection, nil @@ -99,7 +78,7 @@ func ParseConnectionType(b byte) (*MessageType, error) { } } -func (m TransportMessage) Encode() []byte { +func (m TransportMessage) Encode() ([]byte, error) { // Shape // [type: 1byte][dataLen: 2 byte][data: ?bytes] out := make([]byte, 1) @@ -110,7 +89,7 @@ func (m TransportMessage) Encode() []byte { dataBuf, err := m.Data.Encode() if err != nil { - panic(err) + return nil, err } l := uint16(len(dataBuf)) @@ -122,7 +101,7 @@ func (m TransportMessage) Encode() []byte { out = append(out, dataBuf...) // Data } - return out + return out, nil } func readSize(r io.Reader) (uint16, error) { @@ -159,6 +138,10 @@ func Decode(r io.Reader) (*TransportMessage, error) { return &m, nil } + if mType == OpenMsg { + return &m, nil + } + // Types that have some data, read and put it into the message for later Decoding. // Read the next data stuff bufSize, err := readSize(r) @@ -173,8 +156,7 @@ func Decode(r io.Reader) (*TransportMessage, error) { } var obj MessagePayload switch mType { - case ConnectionDetails: - obj = &ConnectionDetailsMessageData{} + case ResponseMessage: obj = &Response{} case PingMessageType: diff --git a/internal/server/server.go b/internal/server/server.go index 33af7b6..e7bb23a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -18,44 +18,20 @@ import ( ) type ServerConnection struct { - agentSession *yamux.Session - connectionType *message.MessageType - connectionDetails *message.ConnectionDetailsMessageData - config config.ServerConfig + agentSession *yamux.Session + connectionType *message.MessageType + + config config.ServerConfig } func (s *ServerConnection) handleInitialConfig(stream *yamux.Stream) error { - // Since this is supposed to be called once only, we don't need a for - // TODO: Ctx with timeout maybe? - // - - // buf := make([]byte, 1) - // _, err := stream.Read(buf) - // if err != nil { - // return err - // } - // + // TODO: Maybe this initial message could be used for something? I'm not quite sure what just yet though m, err := message.Decode(stream) if err != nil { return err } - if m.Type == message.ConnectionDetails { - details, err := m.GetConnectionDetails() - if err != nil { - return err - } - - if details == nil { - return errors.New("no details") - } - - // dr := *details - pDet := (*details).(*message.ConnectionDetailsMessageData) - log.Printf("Config with %v addr", pDet.SubdomainName) - - s.connectionDetails = pDet - } + log.Printf("Received message %v", m) s.connectionType = &m.Type return nil @@ -66,13 +42,6 @@ func (s *ServerConnection) setupServer() (net.Listener, error) { return net.Listen("tcp", "0.0.0.0:") } -func pipe(out io.Writer, in io.Reader) { - _, err := io.Copy(out, in) - if err != nil { - log.Printf("Error while copying %v", err) - } -} - func (s *ServerConnection) handleNewConn(conn net.Conn) { agConn, err := s.agentSession.OpenStream() if err != nil { @@ -188,7 +157,9 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List if s.config.Domain != nil { _, port, err := net.SplitHostPort(server.Addr().String()) if err != nil { - panic(err) + log.Printf("Error parsing the host port: %v", err) + return err + } addr = fmt.Sprintf("%v:%v", *s.config.Domain, port) @@ -202,7 +173,12 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List }, } - _, err := stream.Write(m.Encode()) + byt, err := m.Encode() + if err != nil { + return err + } + + _, err = stream.Write(byt) if err != nil { log.Printf("Error sending to the client response info") return err @@ -211,7 +187,7 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List } -func pingLoop(stream *yamux.Stream) { +func pingLoop(stream *yamux.Stream) error { t := time.NewTicker(time.Second * 5) // TODO: Change to something like 30s defer t.Stop() @@ -223,11 +199,17 @@ func pingLoop(stream *yamux.Stream) { }, } - _, err := stream.Write(m.Encode()) + byt, err := m.Encode() if err != nil { - return + return err + } + + _, err = stream.Write(byt) + if err != nil { + return err } } + return nil } @@ -245,14 +227,17 @@ func (s *ServerConnection) StartListeningAgents() error { } // When we receive an agent stream, listen and parse for the first handshake and then we can start copying over / creating new sessions - s.handleInitialConfig(stream) + err = s.handleInitialConfig(stream) + if err != nil { + return err + } // After config, we setup whatever server type needs to be opened, and start copying over // defer stream.Close() server, err := s.setupServer() if err != nil { - panic(err) + return err } err = s.sendAddressInfo(stream, server) @@ -261,11 +246,22 @@ func (s *ServerConnection) StartListeningAgents() error { return err } - go pingLoop(stream) + go func() { + err := pingLoop(stream) + if err != nil { + log.Printf("Error executing the ping loop\n%v", err) + s.agentSession.Close() + return + } + }() go func() { <-s.agentSession.CloseChan() - server.Close() + err := server.Close() + if err != nil { + log.Printf("Error closing the server: %v", err) + return + } log.Printf("Server %v closed", server.Addr()) }() diff --git a/internal/tls/tlsutil.go b/internal/tls/tlsutil.go index cd6895e..d97fa70 100644 --- a/internal/tls/tlsutil.go +++ b/internal/tls/tlsutil.go @@ -20,10 +20,11 @@ func (t TLSStore) Wrap(conn net.Conn) (*tls.Conn, error) { return serv, nil } -func LoadTLSCerts() TLSStore { +func LoadTLSCerts() (TLSStore, error) { cert, err := tls.LoadX509KeyPair("certs/cert.cer", "certs/cert_key.key") if err != nil { - panic(err) + log.Printf("Failed to load certificates: %v", err) + return TLSStore{}, err } conf := &tls.Config{ @@ -32,5 +33,5 @@ func LoadTLSCerts() TLSStore { return TLSStore{ conf: conf, - } + }, nil } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index c037ece..950f919 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,8 +1,10 @@ package tui import ( + "errors" "fmt" "image/color" + "strings" "time" @@ -12,6 +14,14 @@ import ( "github.com/KatIsCoding/yatun/internal/message" ) +type FailedInitialConfigError struct { + Err error +} + +func (f *FailedInitialConfigError) Error() string { + return f.Err.Error() +} + type ServerAddress struct { Addr string } @@ -96,6 +106,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case SetState: m.state = msg.State m.err = msg.Err + + // var x *FailedInitialConfigError + if _, ok := errors.AsType[*FailedInitialConfigError](msg.Err); ok { + return m, tea.Quit + } case tea.KeyPressMsg: switch msg.String() { case "a": From aa07d7f892af7f389fdec5753109dfacdd83a162 Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 17:31:33 -0600 Subject: [PATCH 2/9] Better error/ctx handling --- cmd/yatun/main.go | 5 ++- cmd/yatund/main.go | 24 +++++++++-- go.mod | 21 ++++------ go.sum | 41 +++---------------- internal/server/server.go | 83 ++++++++++++++++++++++++++------------- 5 files changed, 93 insertions(+), 81 deletions(-) diff --git a/cmd/yatun/main.go b/cmd/yatun/main.go index e4c8b55..5d15125 100644 --- a/cmd/yatun/main.go +++ b/cmd/yatun/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "io" + "sync/atomic" "os" "sync" @@ -208,7 +209,9 @@ func main() { } if _, err := tuiP.Run(); err != nil { - sess.Close() + if sess != nil { + sess.Close() + } os.Exit(1) } } diff --git a/cmd/yatund/main.go b/cmd/yatund/main.go index 91148b6..391dbde 100644 --- a/cmd/yatund/main.go +++ b/cmd/yatund/main.go @@ -1,18 +1,26 @@ package main import ( + "context" "log" "net" + "os/signal" + "syscall" + "github.com/KatIsCoding/yatun/internal/config" "github.com/KatIsCoding/yatun/internal/server" ) func main() { + ctx := context.Background() + + ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer cancel() + conf, err := config.ReadFromEnv() if err != nil { - log.Printf("FATAL: Failed to parse configuration\n%v", err) - return + log.Fatalf("FATAL: Failed to parse configuration\n%v", err) } // So, we need to open a tcp server for external connections and another one for the agent. @@ -22,14 +30,24 @@ func main() { listener, err := net.Listen("tcp", "0.0.0.0:5678") if err != nil { log.Printf("Error binding to port 5678: %v", err) + return } + go func() { + <-ctx.Done() + listener.Close() + }() + log.Printf("Listening on %v", listener.Addr().String()) for { conn, err := listener.Accept() if err != nil { + if ctx.Err() != nil { + log.Printf("Loop break, context canceled") + return + } log.Printf("Failed to accept client: %v", err) continue } @@ -41,7 +59,7 @@ func main() { } go func() { - err := sconn.StartListeningAgents() + err := sconn.StartListeningAgents(ctx) if err != nil { log.Printf("Failed starting the agent(s) setup and loop: %v", err) } diff --git a/go.mod b/go.mod index 6294ec6..9ede070 100644 --- a/go.mod +++ b/go.mod @@ -3,33 +3,26 @@ module github.com/KatIsCoding/yatun go 1.26.1 require ( - charm.land/bubbles/v2 v2.1.0 // indirect - charm.land/bubbletea/v2 v2.0.6 // indirect - charm.land/lipgloss/v2 v2.0.3 // indirect - github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/bubbletea v1.3.10 // indirect + charm.land/bubbles/v2 v2.1.0 + charm.land/bubbletea/v2 v2.0.6 + charm.land/lipgloss/v2 v2.0.3 + github.com/hashicorp/yamux v0.1.2 +) + +require ( github.com/charmbracelet/colorprofile v0.4.3 // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468 // indirect github.com/charmbracelet/x/ansi v0.11.7 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/term v0.2.2 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect - github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/hashicorp/yamux v0.1.2 // indirect github.com/lucasb-eyer/go-colorful v1.4.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.23 // indirect - github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.16.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.3.8 // indirect ) diff --git a/go.sum b/go.sum index 264bb70..a7ad83b 100644 --- a/go.sum +++ b/go.sum @@ -4,26 +4,16 @@ charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo= charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g= charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU= charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA= -github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= -github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= -github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o= +github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w= github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q= github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q= -github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= -github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468 h1:Q9fO0y1Zo5KB/5Vu8JZoLGm1N3RzF9bNj3Ao3xoR+Ac= github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468/go.mod h1:bAAz7dh/FTYfC+oiHavL4mX1tOIBZ0ZwYjSi3qE6ivM= -github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= -github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI= github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= -github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= -github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA= +github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= @@ -34,40 +24,21 @@ github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSE github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4= github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= -github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= -github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= -github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= diff --git a/internal/server/server.go b/internal/server/server.go index e7bb23a..018d661 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "context" "errors" "fmt" @@ -22,6 +23,8 @@ type ServerConnection struct { connectionType *message.MessageType config config.ServerConfig + + serverCancelFunc context.CancelFunc } func (s *ServerConnection) handleInitialConfig(stream *yamux.Stream) error { @@ -37,37 +40,60 @@ func (s *ServerConnection) handleInitialConfig(stream *yamux.Stream) error { return nil } -func (s *ServerConnection) setupServer() (net.Listener, error) { - // TODO: Expand this for more server options - return net.Listen("tcp", "0.0.0.0:") +func (s *ServerConnection) setupServer(ctx context.Context) (net.Listener, error) { + + server, err := net.Listen("tcp", "0.0.0.0:") + if err != nil { + log.Printf("Error starting TCP server: %v", err) + return server, err + } + ctx, s.serverCancelFunc = context.WithCancel(ctx) + + go func() { + <-ctx.Done() + + err := server.Close() + if err != nil { + log.Printf("Failed to close the server: %v", err) + } + }() + + return server, nil } -func (s *ServerConnection) handleNewConn(conn net.Conn) { +func (s *ServerConnection) handleNewConn(ctx context.Context, conn net.Conn) { agConn, err := s.agentSession.OpenStream() if err != nil { log.Printf("Error opening yamux stream %v", err) return } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + agConn.Close() + conn.Close() + }() + wg := sync.WaitGroup{} wg.Go(func() { + defer cancel() _, err := io.Copy(agConn, conn) log.Printf("Conn finished %v", err) - agConn.Close() }) wg.Go(func() { + defer cancel() _, err := io.Copy(conn, agConn) log.Printf("AgConn finished %v", err) - conn.Close() }) wg.Wait() log.Printf("WG finished") - conn.Close() - agConn.Close() } type bufferedConn struct { @@ -126,7 +152,7 @@ func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (net.Conn, error) { } -func (s *ServerConnection) handleExternalConnections(serv net.Listener) error { +func (s *ServerConnection) handleExternalConnections(ctx context.Context, serv net.Listener) error { // TODO: Accept a context so that if the connection with the agent is lost, we close the server! for { conn, err := serv.Accept() @@ -135,20 +161,24 @@ func (s *ServerConnection) handleExternalConnections(serv net.Listener) error { return err } - // if tcpConn, ok := conn.(*net.TCPConn); ok { - // log.Printf("TCPConn conv") - - // tcpConn.SetKeepAlive(true) - // tcpConn.SetKeepAlivePeriod(time.Second * 5) - // } - // + go func() { + <-ctx.Done() + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + } + }() // Check if the connection requires TLS and upgrade in case it does - buffered, err := s.peekAndUpgrade(conn) + buffered, err := s.peekAndUpgrade(conn) // This function will only return an error IF TLS handshake fails, but if the request doesn't even ask for TLS it will return with the original conn and no error + if err != nil { + log.Printf("Error upgrading connection: %v", err) + continue + } log.Printf("New req received") - go s.handleNewConn(buffered) + go s.handleNewConn(ctx, buffered) } } @@ -213,9 +243,9 @@ func pingLoop(stream *yamux.Stream) error { } -func (s *ServerConnection) StartListeningAgents() error { +func (s *ServerConnection) StartListeningAgents(ctx context.Context) error { - stream, err := s.agentSession.AcceptStream() + stream, err := s.agentSession.AcceptStreamWithContext(ctx) if err != nil { if errors.Is(err, io.EOF) { @@ -235,7 +265,7 @@ func (s *ServerConnection) StartListeningAgents() error { // After config, we setup whatever server type needs to be opened, and start copying over // defer stream.Close() - server, err := s.setupServer() + server, err := s.setupServer(ctx) if err != nil { return err } @@ -256,19 +286,16 @@ func (s *ServerConnection) StartListeningAgents() error { }() go func() { + // If the agent disconnects, also close the server <-s.agentSession.CloseChan() - err := server.Close() - if err != nil { - log.Printf("Error closing the server: %v", err) - return - } - log.Printf("Server %v closed", server.Addr()) + s.serverCancelFunc() + }() log.Printf("Port opened in %v", server.Addr().String()) go func() { - err := s.handleExternalConnections(server) + err := s.handleExternalConnections(ctx, server) if err != nil { log.Printf("Error handling external connection: %v", err) } From 67a80bf607a6966057a2c4afd444411874f69337 Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 18:20:58 -0600 Subject: [PATCH 3/9] better error handling --- cmd/yatun/main.go | 142 +++++++++++++++++++++++------------- internal/message/message.go | 5 +- internal/server/server.go | 57 +++++++-------- 3 files changed, 119 insertions(+), 85 deletions(-) diff --git a/cmd/yatun/main.go b/cmd/yatun/main.go index 5d15125..cd3ebce 100644 --- a/cmd/yatun/main.go +++ b/cmd/yatun/main.go @@ -1,12 +1,12 @@ package main import ( + "context" "flag" "fmt" "io" "sync/atomic" - "os" "sync" "time" @@ -21,7 +21,8 @@ import ( func sendMsg(con *yamux.Stream, m message.TransportMessage) (err error) { byt, err := m.Encode() if err != nil { - return err + + return fmt.Errorf("failed to encode message: %w", err) } _, err = con.Write(byt) @@ -100,83 +101,121 @@ type trafficMonitor struct { underlying io.ReadWriter tuiP *tea.Program streamType tui.TrafficDirection + + bytesTransferred *atomic.Int64 } func (c trafficMonitor) Read(p []byte) (n int, err error) { n, err = c.underlying.Read(p) - go c.tuiP.Send(tui.TrafficUpdate{ - Direction: c.streamType, - Bytes: n, - }) + c.bytesTransferred.Add(int64(n)) return } func (c trafficMonitor) Write(p []byte) (n int, err error) { return c.underlying.Write(p) + } -func serverConnectionLoop(sess *yamux.Session, port *string, tuiP *tea.Program) { - // TODO: After initial handshake is done, io.Copy from server (yatun) to internal target server - for { +func handleStream(ctx context.Context, stream *yamux.Stream, port *string, tuiP *tea.Program) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() - stream, err := sess.AcceptStream() - if err != nil { - // TODO: Maybe? send a message to the TUI so that the user knows it is having trouble getting new sessions from server - tuiP.Send(tui.SetState{ - Err: err, - State: tui.ErrorState, - }) - tuiP.Quit() - return - } + tuiP.Send(tui.LiveConnection) + defer tuiP.Send(tui.DeadConnection) + defer stream.Close() - go func() { - tuiP.Send(tui.LiveConnection) - defer tuiP.Send(tui.DeadConnection) - defer stream.Close() + localConn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%v", *port), time.Second*10) + if err != nil { + tuiP.Send(tui.LocalConnectionError) + // Notify to the TUI the error, maybe the server is down? + return + } + defer localConn.Close() - localConn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%v", *port), time.Second*10) - if err != nil { - tuiP.Send(tui.LocalConnectionError) - // Notify to the TUI the error, maybe the server is down? + go func() { + <-ctx.Done() + + localConn.Close() + stream.Close() + }() + + streamMonitor := trafficMonitor{ + underlying: stream, + tuiP: tuiP, + streamType: tui.Inbound, + bytesTransferred: &atomic.Int64{}, + } + localConnMonitor := trafficMonitor{ + underlying: localConn, + tuiP: tuiP, + streamType: tui.Outbound, + bytesTransferred: &atomic.Int64{}, + } + + go func() { + t := time.NewTicker(time.Second * 5) + defer t.Stop() + + for { + select { + case <-t.C: + // The TUI already sums the data internally, so the right call is Swap instead of load, this could also be used to measure throughput + tuiP.Send(tui.TrafficUpdate{ + Direction: streamMonitor.streamType, + Bytes: int(streamMonitor.bytesTransferred.Swap(0)), + }) + tuiP.Send(tui.TrafficUpdate{ + Direction: localConnMonitor.streamType, + Bytes: int(localConnMonitor.bytesTransferred.Swap(0)), + }) + + case <-ctx.Done(): return } - defer localConn.Close() + } + }() - streamCopier := trafficMonitor{ - underlying: stream, - tuiP: tuiP, - streamType: tui.Inbound, - } - localConnCopier := trafficMonitor{ - underlying: localConn, - tuiP: tuiP, - streamType: tui.Outbound, - } + wg := sync.WaitGroup{} - wg := sync.WaitGroup{} + wg.Go(func() { - wg.Go(func() { + io.Copy(streamMonitor, localConnMonitor) + cancel() - io.Copy(streamCopier, localConnCopier) - localConn.Close() + }) - }) + wg.Go(func() { + io.Copy(localConnMonitor, streamMonitor) + cancel() - wg.Go(func() { - io.Copy(localConnCopier, streamCopier) - stream.Close() + }) - }) + wg.Wait() +} - wg.Wait() +func serverConnectionLoop(ctx context.Context, sess *yamux.Session, port *string, tuiP *tea.Program) { + // TODO: After initial handshake is done, io.Copy from server (yatun) to internal target server + for { - }() + stream, err := sess.AcceptStreamWithContext(ctx) + if err != nil { + // TODO: Maybe? send a message to the TUI so that the user knows it is having trouble getting new sessions from server + tuiP.Send(tui.SetState{ + Err: err, + State: tui.ErrorState, + }) + tuiP.Quit() + return + } + go handleStream(ctx, stream, port, tuiP) } } func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + port := flag.String("port", "", "--port") server := flag.String("server", "yatun.snowdev.one", "--server") @@ -205,13 +244,14 @@ func main() { go tuiP.Send(tui.SetState{ State: tui.OnlineState, }) - go serverConnectionLoop(sess, port, tuiP) + go serverConnectionLoop(ctx, sess, port, tuiP) } if _, err := tuiP.Run(); err != nil { if sess != nil { sess.Close() } - os.Exit(1) + // os.Exit(1) + cancel() } } diff --git a/internal/message/message.go b/internal/message/message.go index 414c74a..b4b7cbf 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" "log" "time" @@ -28,7 +29,7 @@ func (r *Response) Decode(b []byte) error { var rT Response err := json.Unmarshal(b, &rT) if err != nil { - return err + return fmt.Errorf("failed to unmarshall message json: %w", err) } r.Address = rT.Address @@ -48,7 +49,7 @@ func (r *PingMessage) Decode(b []byte) error { var rT PingMessage err := json.Unmarshal(b, &rT) if err != nil { - return err + return fmt.Errorf("failed to unmarshall message: %w", err) } r.Time = rT.Time diff --git a/internal/server/server.go b/internal/server/server.go index 018d661..eb97c27 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -31,7 +31,7 @@ func (s *ServerConnection) handleInitialConfig(stream *yamux.Stream) error { // TODO: Maybe this initial message could be used for something? I'm not quite sure what just yet though m, err := message.Decode(stream) if err != nil { - return err + return fmt.Errorf("failed to decode initial config message: %w", err) } log.Printf("Received message %v", m) @@ -44,8 +44,7 @@ func (s *ServerConnection) setupServer(ctx context.Context) (net.Listener, error server, err := net.Listen("tcp", "0.0.0.0:") if err != nil { - log.Printf("Error starting TCP server: %v", err) - return server, err + return server, fmt.Errorf("error starting tcp server: %w", err) } ctx, s.serverCancelFunc = context.WithCancel(ctx) @@ -105,7 +104,7 @@ func (b *bufferedConn) Read(p []byte) (int, error) { return b.buf.Read(p) } -func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (net.Conn, error) { +func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (*bufferedConn, error) { buf := bufio.NewReader(conn) ob := bufferedConn{ @@ -114,41 +113,34 @@ func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (net.Conn, error) { } if !s.config.TLS { - return conn, nil + return &ob, nil } fB, err := buf.Peek(1) if err != nil { log.Printf("Error peeking into the connection %v", err) - return conn, err + return &ob, err } - log.Printf("TCP first bytes %v", fB) - if fB[0] == 0x16 { log.Printf("Initiating TLS") - if !s.config.TLS { - - log.Printf("The incoming request attempted to initiate a TLS connection, but TLS is disabled.") - return conn, nil - } upgraded, err := s.config.TLSStore.Wrap(&ob) if err != nil { - return conn, err + return &ob, err } - // buf := bufio.NewReader(upgraded) - // ob.buf = buf - // ob.Conn = upgraded + buf := bufio.NewReader(upgraded) + ob.buf = buf + ob.Conn = upgraded log.Printf("TLS Upgrade success") - return upgraded, nil + return &ob, nil } - return conn, nil + return &ob, nil } @@ -157,8 +149,8 @@ func (s *ServerConnection) handleExternalConnections(ctx context.Context, serv n for { conn, err := serv.Accept() if err != nil { - log.Printf("Error accepting external conn %v", err) - return err + + return fmt.Errorf("error accepting external connection: %w", err) } go func() { @@ -187,8 +179,8 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List if s.config.Domain != nil { _, port, err := net.SplitHostPort(server.Addr().String()) if err != nil { - log.Printf("Error parsing the host port: %v", err) - return err + + return fmt.Errorf("error parsing host/port from address: %w", err) } @@ -205,13 +197,12 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List byt, err := m.Encode() if err != nil { - return err + return fmt.Errorf("error encoding address info message: %w", err) } _, err = stream.Write(byt) if err != nil { - log.Printf("Error sending to the client response info") - return err + return fmt.Errorf("error writing to stream: %w", err) } return nil @@ -231,11 +222,13 @@ func pingLoop(stream *yamux.Stream) error { byt, err := m.Encode() if err != nil { + log.Printf("Failed to encode ping message for agent: %v", err) return err } _, err = stream.Write(byt) if err != nil { + log.Printf("Failed to send ping notification to agent: %v", err) return err } } @@ -252,14 +245,14 @@ func (s *ServerConnection) StartListeningAgents(ctx context.Context) error { // This means the agent stream disconnected, it is fine to just break here return errors.New("sess disconnected") } - log.Printf("Unrecognized err: %v", err) - return err + + return fmt.Errorf("unrecoginzed err when listening to agents: %w", err) } // When we receive an agent stream, listen and parse for the first handshake and then we can start copying over / creating new sessions err = s.handleInitialConfig(stream) if err != nil { - return err + return fmt.Errorf("setting up the initial config for agent failed: %w", err) } // After config, we setup whatever server type needs to be opened, and start copying over @@ -267,13 +260,13 @@ func (s *ServerConnection) StartListeningAgents(ctx context.Context) error { server, err := s.setupServer(ctx) if err != nil { - return err + return fmt.Errorf("failed to setup tcp server: %w", err) } err = s.sendAddressInfo(stream, server) if err != nil { - log.Printf("Error sending address information %v", err) - return err + return fmt.Errorf("failed to send initial information to the agent: %w", err) + } go func() { From cb2b3efe72743f079d0d79ec9c4f7655c035a0d2 Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 20:32:23 -0600 Subject: [PATCH 4/9] More error handling, cleanup and ctx --- .gitignore | 4 ++++ cmd/yatun/main.go | 3 ++- internal/config/config.go | 2 +- internal/message/message.go | 2 +- internal/server/server.go | 3 +++ 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 4c49bd7..49c8451 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ .env +/yatun +/yatund +*.exe + diff --git a/cmd/yatun/main.go b/cmd/yatun/main.go index cd3ebce..a2697f4 100644 --- a/cmd/yatun/main.go +++ b/cmd/yatun/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "io" + "os" "sync/atomic" "sync" @@ -251,7 +252,7 @@ func main() { if sess != nil { sess.Close() } - // os.Exit(1) cancel() + os.Exit(1) } } diff --git a/internal/config/config.go b/internal/config/config.go index 257619f..496ccd0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,7 +19,7 @@ type ServerConfig struct { func ReadFromEnv() (ServerConfig, error) { c := ServerConfig{} - domain, ok := os.LookupEnv("Domain") + domain, ok := os.LookupEnv("DOMAIN") if ok { if !strings.Contains(domain, "//") { domain = "//" + domain diff --git a/internal/message/message.go b/internal/message/message.go index b4b7cbf..fa232a8 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -121,7 +121,7 @@ func readSize(r io.Reader) (uint16, error) { func Decode(r io.Reader) (*TransportMessage, error) { b := make([]byte, 1) - n, err := r.Read(b) + n, err := io.ReadFull(r, b) if err != nil { return nil, err diff --git a/internal/server/server.go b/internal/server/server.go index eb97c27..b51b8b5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -149,6 +149,9 @@ func (s *ServerConnection) handleExternalConnections(ctx context.Context, serv n for { conn, err := serv.Accept() if err != nil { + if ctx.Err() != nil { + return nil + } return fmt.Errorf("error accepting external connection: %w", err) } From 84921bc04c378665b319fde7060a53cfcd48159a Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 20:36:52 -0600 Subject: [PATCH 5/9] Tests added, fixed nil pointer issue --- internal/config/config_test.go | 81 ++++++++++ internal/message/message.go | 11 +- internal/message/message_test.go | 247 +++++++++++++++++++++++++++++++ 3 files changed, 334 insertions(+), 5 deletions(-) create mode 100644 internal/config/config_test.go create mode 100644 internal/message/message_test.go diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..fadff5e --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,81 @@ +package config + +import ( + "os" + "testing" +) + +func TestReadFromEnvDefaults(t *testing.T) { + os.Unsetenv("DOMAIN") + os.Unsetenv("TLS") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.Domain != nil { + t.Errorf("expected nil Domain, got %q", *cfg.Domain) + } + if cfg.TLS { + t.Errorf("expected TLS to be false by default") + } +} + +func TestReadFromEnvDomain(t *testing.T) { + os.Unsetenv("TLS") + t.Setenv("DOMAIN", "example.com") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.Domain == nil { + t.Fatal("expected Domain to be set") + } + if *cfg.Domain != "example.com" { + t.Errorf("expected example.com, got %q", *cfg.Domain) + } +} + +func TestReadFromEnvDomainWithScheme(t *testing.T) { + os.Unsetenv("TLS") + t.Setenv("DOMAIN", "https://sub.example.com") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.Domain == nil { + t.Fatal("expected Domain to be set") + } + if *cfg.Domain != "sub.example.com" { + t.Errorf("expected sub.example.com, got %q", *cfg.Domain) + } +} + +func TestReadFromEnvDomainInvalidURL(t *testing.T) { + os.Unsetenv("TLS") + t.Setenv("DOMAIN", "://bad host") + + _, err := ReadFromEnv() + if err == nil { + t.Fatal("expected error for invalid domain URL") + } +} + +func TestReadFromEnvTLSDisabled(t *testing.T) { + os.Unsetenv("DOMAIN") + t.Setenv("TLS", "0") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.TLS { + t.Errorf("expected TLS to be false when set to 0") + } +} diff --git a/internal/message/message.go b/internal/message/message.go index fa232a8..0f976bd 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -164,12 +164,13 @@ func Decode(r io.Reader) (*TransportMessage, error) { obj = &PingMessage{} } - err = obj.Decode(buf[:n]) - if err != nil { - return nil, err + if obj != nil { + err = obj.Decode(buf[:n]) + if err != nil { + return nil, err + } + m.Data = obj } - m.Data = obj - return &m, nil } diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 0000000..d2650a9 --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,247 @@ +package message + +import ( + "bytes" + "io" + "strings" + "testing" + "time" +) + +func TestEncodeDecodeResponse(t *testing.T) { + original := TransportMessage{ + Type: ResponseMessage, + Data: &Response{ + Ok: true, + Address: "example.com:443", + }, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Type != original.Type { + t.Errorf("type: got %c, want %c", decoded.Type, original.Type) + } + + resp, ok := decoded.Data.(*Response) + if !ok { + t.Fatalf("expected *Response, got %T", decoded.Data) + } + if resp.Ok != true { + t.Errorf("Ok: got %v, want true", resp.Ok) + } + if resp.Address != "example.com:443" { + t.Errorf("Address: got %q, want %q", resp.Address, "example.com:443") + } +} + +func TestEncodeDecodePingMessage(t *testing.T) { + now := time.Now().Truncate(time.Millisecond) + + original := TransportMessage{ + Type: PingMessageType, + Data: &PingMessage{Time: now}, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Type != PingMessageType { + t.Errorf("type: got %c, want %c", decoded.Type, PingMessageType) + } + + ping, ok := decoded.Data.(*PingMessage) + if !ok { + t.Fatalf("expected *PingMessage, got %T", decoded.Data) + } + if !ping.Time.Equal(now) { + t.Errorf("Time: got %v, want %v", ping.Time, now) + } +} + +func TestEncodeDecodeNoData(t *testing.T) { + for _, mt := range []MessageType{TCPConnection, OpenMsg} { + original := TransportMessage{ + Type: mt, + Data: nil, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("encode %c: %v", mt, err) + } + + if len(encoded) != 1 { + t.Errorf("expected 1 byte for data-less message, got %d", len(encoded)) + } + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode %c: %v", mt, err) + } + + if decoded.Type != mt { + t.Errorf("type: got %c, want %c", decoded.Type, mt) + } + if decoded.Data != nil { + t.Errorf("expected nil Data for type %c, got %v", mt, decoded.Data) + } + } +} + +func TestDecodeUnknownType(t *testing.T) { + // Create a message with an unknown type 'X' and a valid-length JSON payload + encoded := []byte{'X', 0, 2, '{', '}'} + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Type != MessageType('X') { + t.Errorf("type: got %c, want X", decoded.Type) + } + if decoded.Data != nil { + t.Errorf("expected nil Data for unknown type, got %v", decoded.Data) + } +} + +func TestDecodeTruncatedType(t *testing.T) { + _, err := Decode(strings.NewReader("")) + if err != io.EOF { + t.Errorf("expected io.EOF for empty input, got %v", err) + } +} + +func TestDecodeTruncatedSize(t *testing.T) { + // Type byte present, but no size bytes + _, err := Decode(bytes.NewReader([]byte{'C'})) + if err == nil { + t.Fatal("expected error for truncated size field") + } +} + +func TestDecodeTruncatedPayload(t *testing.T) { + // Type 'C', size says 100 bytes, but only 5 follow + encoded := bytes.NewBuffer([]byte{'C', 0, 100}) + encoded.Write([]byte("short")) + + _, err := Decode(bytes.NewReader(encoded.Bytes())) + if err == nil { + t.Fatal("expected error for truncated payload") + } +} + +func TestDecodeGarbagePayload(t *testing.T) { + // Type 'C', size says 5 bytes of garbage (not valid JSON) + encoded := bytes.NewBuffer([]byte{'C', 0, 5}) + encoded.Write([]byte("!!!!!")) + + _, err := Decode(bytes.NewReader(encoded.Bytes())) + if err == nil { + t.Fatal("expected error for garbage JSON payload") + } +} + +func TestEncodeWithNilData(t *testing.T) { + m := TransportMessage{ + Type: PingMessageType, + Data: nil, + } + + encoded, err := m.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + if len(encoded) != 1 { + t.Errorf("expected 1 byte with nil Data, got %d", len(encoded)) + } + + if encoded[0] != byte(PingMessageType) { + t.Errorf("type byte: got %c, want %c", encoded[0], PingMessageType) + } +} + +func TestParseConnectionType(t *testing.T) { + mt, err := ParseConnectionType('A') + if err != nil { + t.Fatalf("ParseConnectionType('A'): %v", err) + } + if *mt != TCPConnection { + t.Errorf("got %c (%d), want %c (%d)", *mt, *mt, TCPConnection, TCPConnection) + } + + _, err = ParseConnectionType('Z') + if err == nil { + t.Fatal("expected error for unknown type") + } +} + +func TestResponseEncodeDecode(t *testing.T) { + original := Response{Ok: false, Address: "127.0.0.1:3000"} + + data, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + var decoded Response + if err := decoded.Decode(data); err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Ok != original.Ok || decoded.Address != original.Address { + t.Errorf("mismatch: got %+v, want %+v", decoded, original) + } +} + +func TestPingMessageEncodeDecode(t *testing.T) { + now := time.Now().Truncate(time.Millisecond) + original := PingMessage{Time: now} + + data, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + var decoded PingMessage + if err := decoded.Decode(data); err != nil { + t.Fatalf("decode: %v", err) + } + + if !decoded.Time.Equal(original.Time) { + t.Errorf("Time mismatch: got %v, want %v", decoded.Time, original.Time) + } +} + +func TestResponseDecodeInvalidJSON(t *testing.T) { + var r Response + err := r.Decode([]byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestPingMessageDecodeInvalidJSON(t *testing.T) { + var p PingMessage + err := p.Decode([]byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} From 547b3e77aef064e3c803b5f7e75fb85a53598a6c Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 22:03:10 -0600 Subject: [PATCH 6/9] Add test step to CI/CD --- .github/workflows/build.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c9486af..20a8798 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,6 +7,17 @@ on: branches: [main] jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.26' + cache: true + + name: Run tests + run: go test ./... -v build: strategy: matrix: From 5705e848b98de1fb1492bc5da7386e4301036816 Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 22:08:52 -0600 Subject: [PATCH 7/9] Move the test and build jobs to another pipeline that runs only on PRs, the build with artifact uploading stays the same --- .github/workflows/build.yml | 10 ---------- .github/workflows/test.yml | 39 +++++++++++++++++++++++++++++++++++++ README.md | 4 ++-- 3 files changed, 41 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 20a8798..71e782a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,17 +7,7 @@ on: branches: [main] jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.26' - cache: true - name: Run tests - run: go test ./... -v build: strategy: matrix: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..5a19640 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,39 @@ +name: Test & Build + +on: + pull_request: + branches: [main, dev] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.26' + cache: true + + name: Run tests + run: go test ./... -v + build: + strategy: + matrix: + arch: [amd64, arm64] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.26' + cache: true + + - name: Build yatund + run: | + GOOS=linux GOARCH=${{ matrix.arch }} CGO_ENABLED=0 go build -o build/yatund-${{ matrix.arch }} ./cmd/yatund/ + + - name: Build yatun + run: | + GOOS=linux GOARCH=${{ matrix.arch }} CGO_ENABLED=0 go build -o build/yatun-${{ matrix.arch }} ./cmd/yatun/ + diff --git a/README.md b/README.md index 8fabcaf..8cbade6 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ ``` ┌──────────┐ TCP ┌────────────────┐ yamux ┌──────────┐ TCP ┌─────────┐ -│ Internet │───────────▶│ yatund │◀══════════▶│ yatun │────────▶│ local │ -│ Client │ random port│ (relay server) │ session │ (agent) │ :port │ service │ +│ Internet │───────────▶│ yatund │◀═════════▶│ yatun │────────▶│ local │ +│ Client │ random port│ (relay server)│ session │ (agent) │ :port │ service │ └──────────┘ └────────────────┘ └──────────┘ └─────────┘ ``` From 689e9705abcc8e0f56cb6d5cb5ab14e31c8e6a4f Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 22:12:15 -0600 Subject: [PATCH 8/9] Add dash to the job name --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5a19640..5f1a7ec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: go-version: '1.26' cache: true - name: Run tests + - name: Run tests run: go test ./... -v build: strategy: From cad232c4c9b6d1c0ecbbecea83c051db7e5b7905 Mon Sep 17 00:00:00 2001 From: Fabrizio Gomez Date: Fri, 15 May 2026 22:13:11 -0600 Subject: [PATCH 9/9] Do not execute the jobs intended for release on PRs --- .github/workflows/build.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 71e782a..e1c3d71 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,9 +3,7 @@ name: Build on: push: tags: ['v*'] - pull_request: - branches: [main] - + jobs: build: