Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ web/management/node_modules/
web/management/dist/
web/management/test-results/
web/management/playwright-report/
docs/.plans/
SECURITY_AUDIT.md
20 changes: 12 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,21 @@ kill:
@-repo=$$(pwd); \
pkill -TERM -f '[g]o tool air -c .air.toml|/go-build/.*/[a]ir -c .air.toml|[b]un run --bun vite --host 127.0.0.1 --port 5173|[n]ode .*vite --host 127.0.0.1 --port 5173' 2>/dev/null || true; \
pkill -TERM -f "$$repo/bin/p2pstream|$$repo/tmp/p2pstream-dev|$$repo/tmp/p2pstream-agent-dev|[g]o run main.go agent|/go-build/.*/[m]ain agent" 2>/dev/null || true; \
for port in 8081 8088 8089 5173; do \
pids=$$(ss -H -ltnp "sport = :$$port" 2>/dev/null | sed -n 's/.*pid=\([0-9][0-9]*\).*/\1/p' | sort -u); \
[ -z "$$pids" ] || kill -TERM $$pids 2>/dev/null || true; \
done; \
if command -v ss >/dev/null 2>&1; then \
for port in 8081 8088 8089 5173; do \
pids=$$(ss -H -ltnp "sport = :$$port" 2>/dev/null | sed -n 's/.*pid=\([0-9][0-9]*\).*/\1/p' | sort -u); \
[ -z "$$pids" ] || kill -TERM $$pids 2>/dev/null || true; \
done; \
fi; \
sleep 0.5; \
pkill -KILL -f '[g]o tool air -c .air.toml|/go-build/.*/[a]ir -c .air.toml|[b]un run --bun vite --host 127.0.0.1 --port 5173|[n]ode .*vite --host 127.0.0.1 --port 5173' 2>/dev/null || true; \
pkill -KILL -f "$$repo/bin/p2pstream|$$repo/tmp/p2pstream-dev|$$repo/tmp/p2pstream-agent-dev|[g]o run main.go agent|/go-build/.*/[m]ain agent" 2>/dev/null || true; \
for port in 8081 8088 8089 5173; do \
pids=$$(ss -H -ltnp "sport = :$$port" 2>/dev/null | sed -n 's/.*pid=\([0-9][0-9]*\).*/\1/p' | sort -u); \
[ -z "$$pids" ] || kill -KILL $$pids 2>/dev/null || true; \
done
if command -v ss >/dev/null 2>&1; then \
for port in 8081 8088 8089 5173; do \
pids=$$(ss -H -ltnp "sport = :$$port" 2>/dev/null | sed -n 's/.*pid=\([0-9][0-9]*\).*/\1/p' | sort -u); \
[ -z "$$pids" ] || kill -KILL $$pids 2>/dev/null || true; \
done; \
fi

clean:
@echo "Cleaning up..."
Expand Down
1 change: 1 addition & 0 deletions agent_tunnel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func dialAgentTunnel(ctx context.Context, managementURL string, publicID string,
return nil, resp, err
}
if resp.StatusCode != http.StatusSwitchingProtocols {
_ = resp.Body.Close()
return nil, resp, fmt.Errorf("expected tunnel upgrade status 101, got %d", resp.StatusCode)
}
body, ok := resp.Body.(io.ReadWriteCloser)
Expand Down
277 changes: 0 additions & 277 deletions docs/.plans/agent-yamux-overhaul.md

This file was deleted.

13 changes: 9 additions & 4 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ var (
agentReconnectBackoffMax = 30 * time.Second
)

const agentTunnelResponseHeaderTimeout = 15 * time.Second

type Options struct {
ManagementURL string
PublicID string
Expand Down Expand Up @@ -248,6 +250,9 @@ func managementTunnelHTTPClient(base *http.Client) (*http.Client, error) {
}
transport.ForceAttemptHTTP2 = false
transport.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
if transport.ResponseHeaderTimeout == 0 {
transport.ResponseHeaderTimeout = agentTunnelResponseHeaderTimeout
}
protocols := new(http.Protocols)
protocols.SetHTTP1(true)
transport.Protocols = protocols
Expand Down Expand Up @@ -288,25 +293,25 @@ func connectAndServe(client *http.Client, tunnelURL string, agentPublicID string
if resp.Body != nil {
data, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
body = strings.TrimSpace(string(data))
resp.Body.Close()
_ = resp.Body.Close()
}
if body != "" {
return fmt.Errorf("agent tunnel upgrade failed: status %d: %s", resp.StatusCode, body)
}
return fmt.Errorf("agent tunnel upgrade failed: status %d", resp.StatusCode)
}
if got := resp.Header.Get("Upgrade"); !strings.EqualFold(got, tunnel.UpgradeToken) {
resp.Body.Close()
_ = resp.Body.Close()
return fmt.Errorf("agent tunnel upgrade response header = %q", got)
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
resp.Body.Close()
_ = resp.Body.Close()
return fmt.Errorf("agent tunnel response body is %T, want io.ReadWriteCloser", resp.Body)
}
session, err := yamux.Client(rwc, tunnel.DefaultYamuxConfig(nil))
if err != nil {
rwc.Close()
_ = rwc.Close()
return fmt.Errorf("failed to initialize tunnel session: %w", err)
}
defer session.Close()
Expand Down
6 changes: 3 additions & 3 deletions internal/server/agent_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ func (a *App) UpdateAgent(
if err != nil {
return nil, publicDBError(err)
}
if err := a.refreshPublicProxySnapshot(ctx); err != nil {
return nil, err
}
if a.AgentTransports != nil {
a.AgentTransports.closeAgent(req.Msg.Id)
}
if err := a.refreshPublicProxySnapshot(ctx); err != nil {
return nil, err
}
return connect.NewResponse(&p2pstreamv1.UpdateAgentResponse{Agent: a.agentToProto(ctx, agent)}), nil
}

Expand Down
4 changes: 3 additions & 1 deletion internal/server/public_backend_health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func TestAgentPoolHealthCheckRunsThroughAssignedAgent(t *testing.T) {
t.Fatalf("connect agent: %v", err)
}
defer app.AgentHub.disconnect(agent)
defer fake.close()

snap := &publicProxySnapshot{
Backends: map[int64]publicBackendConfig{backend.ID: backend},
Expand Down Expand Up @@ -231,12 +232,13 @@ func TestAgentHealthTraceRecordsSuccessAndDebugAttributes(t *testing.T) {

backend := testHealthBackend(t, 102, publicBackendForwardModeAgentPool, upstream.URL)
backend.AgentAssignments = []publicBackendAgentConfig{{BackendID: backend.ID, AgentID: 7, Position: 0, Weight: 100, Enabled: true}}
agent, _ := newFakeYamuxAgent(t, 7, "agent-7")
agent, fake := newFakeYamuxAgent(t, 7, "agent-7")
agent.Name = "Agent Seven"
if err := app.AgentHub.connect(agent); err != nil {
t.Fatalf("connect agent: %v", err)
}
defer app.AgentHub.disconnect(agent)
defer fake.close()
app.BackendHealth.reconcile(app, &publicProxySnapshot{
Backends: map[int64]publicBackendConfig{backend.ID: backend},
Agents: map[int64]publicAgentConfig{7: {ID: 7, PublicID: "agent-7", Name: "Agent Seven", Enabled: true}},
Expand Down
45 changes: 43 additions & 2 deletions internal/server/public_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,9 +974,27 @@ func (a *App) dialViaAgent(ctx context.Context, agent *AgentConn, network string
openCh := make(chan struct {
conn net.Conn
err error
})
}, 1)
openDone := make(chan struct{})
stopOpenWatch := func() {
select {
case <-openDone:
default:
close(openDone)
}
}
session := agent.Session
go func() {
select {
case <-ctx.Done():
_ = session.Close()
case <-agent.Done:
_ = session.Close()
case <-openDone:
}
}()
go func() {
conn, err := agent.Session.Open()
conn, err := session.Open()
result := struct {
conn net.Conn
err error
Expand All @@ -997,6 +1015,7 @@ func (a *App) dialViaAgent(ctx context.Context, agent *AgentConn, network string
var conn net.Conn
select {
case result := <-openCh:
stopOpenWatch()
if result.err != nil {
if agent != nil {
log.Debug().
Expand All @@ -1010,6 +1029,8 @@ func (a *App) dialViaAgent(ctx context.Context, agent *AgentConn, network string
}
conn = result.conn
case <-ctx.Done():
_ = agent.Session.Close()
stopOpenWatch()
if agent != nil {
log.Debug().
Err(ctx.Err()).
Expand All @@ -1020,6 +1041,8 @@ func (a *App) dialViaAgent(ctx context.Context, agent *AgentConn, network string
}
return nil, ctx.Err()
case <-agent.Done:
_ = agent.Session.Close()
stopOpenWatch()
log.Debug().
Str("request_id", requestID).
Str("agent", agent.PublicID).
Expand All @@ -1031,6 +1054,24 @@ func (a *App) dialViaAgent(ctx context.Context, agent *AgentConn, network string
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
handshakeDone := make(chan struct{})
stopHandshakeWatch := func() {
select {
case <-handshakeDone:
default:
close(handshakeDone)
}
}
go func() {
select {
case <-ctx.Done():
_ = conn.Close()
case <-agent.Done:
_ = conn.Close()
case <-handshakeDone:
}
}()
defer stopHandshakeWatch()
req := tunnel.NewOpenRequest(requestID, network, address)
if err := tunnel.WriteOpenRequest(conn, req); err != nil {
_ = conn.Close()
Expand Down
83 changes: 46 additions & 37 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -284,7 +285,8 @@ func (a *App) agentTunnelHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "agent tunnel upgrade required", http.StatusBadRequest)
return
}
if r.Header.Get(tunnel.TunnelVersionHeader) != "1" {
version := strconv.Itoa(tunnel.ProtocolVersion)
if r.Header.Get(tunnel.TunnelVersionHeader) != version {
http.Error(w, "unsupported tunnel version", http.StatusUpgradeRequired)
return
}
Expand All @@ -294,73 +296,80 @@ func (a *App) agentTunnelHandler(w http.ResponseWriter, r *http.Request) {
return
}

var connID int64
if a.DB != nil {
id, err := a.DB.InsertConnection(r.Context(), sql.NullInt64{Int64: agentRow.ID, Valid: true})
if err == nil {
connID = id
if err := a.DB.MarkAgentConnected(r.Context(), agentRow.ID); err != nil {
log.Warn().Err(err).Str("agent", agentRow.PublicID).Msg("Failed to update agent connected timestamp")
}
} else {
log.Warn().Err(err).Msg("Failed to insert connection into DB")
}
if existing := a.AgentHub.connectedByID(agentRow.ID); existing != nil {
log.Warn().Str("agent", agentRow.PublicID).Msg("Rejecting duplicate agent connection")
http.Error(w, "agent is already connected", http.StatusConflict)
return
}

agent := &AgentConn{
AgentID: agentRow.ID,
PublicID: agentRow.PublicID,
Name: agentRow.Name,
Done: make(chan struct{}),
ConnectedAt: time.Now(),
ConnectionDBID: connID,
}

if err := a.AgentHub.connect(agent); err != nil {
log.Warn().Err(err).Str("agent", agent.PublicID).Msg("Rejecting duplicate agent connection")
http.Error(w, err.Error(), http.StatusConflict)
return
}
cleanupAgent := func() {
a.AgentHub.disconnect(agent)
if a.BackendHealth != nil {
a.BackendHealth.recordAgentDisconnectedForAll(agent.AgentID)
}
AgentID: agentRow.ID,
PublicID: agentRow.PublicID,
Name: agentRow.Name,
Done: make(chan struct{}),
ConnectedAt: time.Now(),
}

rawConn, rw, err := hijacker.Hijack()
if err != nil {
cleanupAgent()
log.Error().Err(err).Str("agent", agent.PublicID).Msg("Failed to hijack agent tunnel")
return
}
if rw.Reader.Buffered() > 0 {
_, _ = rw.WriteString("HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\nunexpected buffered tunnel data\n")
_ = rw.Flush()
_ = rawConn.Close()
cleanupAgent()
return
}
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
_, _ = rw.WriteString("Connection: Upgrade\r\n")
_, _ = rw.WriteString("Upgrade: " + tunnel.UpgradeToken + "\r\n")
_, _ = rw.WriteString(tunnel.TunnelVersionHeader + ": 1\r\n")
_, _ = rw.WriteString(tunnel.TunnelVersionHeader + ": " + version + "\r\n")
_, _ = rw.WriteString("\r\n")
if err := rw.Flush(); err != nil {
_ = rawConn.Close()
cleanupAgent()
log.Error().Err(err).Str("agent", agent.PublicID).Msg("Failed to write agent tunnel upgrade response")
return
}

session, err := yamux.Server(rawConn, tunnel.DefaultYamuxConfig(nil))
if err != nil {
_ = rawConn.Close()
cleanupAgent()
log.Error().Err(err).Str("agent", agent.PublicID).Msg("Failed to initialize agent tunnel session")
return
}
agent.Session = session

if a.DB != nil {
id, err := a.DB.InsertConnection(r.Context(), sql.NullInt64{Int64: agentRow.ID, Valid: true})
if err == nil {
agent.ConnectionDBID = id
if err := a.DB.MarkAgentConnected(r.Context(), agentRow.ID); err != nil {
log.Warn().Err(err).Str("agent", agentRow.PublicID).Msg("Failed to update agent connected timestamp")
}
} else {
log.Warn().Err(err).Msg("Failed to insert connection into DB")
}
}
if err := a.AgentHub.connect(agent); err != nil {
_ = session.Close()
if a.DB != nil && agent.ConnectionDBID > 0 {
if err := a.DB.UpdateConnectionDisconnected(context.Background(), agent.ConnectionDBID); err != nil {
log.Warn().Err(err).Msg("Failed to update rejected connection disconnection time")
}
if err := a.DB.MarkAgentDisconnected(context.Background(), agent.AgentID); err != nil {
log.Warn().Err(err).Str("agent", agent.PublicID).Msg("Failed to update rejected agent disconnected timestamp")
}
}
log.Warn().Err(err).Str("agent", agent.PublicID).Msg("Rejecting duplicate agent connection")
return
}
cleanupAgent := func() {
a.AgentHub.disconnect(agent)
if a.BackendHealth != nil {
a.BackendHealth.recordAgentDisconnectedForAll(agent.AgentID)
}
}
if a.BackendHealth != nil {
a.BackendHealth.recordAgentConnectedForAll(agent.AgentID, agent.PublicID)
}
Expand All @@ -383,8 +392,8 @@ func (a *App) agentTunnelHandler(w http.ResponseWriter, r *http.Request) {
Int64("duration_ms", time.Since(agent.ConnectedAt).Milliseconds()).
Int64("active_requests", agent.ActiveRequests.Load()).
Msg("Agent tunnel disconnected")
if a.DB != nil && connID > 0 {
if err := a.DB.UpdateConnectionDisconnected(context.Background(), connID); err != nil {
if a.DB != nil && agent.ConnectionDBID > 0 {
if err := a.DB.UpdateConnectionDisconnected(context.Background(), agent.ConnectionDBID); err != nil {
log.Warn().Err(err).Msg("Failed to update disconnection time")
}
if err := a.DB.MarkAgentDisconnected(context.Background(), agent.AgentID); err != nil {
Expand Down
Loading
Loading