diff --git a/Makefile b/Makefile index e519a75a2..e144fa3ee 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,8 @@ BRANCH := $(shell git rev-parse --abbrev-ref HEAD) BUILDDATE := $(shell date -u +%FT%T%z) BUILDTS := $(shell date -u +%s) REVISION := $(shell git rev-parse HEAD) -VERSION_DEV := 0.4.9-dev$(shell date +%Y%m%d%H%M) -VERSION := 0.4.8 +VERSION_DEV ?= 0.4.9-dev$(shell date +%Y%m%d%H%M) +VERSION ?= 0.4.8 PROMETHEUS_TAG := github.com/prometheus/common/version KVM_PKG_NAME := github.com/jetkvm/kvm diff --git a/cloud.go b/cloud.go index dbbd3bbcc..9d805a4a6 100644 --- a/cloud.go +++ b/cloud.go @@ -197,6 +197,24 @@ func wsResetMetrics(established bool, sourceType string, source string) { } func handleCloudRegister(c *gin.Context) { + sessionID, _ := c.Cookie("sessionId") + authToken, _ := c.Cookie("authToken") + + // Require authentication for this endpoint + if authToken == "" || authToken != config.LocalAuthToken { + c.JSON(401, gin.H{"error": "Authentication required"}) + return + } + + // Check session permissions if session exists + if sessionID != "" { + session := sessionManager.GetSession(sessionID) + if session != nil && !session.HasPermission(PermissionSettingsWrite) { + c.JSON(403, gin.H{"error": "Permission denied: settings modify permission required"}) + return + } + } + var req CloudRegisterRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -426,8 +444,15 @@ func handleSessionRequest( req WebRTCSessionRequest, isCloudConnection bool, source string, + connectionID string, scopedLogger *zerolog.Logger, -) error { +) (returnErr error) { + defer func() { + if r := recover(); r != nil { + websocketLogger.Error().Interface("panic", r).Msg("PANIC in handleSessionRequest") + returnErr = fmt.Errorf("panic: %v", r) + } + }() var sourceType string if isCloudConnection { sourceType = "cloud" @@ -453,6 +478,7 @@ func handleSessionRequest( IsCloud: isCloudConnection, LocalIP: req.IP, ICEServers: req.ICEServers, + UserAgent: req.UserAgent, Logger: scopedLogger, }) if err != nil { @@ -462,26 +488,73 @@ func handleSessionRequest( sd, err := session.ExchangeOffer(req.Sd) if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to exchange offer") _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) return err } - if currentSession != nil { - writeJSONRPCEvent("otherSessionConnected", nil, currentSession) - peerConn := currentSession.peerConnection - go func() { - time.Sleep(1 * time.Second) - _ = peerConn.Close() - }() + session.Source = source + + if isCloudConnection && req.OidcGoogle != "" { + session.Identity = config.GoogleIdentity + + // Use client-provided sessionId for reconnection, otherwise generate new one + // This enables multi-tab support while preserving reconnection on refresh + if req.SessionId != "" { + session.ID = req.SessionId + scopedLogger.Info().Str("sessionId", session.ID).Msg("Cloud session reconnecting with client-provided ID") + } else { + session.ID = connectionID + scopedLogger.Info().Str("sessionId", session.ID).Msg("New cloud session established") + } + } else { + session.ID = connectionID + scopedLogger.Info().Str("sessionId", session.ID).Msg("Local session established") + } + + if sessionManager == nil { + scopedLogger.Error().Msg("sessionManager is nil") + _ = wsjson.Write(context.Background(), c, gin.H{"error": "session manager not initialized"}) + return fmt.Errorf("session manager not initialized") + } + + err = sessionManager.AddSession(session, req.SessionSettings) + if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to add session to session manager") + if err == ErrMaxSessionsReached { + _ = wsjson.Write(context.Background(), c, gin.H{"error": "maximum sessions reached"}) + } else { + _ = wsjson.Write(context.Background(), c, gin.H{"error": err.Error()}) + } + return err + } + + if session.HasPermission(PermissionPaste) { + cancelKeyboardMacro() } - cloudLogger.Info().Interface("session", session).Msg("new session accepted") - cloudLogger.Trace().Interface("session", session).Msg("new session accepted") + requireNickname := false + requireApproval := false + if currentSessionSettings != nil { + requireNickname = currentSessionSettings.RequireNickname + requireApproval = currentSessionSettings.RequireApproval + } - // Cancel any ongoing keyboard macro when session changes - cancelKeyboardMacro() + err = wsjson.Write(context.Background(), c, gin.H{ + "type": "answer", + "data": sd, + "sessionId": session.ID, + "mode": session.Mode, + "nickname": session.Nickname, + "requireNickname": requireNickname, + "requireApproval": requireApproval, + }) + if err != nil { + return err + } - currentSession = session - _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) + if session.flushCandidates != nil { + session.flushCandidates() + } return nil } diff --git a/config.go b/config.go index 26f54a45b..15698761e 100644 --- a/config.go +++ b/config.go @@ -78,11 +78,21 @@ func (m *KeyboardMacro) Validate() error { return nil } +// MultiSessionConfig defines settings for multi-session support +type MultiSessionConfig struct { + Enabled bool `json:"enabled"` + MaxSessions int `json:"max_sessions"` + PrimaryTimeout int `json:"primary_timeout_seconds"` + AllowCloudOverride bool `json:"allow_cloud_override"` + RequireAuthTransfer bool `json:"require_auth_transfer"` +} + type Config struct { CloudURL string `json:"cloud_url"` CloudAppURL string `json:"cloud_app_url"` CloudToken string `json:"cloud_token"` GoogleIdentity string `json:"google_identity"` + MultiSession *MultiSessionConfig `json:"multi_session"` JigglerEnabled bool `json:"jiggler_enabled"` JigglerConfig *JigglerConfig `json:"jiggler_config"` AutoUpdateEnabled bool `json:"auto_update_enabled"` @@ -105,6 +115,7 @@ type Config struct { UsbDevices *usbgadget.Devices `json:"usb_devices"` NetworkConfig *types.NetworkConfig `json:"network_config"` DefaultLogLevel string `json:"default_log_level"` + SessionSettings *SessionSettings `json:"session_settings"` VideoSleepAfterSec int `json:"video_sleep_after_sec"` VideoQualityFactor float64 `json:"video_quality_factor"` } @@ -156,17 +167,31 @@ var ( func getDefaultConfig() Config { return Config{ - CloudURL: "https://api.jetkvm.com", - CloudAppURL: "https://app.jetkvm.com", - AutoUpdateEnabled: true, // Set a default value - ActiveExtension: "", + CloudURL: "https://api.jetkvm.com", + CloudAppURL: "https://app.jetkvm.com", + AutoUpdateEnabled: true, // Set a default value + ActiveExtension: "", + MultiSession: &MultiSessionConfig{ + Enabled: true, // Enable by default for new features + MaxSessions: 10, // Reasonable default + PrimaryTimeout: 300, // 5 minutes + AllowCloudOverride: true, // Cloud sessions can take control + RequireAuthTransfer: false, // Don't require auth by default + }, KeyboardMacros: []KeyboardMacro{}, DisplayRotation: "270", KeyboardLayout: "en-US", DisplayMaxBrightness: 64, DisplayDimAfterSec: 120, // 2 minutes DisplayOffAfterSec: 1800, // 30 minutes - JigglerEnabled: false, + SessionSettings: &SessionSettings{ + RequireApproval: false, + RequireNickname: false, + ReconnectGrace: 10, + PrivateKeystrokes: false, + MaxRejectionAttempts: 3, + }, + JigglerEnabled: false, // This is the "Standard" jiggler option in the UI JigglerConfig: func() *JigglerConfig { c := defaultJigglerConfig; return &c }(), TLSMode: "", @@ -248,6 +273,14 @@ func LoadConfig() { loadedConfig.JigglerConfig = getDefaultConfig().JigglerConfig } + if loadedConfig.MultiSession == nil { + loadedConfig.MultiSession = getDefaultConfig().MultiSession + } + + if loadedConfig.SessionSettings == nil { + loadedConfig.SessionSettings = getDefaultConfig().SessionSettings + } + // fixup old keyboard layout value if loadedConfig.KeyboardLayout == "en_US" { loadedConfig.KeyboardLayout = "en-US" diff --git a/datachannel_helpers.go b/datachannel_helpers.go new file mode 100644 index 000000000..8edfd0952 --- /dev/null +++ b/datachannel_helpers.go @@ -0,0 +1,11 @@ +package kvm + +import "github.com/pion/webrtc/v4" + +func handlePermissionDeniedChannel(d *webrtc.DataChannel, message string) { + d.OnOpen(func() { + _ = d.SendText(message + "\r\n") + d.Close() + }) + d.OnMessage(func(msg webrtc.DataChannelMessage) {}) +} diff --git a/display.go b/display.go index 042bf122b..70fb72c37 100644 --- a/display.go +++ b/display.go @@ -70,7 +70,7 @@ func updateDisplay() { nativeInstance.UpdateLabelIfChanged("hdmi_status_label", "Disconnected") _, _ = nativeInstance.UIObjClearState("hdmi_status_label", "LV_STATE_CHECKED") } - nativeInstance.UpdateLabelIfChanged("cloud_status_label", fmt.Sprintf("%d active", actionSessions)) + nativeInstance.UpdateLabelIfChanged("cloud_status_label", fmt.Sprintf("%d active", getActiveSessions())) if networkManager != nil && networkManager.IsUp() { nativeInstance.UISetVar("main_screen", "home_screen") diff --git a/errors.go b/errors.go new file mode 100644 index 000000000..b287f9382 --- /dev/null +++ b/errors.go @@ -0,0 +1,10 @@ +package kvm + +import "errors" + +var ( + ErrPermissionDeniedKeyboard = errors.New("permission denied: keyboard input") + ErrPermissionDeniedMouse = errors.New("permission denied: mouse input") + ErrNotPrimarySession = errors.New("operation requires primary session") + ErrSessionNotFound = errors.New("session not found") +) diff --git a/hidrpc.go b/hidrpc.go index ebe03daab..bcc0272e8 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -16,6 +16,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { switch message.Type() { case hidrpc.TypeHandshake: + if !session.HasPermission(PermissionVideoView) { + logger.Debug(). + Str("sessionID", session.ID). + Str("mode", string(session.Mode)). + Msg("handshake blocked: session lacks PermissionVideoView") + return + } message, err := hidrpc.NewHandshakeMessage().Marshal() if err != nil { logger.Warn().Err(err).Msg("failed to marshal handshake message") @@ -27,8 +34,18 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } session.hidRPCAvailable = true case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: + if !session.HasPermission(PermissionKeyboardInput) { + logger.Debug(). + Str("sessionID", session.ID). + Str("mode", string(session.Mode)). + Msg("keyboard input blocked: session lacks PermissionKeyboardInput") + return + } rpcErr = handleHidRPCKeyboardInput(message) case hidrpc.TypeKeyboardMacroReport: + if !session.HasPermission(PermissionPaste) { + return + } keyboardMacroReport, err := message.KeyboardMacroReport() if err != nil { logger.Warn().Err(err).Msg("failed to get keyboard macro report") @@ -36,11 +53,24 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } rpcErr = rpcExecuteKeyboardMacro(keyboardMacroReport.Steps) case hidrpc.TypeCancelKeyboardMacroReport: + if !session.HasPermission(PermissionPaste) { + return + } rpcCancelKeyboardMacro() return case hidrpc.TypeKeypressKeepAliveReport: + if !session.HasPermission(PermissionKeyboardInput) { + return + } rpcErr = handleHidRPCKeypressKeepAlive(session) case hidrpc.TypePointerReport: + if !session.HasPermission(PermissionMouseInput) { + logger.Debug(). + Str("sessionID", session.ID). + Str("mode", string(session.Mode)). + Msg("pointer report blocked: session lacks PermissionMouseInput") + return + } pointerReport, err := message.PointerReport() if err != nil { logger.Warn().Err(err).Msg("failed to get pointer report") @@ -48,6 +78,13 @@ func handleHidRPCMessage(message hidrpc.Message, session *Session) { } rpcErr = rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) case hidrpc.TypeMouseReport: + if !session.HasPermission(PermissionMouseInput) { + logger.Debug(). + Str("sessionID", session.ID). + Str("mode", string(session.Mode)). + Msg("mouse report blocked: session lacks PermissionMouseInput") + return + } mouseReport, err := message.MouseReport() if err != nil { logger.Warn().Err(err).Msg("failed to get mouse report") @@ -116,14 +153,15 @@ const baseExtension = expectedRate + maxLateness // 100ms extension on perfect t const maxStaleness = 225 * time.Millisecond // discard ancient packets outright func handleHidRPCKeypressKeepAlive(session *Session) error { + // NOTE: Do NOT update LastActive here - jiggler keep-alives are automated, + // not human input. Only actual keyboard/mouse input should prevent timeout. + session.keepAliveJitterLock.Lock() defer session.keepAliveJitterLock.Unlock() now := time.Now() - // 1) Staleness guard: ensures packets that arrive far beyond the life of a valid key hold - // (e.g. after a network stall, retransmit burst, or machine sleep) are ignored outright. - // This prevents “zombie” keepalives from reviving a key that should already be released. + // Staleness guard: discard ancient packets after network stall/machine sleep if !session.lastTimerResetTime.IsZero() && now.Sub(session.lastTimerResetTime) > maxStaleness { return nil } diff --git a/internal/session/permissions.go b/internal/session/permissions.go new file mode 100644 index 000000000..6db9316e8 --- /dev/null +++ b/internal/session/permissions.go @@ -0,0 +1,306 @@ +package session + +import "fmt" + +// Permission represents a specific action that can be performed +type Permission string + +const ( + // Video/Display permissions + PermissionVideoView Permission = "video.view" + + // Input permissions + PermissionKeyboardInput Permission = "keyboard.input" + PermissionMouseInput Permission = "mouse.input" + PermissionPaste Permission = "clipboard.paste" + + // Session management permissions + PermissionSessionTransfer Permission = "session.transfer" + PermissionSessionApprove Permission = "session.approve" + PermissionSessionKick Permission = "session.kick" + PermissionSessionRequestPrimary Permission = "session.request_primary" + PermissionSessionReleasePrimary Permission = "session.release_primary" + PermissionSessionManage Permission = "session.manage" + + // Power/USB control permissions + PermissionPowerControl Permission = "power.control" + PermissionUSBControl Permission = "usb.control" + + // Mount/Media permissions + PermissionMountMedia Permission = "mount.media" + PermissionUnmountMedia Permission = "mount.unmedia" + PermissionMountList Permission = "mount.list" + + // Extension permissions + PermissionExtensionManage Permission = "extension.manage" + + // Terminal/Serial permissions + PermissionTerminalAccess Permission = "terminal.access" + PermissionSerialAccess Permission = "serial.access" + PermissionExtensionATX Permission = "extension.atx" + PermissionExtensionDC Permission = "extension.dc" + PermissionExtensionSerial Permission = "extension.serial" + PermissionExtensionWOL Permission = "extension.wol" + + // Settings permissions + PermissionSettingsRead Permission = "settings.read" + PermissionSettingsWrite Permission = "settings.write" + PermissionSettingsAccess Permission = "settings.access" // Access control settings + + // System permissions + PermissionSystemReboot Permission = "system.reboot" + PermissionSystemUpdate Permission = "system.update" + PermissionSystemNetwork Permission = "system.network" +) + +// PermissionSet represents a set of permissions +type PermissionSet map[Permission]bool + +// RolePermissions defines permissions for each session mode +var RolePermissions = map[SessionMode]PermissionSet{ + SessionModePrimary: { + // Primary has all permissions + PermissionVideoView: true, + PermissionKeyboardInput: true, + PermissionMouseInput: true, + PermissionPaste: true, + PermissionSessionTransfer: true, + PermissionSessionApprove: true, + PermissionSessionKick: true, + PermissionSessionReleasePrimary: true, + PermissionMountMedia: true, + PermissionUnmountMedia: true, + PermissionMountList: true, + PermissionExtensionManage: true, + PermissionExtensionATX: true, + PermissionExtensionDC: true, + PermissionExtensionSerial: true, + PermissionExtensionWOL: true, + PermissionSettingsRead: true, + PermissionSettingsWrite: true, + PermissionSettingsAccess: true, // Only primary can access settings UI + PermissionSystemReboot: true, + PermissionSystemUpdate: true, + PermissionSystemNetwork: true, + PermissionTerminalAccess: true, + PermissionSerialAccess: true, + PermissionPowerControl: true, + PermissionUSBControl: true, + PermissionSessionManage: true, + PermissionSessionRequestPrimary: false, // Primary doesn't need to request + }, + SessionModeObserver: { + // Observers can only view + PermissionVideoView: true, + PermissionSessionRequestPrimary: true, + PermissionMountList: true, // Can see what's mounted but not mount/unmount + }, + SessionModeQueued: { + // Queued sessions can view and request primary + PermissionVideoView: true, + PermissionSessionRequestPrimary: true, + }, + SessionModePending: { + // Pending sessions have NO permissions until approved + // This prevents unauthorized video access + }, +} + +// CheckPermission checks if a session mode has a specific permission +func CheckPermission(mode SessionMode, perm Permission) bool { + permissions, exists := RolePermissions[mode] + if !exists { + return false + } + return permissions[perm] +} + +// GetPermissionsForMode returns all permissions for a session mode +func GetPermissionsForMode(mode SessionMode) PermissionSet { + permissions, exists := RolePermissions[mode] + if !exists { + return PermissionSet{} + } + + // Return a copy to prevent modification + result := make(PermissionSet) + for k, v := range permissions { + result[k] = v + } + return result +} + +// RequirePermissionForMode is a middleware-like function for RPC handlers +func RequirePermissionForMode(mode SessionMode, perm Permission) error { + if !CheckPermission(mode, perm) { + return fmt.Errorf("permission denied: %s", perm) + } + return nil +} + +// GetPermissionsResponse is the response structure for getPermissions RPC +type GetPermissionsResponse struct { + Mode string `json:"mode"` + Permissions map[string]bool `json:"permissions"` +} + +// MethodPermissions maps RPC methods to required permissions +var MethodPermissions = map[string]Permission{ + // Power/hardware control + "setATXPowerAction": PermissionPowerControl, + "setDCPowerState": PermissionPowerControl, + "setDCRestoreState": PermissionPowerControl, + + // USB device control + "setUsbDeviceState": PermissionUSBControl, + "setUsbDevices": PermissionUSBControl, + + // Mount operations + "mountUsb": PermissionMountMedia, + "unmountUsb": PermissionMountMedia, + "mountBuiltInImage": PermissionMountMedia, + "rpcMountBuiltInImage": PermissionMountMedia, + "unmountImage": PermissionMountMedia, + "mountWithHTTP": PermissionMountMedia, + "mountWithStorage": PermissionMountMedia, + "checkMountUrl": PermissionMountMedia, + "startStorageFileUpload": PermissionMountMedia, + "deleteStorageFile": PermissionMountMedia, + + // Settings operations + "setDevModeState": PermissionSettingsWrite, + "setDevChannelState": PermissionSettingsWrite, + "setAutoUpdateState": PermissionSettingsWrite, + "tryUpdate": PermissionSettingsWrite, + "reboot": PermissionSettingsWrite, + "resetConfig": PermissionSettingsWrite, + "setNetworkSettings": PermissionSettingsWrite, + "setLocalLoopbackOnly": PermissionSettingsWrite, + "renewDHCPLease": PermissionSettingsWrite, + "setSSHKeyState": PermissionSettingsWrite, + "setTLSState": PermissionSettingsWrite, + "setVideoBandwidth": PermissionSettingsWrite, + "setVideoFramerate": PermissionSettingsWrite, + "setVideoResolution": PermissionSettingsWrite, + "setVideoEncoderQuality": PermissionSettingsWrite, + "setVideoSignal": PermissionSettingsWrite, + "setSerialBitrate": PermissionSettingsWrite, + "setSerialSettings": PermissionSettingsWrite, + "setSessionSettings": PermissionSessionManage, + "updateSessionSettings": PermissionSessionManage, + + // Display settings + "setEDID": PermissionSettingsWrite, + "setStreamQualityFactor": PermissionSettingsWrite, + "setDisplayRotation": PermissionSettingsWrite, + "setBacklightSettings": PermissionSettingsWrite, + + // USB/HID settings + "setUsbEmulationState": PermissionSettingsWrite, + "setUsbConfig": PermissionSettingsWrite, + "setKeyboardLayout": PermissionSettingsWrite, + "setJigglerState": PermissionSettingsWrite, + "setJigglerConfig": PermissionSettingsWrite, + "setMassStorageMode": PermissionSettingsWrite, + "setKeyboardMacros": PermissionSettingsWrite, + "setWakeOnLanDevices": PermissionSettingsWrite, + + // Cloud settings + "setCloudUrl": PermissionSettingsWrite, + "deregisterDevice": PermissionSettingsWrite, + + // Active extension control + "setActiveExtension": PermissionExtensionManage, + + // Input operations (already handled in other places but for consistency) + "keyboardReport": PermissionKeyboardInput, + "keypressReport": PermissionKeyboardInput, + "absMouseReport": PermissionMouseInput, + "relMouseReport": PermissionMouseInput, + "wheelReport": PermissionMouseInput, + "executeKeyboardMacro": PermissionPaste, + "cancelKeyboardMacro": PermissionPaste, + + // Session operations + "approveNewSession": PermissionSessionApprove, + "denyNewSession": PermissionSessionApprove, + "transferSession": PermissionSessionTransfer, + "transferPrimary": PermissionSessionTransfer, + "requestPrimary": PermissionSessionRequestPrimary, + "releasePrimary": PermissionSessionReleasePrimary, + + // Extension operations + "activateExtension": PermissionExtensionManage, + "deactivateExtension": PermissionExtensionManage, + "sendWOLMagicPacket": PermissionExtensionWOL, + + // Read operations - require appropriate read permissions + "getSessionSettings": PermissionSettingsRead, + "getSessionConfig": PermissionSettingsRead, + "getSessionData": PermissionVideoView, + "getNetworkSettings": PermissionSettingsRead, + "getSerialSettings": PermissionSettingsRead, + "getBacklightSettings": PermissionSettingsRead, + "getDisplayRotation": PermissionSettingsRead, + "getEDID": PermissionSettingsRead, + "get_edid": PermissionSettingsRead, + "getKeyboardLayout": PermissionSettingsRead, + "getJigglerConfig": PermissionSettingsRead, + "getJigglerState": PermissionSettingsRead, + "getStreamQualityFactor": PermissionSettingsRead, + "getVideoSettings": PermissionSettingsRead, + "getVideoBandwidth": PermissionSettingsRead, + "getVideoFramerate": PermissionSettingsRead, + "getVideoResolution": PermissionSettingsRead, + "getVideoEncoderQuality": PermissionSettingsRead, + "getVideoSignal": PermissionSettingsRead, + "getSerialBitrate": PermissionSettingsRead, + "getDevModeState": PermissionSettingsRead, + "getDevChannelState": PermissionSettingsRead, + "getAutoUpdateState": PermissionSettingsRead, + "getLocalLoopbackOnly": PermissionSettingsRead, + "getSSHKeyState": PermissionSettingsRead, + "getTLSState": PermissionSettingsRead, + "getCloudUrl": PermissionSettingsRead, + "getCloudState": PermissionSettingsRead, + "getNetworkState": PermissionSettingsRead, + + // Mount/media read operations + "getMassStorageMode": PermissionMountList, + "getUsbState": PermissionMountList, + "getUSBState": PermissionMountList, + "listStorageFiles": PermissionMountList, + "getStorageSpace": PermissionMountList, + + // Extension read operations + "getActiveExtension": PermissionSettingsRead, + + // Power state reads + "getATXState": PermissionSettingsRead, + "getDCPowerState": PermissionSettingsRead, + "getDCRestoreState": PermissionSettingsRead, + + // Device info reads (these should be accessible to all) + "getDeviceID": PermissionVideoView, + "getLocalVersion": PermissionVideoView, + "getVideoState": PermissionVideoView, + "getKeyboardLedState": PermissionVideoView, + "getKeyDownState": PermissionVideoView, + "ping": PermissionVideoView, + "getTimezones": PermissionVideoView, + "getSessions": PermissionVideoView, + "getUpdateStatus": PermissionSettingsRead, + "isUpdatePending": PermissionSettingsRead, + "getUsbEmulationState": PermissionSettingsRead, + "getUsbConfig": PermissionSettingsRead, + "getUsbDevices": PermissionSettingsRead, + "getKeyboardMacros": PermissionSettingsRead, + "getWakeOnLanDevices": PermissionSettingsRead, + "getVirtualMediaState": PermissionMountList, +} + +// GetMethodPermission returns the required permission for an RPC method +func GetMethodPermission(method string) (Permission, bool) { + perm, exists := MethodPermissions[method] + return perm, exists +} diff --git a/internal/session/types.go b/internal/session/types.go new file mode 100644 index 000000000..1a983bdb6 --- /dev/null +++ b/internal/session/types.go @@ -0,0 +1,11 @@ +package session + +// SessionMode represents the role/mode of a session +type SessionMode string + +const ( + SessionModePrimary SessionMode = "primary" + SessionModeObserver SessionMode = "observer" + SessionModeQueued SessionMode = "queued" + SessionModePending SessionMode = "pending" +) diff --git a/jsonrpc.go b/jsonrpc.go index 2c06f12b8..af67d24de 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -10,7 +10,9 @@ import ( "os/exec" "path/filepath" "reflect" + "regexp" "strconv" + "strings" "sync" "time" @@ -23,6 +25,40 @@ import ( "github.com/jetkvm/kvm/internal/utils" ) +// nicknameRegex defines the valid pattern for nicknames (matching frontend validation) +var nicknameRegex = regexp.MustCompile(`^[a-zA-Z0-9\s\-_.@]+$`) + +// isValidNickname checks if a nickname contains only valid characters +func isValidNickname(nickname string) bool { + return nicknameRegex.MatchString(nickname) +} + +// Global RPC rate limiting (protects against coordinated DoS from multiple sessions) +var ( + globalRPCRateLimitMu sync.Mutex + globalRPCRateLimit int + globalRPCRateLimitWin time.Time +) + +func checkGlobalRPCRateLimit() bool { + const ( + maxGlobalRPCPerSecond = 2000 + rateLimitWindow = time.Second + ) + + globalRPCRateLimitMu.Lock() + defer globalRPCRateLimitMu.Unlock() + + now := time.Now() + if now.Sub(globalRPCRateLimitWin) > rateLimitWindow { + globalRPCRateLimit = 0 + globalRPCRateLimitWin = now + } + + globalRPCRateLimit++ + return globalRPCRateLimit <= maxGlobalRPCPerSecond +} + type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` @@ -54,11 +90,16 @@ type BacklightSettings struct { } func writeJSONRPCResponse(response JSONRPCResponse, session *Session) { + if session == nil || session.RPCChannel == nil { + return + } + responseBytes, err := json.Marshal(response) if err != nil { jsonRpcLogger.Warn().Err(err).Msg("Error marshalling JSONRPC response") return } + err = session.RPCChannel.SendText(string(responseBytes)) if err != nil { jsonRpcLogger.Warn().Err(err).Msg("Error sending JSONRPC response") @@ -67,6 +108,11 @@ func writeJSONRPCResponse(response JSONRPCResponse, session *Session) { } func writeJSONRPCEvent(event string, params any, session *Session) { + // Defensive checks: skip if session or RPC channel is not ready + if session == nil || session.RPCChannel == nil { + return // Channel not ready or already closed - this is expected during cleanup + } + request := JSONRPCEvent{ JSONRPC: "2.0", Method: event, @@ -77,10 +123,6 @@ func writeJSONRPCEvent(event string, params any, session *Session) { jsonRpcLogger.Warn().Err(err).Msg("Error marshalling JSONRPC event") return } - if session == nil || session.RPCChannel == nil { - jsonRpcLogger.Info().Msg("RPC channel not available") - return - } requestString := string(requestBytes) scopedLogger := jsonRpcLogger.With(). @@ -91,12 +133,59 @@ func writeJSONRPCEvent(event string, params any, session *Session) { err = session.RPCChannel.SendText(requestString) if err != nil { - scopedLogger.Warn().Err(err).Msg("error sending JSONRPC event") + // Check if it's a closed/closing error (expected during reconnection) + errStr := err.Error() + if strings.Contains(errStr, "closed") || strings.Contains(errStr, "closing") { + scopedLogger.Debug().Err(err).Str("event", event).Msg("Could not send JSONRPC event (channel closing)") + } else { + // Other errors (buffer full, protocol errors) should be visible + scopedLogger.Warn().Err(err).Str("event", event).Msg("Failed to send JSONRPC event") + } return } } +func broadcastJSONRPCEvent(event string, params any) { + sessionManager.ForEachSession(func(s *Session) { + writeJSONRPCEvent(event, params, s) + }) +} + func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { + // Global rate limit check (protects against coordinated DoS from multiple sessions) + if !checkGlobalRPCRateLimit() { + jsonRpcLogger.Warn(). + Str("sessionId", session.ID). + Msg("Global RPC rate limit exceeded") + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]any{ + "code": -32000, + "message": "Global rate limit exceeded", + }, + ID: 0, + } + writeJSONRPCResponse(errorResponse, session) + return + } + + // Per-session rate limit check (DoS protection) + if !session.CheckRPCRateLimit() { + jsonRpcLogger.Warn(). + Str("sessionId", session.ID). + Msg("RPC rate limit exceeded") + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]any{ + "code": -32000, + "message": "Rate limit exceeded", + }, + ID: 0, + } + writeJSONRPCResponse(errorResponse, session) + return + } + var request JSONRPCRequest err := json.Unmarshal(message.Data, &request) if err != nil { @@ -124,21 +213,62 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { scopedLogger.Trace().Msg("Received RPC request") - handler, ok := rpcHandlers[request.Method] - if !ok { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]any{ - "code": -32601, - "message": "Method not found", - }, - ID: request.ID, + var result any + var handlerErr error + + // Handle session management RPC methods + switch request.Method { + case "approvePrimaryRequest", "denyPrimaryRequest": + result, handlerErr = handleSessionTransferRPC(request.Method, request.Params, session) + case "approveNewSession", "denyNewSession": + result, handlerErr = handleSessionApprovalRPC(request.Method, request.Params, session) + case "requestSessionApproval": + result, handlerErr = handleRequestSessionApprovalRPC(session) + case "updateSessionNickname": + result, handlerErr = handleUpdateSessionNicknameRPC(request.Params, session) + case "getSessions": + result = sessionManager.GetAllSessions() + case "getPermissions": + result, handlerErr = handleGetPermissionsRPC(session) + case "getSessionSettings", "setSessionSettings": + result, handlerErr = handleSessionSettingsRPC(request.Method, request.Params, session) + case "generateNickname": + result, handlerErr = handleGenerateNicknameRPC(request.Params) + default: + // Check method permissions using centralized permission system + if requiredPerm, exists := GetMethodPermission(request.Method); exists { + if !session.HasPermission(requiredPerm) { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]any{ + "code": -32603, + "message": fmt.Sprintf("Permission denied: %s required", requiredPerm), + }, + ID: request.ID, + } + writeJSONRPCResponse(errorResponse, session) + return + } } - writeJSONRPCResponse(errorResponse, session) - return + + // Fall back to regular handlers + handler, ok := rpcHandlers[request.Method] + if !ok { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]any{ + "code": -32601, + "message": "Method not found", + }, + ID: request.ID, + } + writeJSONRPCResponse(errorResponse, session) + return + } + result, handlerErr = callRPCHandler(scopedLogger, handler, request.Params) } - result, err := callRPCHandler(scopedLogger, handler, request.Params) + err = handlerErr if err != nil { scopedLogger.Error().Err(err).Msg("Error calling RPC handler") errorResponse := JSONRPCResponse{ @@ -154,7 +284,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger.Trace().Interface("result", result).Msg("RPC handler returned") + scopedLogger.Info().Interface("result", result).Msg("RPC handler returned successfully") response := JSONRPCResponse{ JSONRPC: "2.0", @@ -175,7 +305,7 @@ func rpcGetDeviceID() (string, error) { func rpcReboot(force bool) error { logger.Info().Msg("Got reboot request from JSONRPC, rebooting...") - writeJSONRPCEvent("willReboot", nil, currentSession) + broadcastJSONRPCEvent("willReboot", nil) // Wait for the JSONRPCEvent to be sent time.Sleep(1 * time.Second) @@ -1088,6 +1218,78 @@ func rpcSetLocalLoopbackOnly(enabled bool) error { return nil } +func rpcGetSessions() ([]SessionData, error) { + return sessionManager.GetAllSessions(), nil +} + +func rpcGetSessionData(sessionId string) (SessionData, error) { + session := sessionManager.GetSession(sessionId) + if session == nil { + return SessionData{}, ErrSessionNotFound + } + return SessionData{ + ID: session.ID, + Mode: session.Mode, + Source: session.Source, + Identity: session.Identity, + CreatedAt: session.CreatedAt, + LastActive: session.LastActive, + }, nil +} + +func rpcRequestPrimary(sessionId string) map[string]interface{} { + err := sessionManager.RequestPrimary(sessionId) + if err != nil { + return map[string]interface{}{ + "status": "error", + "message": err.Error(), + } + } + + // Check if the session was immediately promoted or queued + session := sessionManager.GetSession(sessionId) + if session == nil { + return map[string]interface{}{ + "status": "error", + "message": "session not found", + } + } + + return map[string]interface{}{ + "status": "success", + "mode": string(session.Mode), + } +} + +func rpcReleasePrimary(sessionId string) error { + return sessionManager.ReleasePrimary(sessionId) +} + +func rpcTransferPrimary(fromId string, toId string) error { + return sessionManager.TransferPrimary(fromId, toId) +} + +func rpcGetSessionConfig() (map[string]interface{}, error) { + maxSessions := 10 + primaryTimeout := 300 + + if config != nil && config.MultiSession != nil { + if config.MultiSession.MaxSessions > 0 { + maxSessions = config.MultiSession.MaxSessions + } + if config.MultiSession.PrimaryTimeout > 0 { + primaryTimeout = config.MultiSession.PrimaryTimeout + } + } + + return map[string]interface{}{ + "enabled": true, + "maxSessions": maxSessions, + "primaryTimeout": primaryTimeout, + "allowCloudOverride": true, + }, nil +} + var ( keyboardMacroCancel context.CancelFunc keyboardMacroLock sync.Mutex @@ -1123,8 +1325,9 @@ func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { IsPaste: true, } - if currentSession != nil { - currentSession.reportHidRPCKeyboardMacroState(s) + // Report to primary session if exists + if primarySession := sessionManager.GetPrimarySession(); primarySession != nil { + primarySession.reportHidRPCKeyboardMacroState(s) } err := rpcDoExecuteKeyboardMacro(ctx, macro) @@ -1132,8 +1335,8 @@ func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { setKeyboardMacroCancel(nil) s.State = false - if currentSession != nil { - currentSession.reportHidRPCKeyboardMacroState(s) + if primarySession := sessionManager.GetPrimarySession(); primarySession != nil { + primarySession.reportHidRPCKeyboardMacroState(s) } return err @@ -1273,4 +1476,10 @@ var rpcHandlers = map[string]RPCHandler{ "setKeyboardMacros": {Func: setKeyboardMacros, Params: []string{"params"}}, "getLocalLoopbackOnly": {Func: rpcGetLocalLoopbackOnly}, "setLocalLoopbackOnly": {Func: rpcSetLocalLoopbackOnly, Params: []string{"enabled"}}, + "getSessions": {Func: rpcGetSessions}, + "getSessionData": {Func: rpcGetSessionData, Params: []string{"sessionId"}}, + "getSessionConfig": {Func: rpcGetSessionConfig}, + "requestPrimary": {Func: rpcRequestPrimary, Params: []string{"sessionId"}}, + "releasePrimary": {Func: rpcReleasePrimary, Params: []string{"sessionId"}}, + "transferPrimary": {Func: rpcTransferPrimary, Params: []string{"fromId", "toId"}}, } diff --git a/jsonrpc_session_handlers.go b/jsonrpc_session_handlers.go new file mode 100644 index 000000000..e744b3d0e --- /dev/null +++ b/jsonrpc_session_handlers.go @@ -0,0 +1,223 @@ +package kvm + +import ( + "errors" + "fmt" +) + +// handleSessionTransferRPC handles primary control transfer requests (approve/deny) +func handleSessionTransferRPC(method string, params map[string]any, session *Session) (any, error) { + requesterID, ok := params["requesterID"].(string) + if !ok { + return nil, errors.New("invalid requesterID parameter") + } + + if err := RequirePermission(session, PermissionSessionTransfer); err != nil { + return nil, err + } + + var err error + switch method { + case "approvePrimaryRequest": + err = sessionManager.ApprovePrimaryRequest(session.ID, requesterID) + if err == nil { + return map[string]interface{}{"status": "approved"}, nil + } + case "denyPrimaryRequest": + err = sessionManager.DenyPrimaryRequest(session.ID, requesterID) + if err == nil { + return map[string]interface{}{"status": "denied"}, nil + } + } + return nil, err +} + +// handleSessionApprovalRPC handles new session approval requests (approve/deny) +func handleSessionApprovalRPC(method string, params map[string]any, session *Session) (any, error) { + sessionID, ok := params["sessionId"].(string) + if !ok { + return nil, errors.New("invalid sessionId parameter") + } + + if err := RequirePermission(session, PermissionSessionApprove); err != nil { + return nil, err + } + + var err error + switch method { + case "approveNewSession": + err = sessionManager.ApproveSession(sessionID) + if err == nil { + go sessionManager.broadcastSessionListUpdate() + return map[string]interface{}{"status": "approved"}, nil + } + case "denyNewSession": + err = sessionManager.DenySession(sessionID) + if err == nil { + if targetSession := sessionManager.GetSession(sessionID); targetSession != nil { + go func() { + writeJSONRPCEvent("sessionAccessDenied", map[string]interface{}{ + "message": "Access denied by primary session", + }, targetSession) + sessionManager.broadcastSessionListUpdate() + }() + } + return map[string]interface{}{"status": "denied"}, nil + } + } + return nil, err +} + +// handleRequestSessionApprovalRPC handles pending sessions requesting approval from primary +func handleRequestSessionApprovalRPC(session *Session) (any, error) { + if session.Mode != SessionModePending { + return nil, errors.New("only pending sessions can request approval") + } + + if currentSessionSettings == nil || !currentSessionSettings.RequireApproval { + return nil, errors.New("session approval not required") + } + + primary := sessionManager.GetPrimarySession() + if primary == nil { + return nil, errors.New("no primary session available") + } + + go func() { + writeJSONRPCEvent("newSessionPending", map[string]interface{}{ + "sessionId": session.ID, + "source": session.Source, + "identity": session.Identity, + "nickname": session.Nickname, + }, primary) + }() + + return map[string]interface{}{"status": "requested"}, nil +} + +func handleUpdateSessionNicknameRPC(params map[string]any, session *Session) (any, error) { + sessionID, _ := params["sessionId"].(string) + nickname, _ := params["nickname"].(string) + + if err := sessionManager.validateNickname(nickname); err != nil { + return nil, err + } + + targetSession := sessionManager.GetSession(sessionID) + if targetSession == nil { + return nil, errors.New("session not found") + } + + if targetSession.ID != session.ID && !session.HasPermission(PermissionSessionManage) { + return nil, errors.New("permission denied: can only update own nickname") + } + + if err := sessionManager.UpdateSessionNickname(sessionID, nickname); err != nil { + return nil, err + } + + // If session is pending and approval is required, send the approval request now that we have a nickname + if targetSession.Mode == SessionModePending && currentSessionSettings != nil && currentSessionSettings.RequireApproval { + if primary := sessionManager.GetPrimarySession(); primary != nil { + go func() { + writeJSONRPCEvent("newSessionPending", map[string]interface{}{ + "sessionId": targetSession.ID, + "source": targetSession.Source, + "identity": targetSession.Identity, + "nickname": targetSession.Nickname, + }, primary) + }() + } + } + + sessionManager.broadcastSessionListUpdate() + return map[string]interface{}{"status": "updated"}, nil +} + +// handleGetPermissionsRPC returns permissions for the current session +func handleGetPermissionsRPC(session *Session) (any, error) { + permissions := session.GetPermissions() + permMap := make(map[string]bool) + for perm, allowed := range permissions { + permMap[string(perm)] = allowed + } + return GetPermissionsResponse{ + Mode: string(session.Mode), + Permissions: permMap, + }, nil +} + +// handleSessionSettingsRPC handles getting or setting session settings +func handleSessionSettingsRPC(method string, params map[string]any, session *Session) (any, error) { + switch method { + case "getSessionSettings": + if err := RequirePermission(session, PermissionSettingsRead); err != nil { + return nil, err + } + return currentSessionSettings, nil + + case "setSessionSettings": + if err := RequirePermission(session, PermissionSessionManage); err != nil { + return nil, err + } + + settings, ok := params["settings"].(map[string]interface{}) + if !ok { + return nil, errors.New("invalid settings parameter") + } + + if requireApproval, ok := settings["requireApproval"].(bool); ok { + currentSessionSettings.RequireApproval = requireApproval + } + if requireNickname, ok := settings["requireNickname"].(bool); ok { + currentSessionSettings.RequireNickname = requireNickname + } + if reconnectGrace, ok := settings["reconnectGrace"].(float64); ok { + currentSessionSettings.ReconnectGrace = int(reconnectGrace) + } + if primaryTimeout, ok := settings["primaryTimeout"].(float64); ok { + currentSessionSettings.PrimaryTimeout = int(primaryTimeout) + } + if privateKeystrokes, ok := settings["privateKeystrokes"].(bool); ok { + currentSessionSettings.PrivateKeystrokes = privateKeystrokes + } + if maxRejectionAttempts, ok := settings["maxRejectionAttempts"].(float64); ok { + currentSessionSettings.MaxRejectionAttempts = int(maxRejectionAttempts) + } + if maxSessions, ok := settings["maxSessions"].(float64); ok { + currentSessionSettings.MaxSessions = int(maxSessions) + } + if observerTimeout, ok := settings["observerTimeout"].(float64); ok { + currentSessionSettings.ObserverTimeout = int(observerTimeout) + } + + if sessionManager != nil { + sessionManager.updateAllSessionNicknames() + } + + if err := SaveConfig(); err != nil { + return nil, errors.New("failed to save session settings") + } + return currentSessionSettings, nil + } + + return nil, fmt.Errorf("unknown session settings method: %s", method) +} + +// handleGenerateNicknameRPC generates a nickname based on user agent +func handleGenerateNicknameRPC(params map[string]any) (any, error) { + userAgent := "" + if params != nil { + if ua, ok := params["userAgent"].(string); ok { + userAgent = ua + } + } + + if userAgent == "" { + userAgent = "Mozilla/5.0 (Unknown) Browser" + } + + return map[string]string{ + "nickname": generateNicknameFromUserAgent(userAgent), + }, nil +} diff --git a/main.go b/main.go index 2648b68d9..2da059da3 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,22 @@ var appCtx context.Context func Main() { LoadConfig() + // Initialize currentSessionSettings to use config's persistent SessionSettings + if config.SessionSettings == nil { + config.SessionSettings = &SessionSettings{ + RequireApproval: false, + RequireNickname: false, + ReconnectGrace: 10, + PrivateKeystrokes: false, + MaxRejectionAttempts: 3, + } + _ = SaveConfig() + } + currentSessionSettings = config.SessionSettings + + // Initialize global session manager (must be called after config and logger are ready) + initSessionManager() + var cancel context.CancelFunc appCtx, cancel = context.WithCancel(context.Background()) defer cancel() @@ -92,7 +108,8 @@ func Main() { continue } - if currentSession != nil { + // Skip update if there's an active primary session + if primarySession := sessionManager.GetPrimarySession(); primarySession != nil { logger.Debug().Msg("skipping update since a session is active") time.Sleep(1 * time.Minute) continue diff --git a/native.go b/native.go index 5f26c0145..fb116c303 100644 --- a/native.go +++ b/native.go @@ -51,12 +51,24 @@ func initNative(systemVersion *semver.Version, appVersion *semver.Version) { } }, OnVideoFrameReceived: func(frame []byte, duration time.Duration) { - if currentSession != nil { - err := currentSession.VideoTrack.WriteSample(media.Sample{Data: frame, Duration: duration}) - if err != nil { - nativeLogger.Warn().Err(err).Msg("error writing sample") + sessionManager.ForEachSession(func(s *Session) { + if !sessionManager.CanReceiveVideo(s, currentSessionSettings) { + return } - } + + if s.VideoTrack != nil { + err := s.VideoTrack.WriteSample(media.Sample{Data: frame, Duration: duration}) + if err != nil { + nativeLogger.Warn(). + Str("sessionID", s.ID). + Err(err). + Msg("error writing sample to session") + } else { + // Update LastActive when video frame successfully sent (prevents observer timeout) + sessionManager.UpdateLastActive(s.ID) + } + } + }) }, }) diff --git a/network.go b/network.go index ff071460f..099fbe06c 100644 --- a/network.go +++ b/network.go @@ -108,9 +108,7 @@ func networkStateChanged(_ string, state types.InterfaceState) { // do not block the main thread go waitCtrlAndRequestDisplayUpdate(true, "network_state_changed") - if currentSession != nil { - writeJSONRPCEvent("networkState", state.ToRpcInterfaceState(), currentSession) - } + broadcastJSONRPCEvent("networkState", state.ToRpcInterfaceState()) if state.Online { networkLogger.Info().Msg("network state changed to online, triggering time sync") @@ -261,7 +259,7 @@ func rpcSetNetworkSettings(settings RpcNetworkSettings) (*RpcNetworkSettings, er // If reboot required, send willReboot event before applying network config if rebootRequired { l.Info().Msg("Sending willReboot event before applying network config") - writeJSONRPCEvent("willReboot", postRebootAction, currentSession) + broadcastJSONRPCEvent("willReboot", postRebootAction) } _ = setHostname(networkManager, netConfig.Hostname.String, netConfig.Domain.String) diff --git a/ota.go b/ota.go index 7063c7ffe..ce3064d12 100644 --- a/ota.go +++ b/ota.go @@ -302,11 +302,7 @@ var otaState = OTAState{} func triggerOTAStateUpdate() { go func() { - if currentSession == nil { - logger.Info().Msg("No active RPC session, skipping update state update") - return - } - writeJSONRPCEvent("otaState", otaState, currentSession) + broadcastJSONRPCEvent("otaState", otaState) }() } diff --git a/serial.go b/serial.go index 5439d135a..c0702eae8 100644 --- a/serial.go +++ b/serial.go @@ -57,12 +57,10 @@ func runATXControl() { newBtnRSTState := line[2] == '1' newBtnPWRState := line[3] == '1' - if currentSession != nil { - writeJSONRPCEvent("atxState", ATXState{ - Power: newLedPWRState, - HDD: newLedHDDState, - }, currentSession) - } + broadcastJSONRPCEvent("atxState", ATXState{ + Power: newLedPWRState, + HDD: newLedHDDState, + }) if newLedHDDState != ledHDDState || newLedPWRState != ledPWRState || @@ -210,9 +208,7 @@ func runDCControl() { // Update Prometheus metrics updateDCMetrics(dcState) - if currentSession != nil { - writeJSONRPCEvent("dcState", dcState, currentSession) - } + broadcastJSONRPCEvent("dcState", dcState) } } @@ -284,9 +280,16 @@ func reopenSerialPort() error { return nil } -func handleSerialChannel(d *webrtc.DataChannel) { +func handleSerialChannel(d *webrtc.DataChannel, session *Session) { scopedLogger := serialLogger.With(). - Uint16("data_channel_id", *d.ID()).Logger() + Uint16("data_channel_id", *d.ID()). + Str("session_id", session.ID).Logger() + + // Check serial access permission + if !session.HasPermission(PermissionSerialAccess) { + handlePermissionDeniedChannel(d, "Serial port access denied: Permission required") + return + } d.OnOpen(func() { go func() { diff --git a/session_cleanup_handlers.go b/session_cleanup_handlers.go new file mode 100644 index 000000000..1d65b70b5 --- /dev/null +++ b/session_cleanup_handlers.go @@ -0,0 +1,363 @@ +package kvm + +import ( + "time" + + "github.com/pion/webrtc/v4" +) + +// emergencyPromotionContext holds context for emergency promotion attempts +type emergencyPromotionContext struct { + triggerSessionID string + triggerReason string + now time.Time +} + +// attemptEmergencyPromotion tries to promote a session using emergency or normal promotion logic +// Returns (promotedSessionID, isEmergency, shouldSkip) +func (sm *SessionManager) attemptEmergencyPromotion(ctx emergencyPromotionContext, excludeSessionID string) (string, bool, bool) { + // Check if emergency promotion is needed + if currentSessionSettings == nil || !currentSessionSettings.RequireApproval { + // Normal promotion - reset consecutive counter + sm.consecutiveEmergencyPromotions = 0 + promotedID := sm.findNextSessionToPromote() + return promotedID, false, false + } + + sm.emergencyWindowMutex.Lock() + defer sm.emergencyWindowMutex.Unlock() + + // CRITICAL: Bypass all rate limits if no primary exists to prevent deadlock + // System availability takes priority over DoS protection + noPrimaryExists := (sm.primarySessionID == "") + if noPrimaryExists { + sm.logger.Info(). + Str("triggerSessionID", ctx.triggerSessionID). + Str("triggerReason", ctx.triggerReason). + Msg("Bypassing emergency promotion rate limits - no primary exists") + + // Find best session, excluding the specified session if provided + var promotedSessionID string + if excludeSessionID != "" { + bestSessionID := "" + bestScore := -1 + for id, session := range sm.sessions { + if id != excludeSessionID && + !sm.isSessionBlacklisted(id) && + (session.Mode == SessionModeObserver || session.Mode == SessionModeQueued) { + score := sm.getSessionTrustScore(id) + if score > bestScore { + bestScore = score + bestSessionID = id + } + } + } + promotedSessionID = bestSessionID + } else { + promotedSessionID = sm.findMostTrustedSessionForEmergency() + } + return promotedSessionID, true, false + } + + const slidingWindowDuration = 60 * time.Second + const maxEmergencyPromotionsPerMinute = 3 + + cutoff := ctx.now.Add(-slidingWindowDuration) + validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow)) + for _, t := range sm.emergencyPromotionWindow { + if t.After(cutoff) { + validEntries = append(validEntries, t) + } + } + sm.emergencyPromotionWindow = validEntries + + if len(sm.emergencyPromotionWindow) >= maxEmergencyPromotionsPerMinute { + sm.logger.Error(). + Str("triggerSessionID", ctx.triggerSessionID). + Int("promotionsInLastMinute", len(sm.emergencyPromotionWindow)). + Msg("Emergency promotion rate limit exceeded - potential attack") + return "", false, true + } + + if ctx.now.Sub(sm.lastEmergencyPromotion) < 10*time.Second { + sm.logger.Warn(). + Str("triggerSessionID", ctx.triggerSessionID). + Dur("timeSinceLastEmergency", ctx.now.Sub(sm.lastEmergencyPromotion)). + Msg("Emergency promotion cooldown active") + return "", false, true + } + + if sm.consecutiveEmergencyPromotions >= 3 { + sm.logger.Error(). + Str("triggerSessionID", ctx.triggerSessionID). + Int("consecutiveCount", sm.consecutiveEmergencyPromotions). + Msg("Too many consecutive emergency promotions - blocking") + return "", false, true + } + + // Find best session for emergency promotion + var promotedSessionID string + if excludeSessionID != "" { + // Need to exclude a specific session (e.g., timed-out session) + bestSessionID := "" + bestScore := -1 + for id, session := range sm.sessions { + if id != excludeSessionID && + !sm.isSessionBlacklisted(id) && + (session.Mode == SessionModeObserver || session.Mode == SessionModeQueued) { + score := sm.getSessionTrustScore(id) + if score > bestScore { + bestScore = score + bestSessionID = id + } + } + } + promotedSessionID = bestSessionID + } else { + promotedSessionID = sm.findMostTrustedSessionForEmergency() + } + + return promotedSessionID, true, false +} + +// handleGracePeriodExpiration checks and handles expired grace periods +// Returns true if any grace period expired +func (sm *SessionManager) handleGracePeriodExpiration(now time.Time) bool { + gracePeriodExpired := false + for sessionID, graceTime := range sm.reconnectGrace { + if now.After(graceTime) { + delete(sm.reconnectGrace, sessionID) + gracePeriodExpired = true + + wasHoldingPrimarySlot := (sm.lastPrimaryID == sessionID) + + if wasHoldingPrimarySlot { + sm.primarySessionID = "" + sm.lastPrimaryID = "" + + sm.logger.Info(). + Str("expiredSessionID", sessionID). + Msg("Primary session grace period expired - slot now available") + + // Promote next eligible session using emergency logic if needed + sm.promoteAfterGraceExpiration(sessionID, now) + } else { + sm.logger.Debug(). + Str("expiredSessionID", sessionID). + Msg("Non-primary session grace period expired") + } + + delete(sm.reconnectInfo, sessionID) + } + } + return gracePeriodExpired +} + +// promoteAfterGraceExpiration handles promotion after grace period expiration +func (sm *SessionManager) promoteAfterGraceExpiration(expiredSessionID string, now time.Time) { + ctx := emergencyPromotionContext{ + triggerSessionID: expiredSessionID, + triggerReason: "grace_expiration", + now: now, + } + + promotedSessionID, isEmergency, shouldSkip := sm.attemptEmergencyPromotion(ctx, "") + if shouldSkip { + return + } + + if promotedSessionID != "" { + reason := "grace_expiration_promotion" + if isEmergency { + reason = "emergency_promotion_deadlock_prevention" + sm.emergencyWindowMutex.Lock() + sm.emergencyPromotionWindow = append(sm.emergencyPromotionWindow, now) + sm.emergencyWindowMutex.Unlock() + sm.lastEmergencyPromotion = now + sm.consecutiveEmergencyPromotions++ + + sm.logger.Warn(). + Str("expiredSessionID", expiredSessionID). + Str("promotedSessionID", promotedSessionID). + Bool("requireApproval", true). + Int("consecutiveEmergencyPromotions", sm.consecutiveEmergencyPromotions). + Int("trustScore", sm.getSessionTrustScore(promotedSessionID)). + Msg("EMERGENCY: Bypassing approval requirement to prevent deadlock") + } + + err := sm.transferPrimaryRole("", promotedSessionID, reason, "primary grace period expired") + if err == nil { + logEvent := sm.logger.Info() + if isEmergency { + logEvent = sm.logger.Warn() + } + logEvent. + Str("expiredSessionID", expiredSessionID). + Str("promotedSessionID", promotedSessionID). + Str("reason", reason). + Bool("isEmergencyPromotion", isEmergency). + Msg("Auto-promoted session after primary grace period expiration") + } else { + sm.logger.Error(). + Err(err). + Str("expiredSessionID", expiredSessionID). + Str("promotedSessionID", promotedSessionID). + Str("reason", reason). + Bool("isEmergencyPromotion", isEmergency). + Msg("Failed to promote session after grace period expiration") + } + } else { + logLevel := sm.logger.Info() + if isEmergency { + logLevel = sm.logger.Error() + } + logLevel. + Str("expiredSessionID", expiredSessionID). + Bool("isEmergencyPromotion", isEmergency). + Msg("Primary grace period expired but no eligible sessions to promote") + } +} + +// handlePendingSessionTimeout removes timed-out pending sessions only if disconnected +// Connected pending sessions remain visible for approval (consistent UX) +// This prevents resource leaks while maintaining good user experience +func (sm *SessionManager) handlePendingSessionTimeout(now time.Time) bool { + toDelete := make([]string, 0) + for id, session := range sm.sessions { + if session.Mode == SessionModePending && + now.Sub(session.CreatedAt) > defaultPendingSessionTimeout { + // Only remove if the connection is closed/failed + // This prevents resource leaks while keeping connected sessions visible + if session.peerConnection != nil { + connectionState := session.peerConnection.ConnectionState() + if connectionState == webrtc.PeerConnectionStateClosed || + connectionState == webrtc.PeerConnectionStateFailed || + connectionState == webrtc.PeerConnectionStateDisconnected { + websocketLogger.Debug(). + Str("sessionId", id). + Dur("age", now.Sub(session.CreatedAt)). + Str("connectionState", connectionState.String()). + Msg("Removing timed-out disconnected pending session") + toDelete = append(toDelete, id) + } + } + } + } + for _, id := range toDelete { + delete(sm.sessions, id) + } + return len(toDelete) > 0 +} + +// handleObserverSessionCleanup removes inactive observer sessions with closed RPC channels +// Returns true if any observer session was removed +func (sm *SessionManager) handleObserverSessionCleanup(now time.Time) bool { + observerTimeout := defaultObserverSessionTimeout + if currentSessionSettings != nil && currentSessionSettings.ObserverTimeout > 0 { + observerTimeout = time.Duration(currentSessionSettings.ObserverTimeout) * time.Second + } + + toDelete := make([]string, 0) + for id, session := range sm.sessions { + if session.Mode == SessionModeObserver { + if session.RPCChannel == nil && now.Sub(session.LastActive) > observerTimeout { + sm.logger.Debug(). + Str("sessionId", id). + Dur("inactiveFor", now.Sub(session.LastActive)). + Dur("observerTimeout", observerTimeout). + Msg("Removing inactive observer session with closed RPC channel") + toDelete = append(toDelete, id) + } + } + } + for _, id := range toDelete { + delete(sm.sessions, id) + } + return len(toDelete) > 0 +} + +// handlePrimarySessionTimeout checks and handles primary session timeout +// Returns true if primary session was timed out and cleanup is needed +func (sm *SessionManager) handlePrimarySessionTimeout(now time.Time) bool { + if sm.primarySessionID == "" { + return false + } + + primary, exists := sm.sessions[sm.primarySessionID] + if !exists { + sm.primarySessionID = "" + return true + } + + currentTimeout := sm.getCurrentPrimaryTimeout() + if now.Sub(primary.LastActive) <= currentTimeout { + return false + } + + // Timeout detected - demote primary + timedOutSessionID := primary.ID + primary.Mode = SessionModeObserver + sm.primarySessionID = "" + + sm.logger.Info(). + Str("sessionID", timedOutSessionID). + Dur("inactiveFor", now.Sub(primary.LastActive)). + Dur("timeout", currentTimeout). + Msg("Primary session timed out due to inactivity - demoted to observer") + + ctx := emergencyPromotionContext{ + triggerSessionID: timedOutSessionID, + triggerReason: "timeout", + now: now, + } + + promotedSessionID, isEmergency, shouldSkip := sm.attemptEmergencyPromotion(ctx, timedOutSessionID) + if shouldSkip { + sm.logger.Info().Msg("Promotion skipped after timeout - session demoted but no promotion") + return true // Still need to broadcast the demotion + } + + if promotedSessionID != "" { + reason := "timeout_promotion" + if isEmergency { + reason = "emergency_timeout_promotion" + sm.emergencyWindowMutex.Lock() + sm.emergencyPromotionWindow = append(sm.emergencyPromotionWindow, now) + sm.emergencyWindowMutex.Unlock() + sm.lastEmergencyPromotion = now + sm.consecutiveEmergencyPromotions++ + + sm.logger.Warn(). + Str("timedOutSessionID", timedOutSessionID). + Str("promotedSessionID", promotedSessionID). + Bool("requireApproval", true). + Int("trustScore", sm.getSessionTrustScore(promotedSessionID)). + Msg("EMERGENCY: Timeout promotion bypassing approval requirement") + } + + err := sm.transferPrimaryRole(timedOutSessionID, promotedSessionID, reason, "primary session timeout") + if err == nil { + logEvent := sm.logger.Info() + if isEmergency { + logEvent = sm.logger.Warn() + } + logEvent. + Str("timedOutSessionID", timedOutSessionID). + Str("promotedSessionID", promotedSessionID). + Bool("isEmergencyPromotion", isEmergency). + Msg("Auto-promoted session after primary timeout") + return true + } else { + sm.logger.Error().Err(err). + Str("timedOutSessionID", timedOutSessionID). + Str("promotedSessionID", promotedSessionID). + Msg("Failed to promote session after timeout - primary demoted") + return true // Still broadcast the demotion even if promotion failed + } + } + + sm.logger.Info(). + Str("timedOutSessionID", timedOutSessionID). + Msg("Primary session timed out - demoted to observer, no eligible sessions to promote") + return true // Broadcast the demotion even if no promotion +} diff --git a/session_manager.go b/session_manager.go new file mode 100644 index 000000000..2f9fca1e5 --- /dev/null +++ b/session_manager.go @@ -0,0 +1,1777 @@ +package kvm + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" +) + +// SessionMode and constants are now imported from internal/session via session_permissions.go + +// Session validation constants +const ( + minNicknameLength = 2 + maxNicknameLength = 30 + maxIdentityLength = 256 +) + +// Timing constants for session management +const ( + // Broadcast throttling (DoS protection) + globalBroadcastDelay = 100 * time.Millisecond // Minimum time between global session broadcasts + sessionBroadcastDelay = 50 * time.Millisecond // Minimum time between broadcasts to a single session + broadcastQueueCapacity = 100 // Maximum pending broadcasts before drops occur + + // Session timeout defaults + defaultPendingSessionTimeout = 1 * time.Minute // Timeout for pending sessions (DoS protection) + defaultObserverSessionTimeout = 2 * time.Minute // Timeout for inactive observer sessions + disabledTimeoutValue = 24 * time.Hour // Value used when timeout is disabled (0 setting) + + // Transfer and blacklist settings + transferBlacklistDuration = 60 * time.Second // Duration to blacklist sessions after manual transfer + + // Grace period limits + maxGracePeriodEntries = 10 // Maximum number of grace period entries to prevent memory exhaustion + + // Emergency promotion limits (DoS protection) + emergencyWindowDuration = 60 * time.Second // Sliding window duration for emergency promotion rate limiting + maxEmergencyPromotionsPerMinute = 3 // Maximum emergency promotions allowed within the sliding window + emergencyPromotionCooldown = 10 * time.Second // Minimum time between individual emergency promotions + maxConsecutiveEmergencyPromotions = 3 // Maximum consecutive emergency promotions before blocking + emergencyPromotionWindowCleanupAge = 60 * time.Second // Age at which emergency window entries are cleaned up + + // Trust scoring constants + invalidSessionTrustScore = -1000 // Trust score for non-existent sessions +) + +var ( + ErrMaxSessionsReached = errors.New("maximum number of sessions reached") +) + +type SessionData struct { + ID string `json:"id"` + Mode SessionMode `json:"mode"` + Source string `json:"source"` + Identity string `json:"identity"` + Nickname string `json:"nickname,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastActive time.Time `json:"last_active"` +} + +// Event types for JSON-RPC notifications +type ( + SessionsUpdateEvent struct { + Sessions []SessionData `json:"sessions"` + YourMode SessionMode `json:"yourMode"` + } + + NewSessionPendingEvent struct { + SessionID string `json:"sessionId"` + Source string `json:"source"` + Identity string `json:"identity"` + Nickname string `json:"nickname,omitempty"` + } + + PrimaryRequestEvent struct { + RequestID string `json:"requestId"` + Source string `json:"source"` + Identity string `json:"identity"` + Nickname string `json:"nickname,omitempty"` + } +) + +// TransferBlacklistEntry prevents recently demoted sessions from immediately becoming primary again +type TransferBlacklistEntry struct { + SessionID string + ExpiresAt time.Time +} + +type SessionManager struct { + mu sync.RWMutex + primaryPromotionLock sync.Mutex + primaryTimeout time.Duration + logger *zerolog.Logger + sessions map[string]*Session + nicknameIndex map[string]*Session + reconnectGrace map[string]time.Time + reconnectInfo map[string]*SessionData + transferBlacklist []TransferBlacklistEntry + queueOrder []string + primarySessionID string + lastPrimaryID string + maxSessions int + cleanupCancel context.CancelFunc + + lastEmergencyPromotion time.Time + consecutiveEmergencyPromotions int + emergencyPromotionWindow []time.Time + emergencyWindowMutex sync.Mutex + + lastBroadcast time.Time + broadcastMutex sync.Mutex + broadcastQueue chan struct{} + broadcastPending atomic.Bool +} + +// NewSessionManager creates a new session manager +func NewSessionManager(logger *zerolog.Logger) *SessionManager { + // Use configuration values if available + maxSessions := 10 + primaryTimeout := 5 * time.Minute + + if config != nil && config.MultiSession != nil { + if config.MultiSession.MaxSessions > 0 { + maxSessions = config.MultiSession.MaxSessions + } + if config.MultiSession.PrimaryTimeout > 0 { + primaryTimeout = time.Duration(config.MultiSession.PrimaryTimeout) * time.Second + } + } + + // Override with session settings if available + if currentSessionSettings != nil { + if currentSessionSettings.PrimaryTimeout > 0 { + primaryTimeout = time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second + } + if currentSessionSettings.MaxSessions > 0 { + maxSessions = currentSessionSettings.MaxSessions + } + } + + sm := &SessionManager{ + sessions: make(map[string]*Session), + nicknameIndex: make(map[string]*Session), + reconnectGrace: make(map[string]time.Time), + reconnectInfo: make(map[string]*SessionData), + transferBlacklist: make([]TransferBlacklistEntry, 0), + queueOrder: make([]string, 0), + logger: logger, + maxSessions: maxSessions, + primaryTimeout: primaryTimeout, + broadcastQueue: make(chan struct{}, broadcastQueueCapacity), + } + + ctx, cancel := context.WithCancel(context.Background()) + sm.cleanupCancel = cancel + go sm.cleanupInactiveSessions(ctx) + go sm.broadcastWorker(ctx) + + return sm +} + +func (sm *SessionManager) AddSession(session *Session, clientSettings *SessionSettings) error { + if session == nil { + sm.logger.Error().Msg("AddSession: session is nil") + return errors.New("session cannot be nil") + } + + if session.Nickname != "" { + if err := sm.validateNickname(session.Nickname); err != nil { + return err + } + } + if len(session.Identity) > maxIdentityLength { + return fmt.Errorf("identity too long (max %d characters)", maxIdentityLength) + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + nicknameReserved := false + defer func() { + if r := recover(); r != nil { + if nicknameReserved && session.Nickname != "" { + if sm.nicknameIndex[session.Nickname] == session { + delete(sm.nicknameIndex, session.Nickname) + } + } + sm.logger.Error().Interface("panic", r).Str("sessionID", session.ID).Msg("Recovered from panic in AddSession") + } + }() + + if session.Nickname != "" { + if existingSession, exists := sm.nicknameIndex[session.Nickname]; exists { + if existingSession.ID != session.ID { + return fmt.Errorf("nickname '%s' is already in use by another session", session.Nickname) + } + } + sm.nicknameIndex[session.Nickname] = session + nicknameReserved = true + } + + wasWithinGracePeriod := false + wasPreviouslyPrimary := false + wasPreviouslyPending := false + if graceTime, exists := sm.reconnectGrace[session.ID]; exists { + if time.Now().Before(graceTime) { + wasWithinGracePeriod = true + wasPreviouslyPrimary = (sm.lastPrimaryID == session.ID) + if reconnectInfo, hasInfo := sm.reconnectInfo[session.ID]; hasInfo { + wasPreviouslyPending = (reconnectInfo.Mode == SessionModePending) + } + } + } + + if existing, exists := sm.sessions[session.ID]; exists { + if existing.Identity != session.Identity || existing.Source != session.Source { + return fmt.Errorf("session ID already in use by different user (identity mismatch)") + } + + if existing.peerConnection != nil { + existing.peerConnection.Close() + } + + existing.peerConnection = session.peerConnection + existing.VideoTrack = session.VideoTrack + existing.ControlChannel = session.ControlChannel + existing.RPCChannel = session.RPCChannel + existing.HidChannel = session.HidChannel + existing.flushCandidates = session.flushCandidates + session.Mode = existing.Mode + session.Nickname = existing.Nickname + session.CreatedAt = existing.CreatedAt + + sm.ensureNickname(session) + + if !nicknameReserved && session.Nickname != "" { + sm.nicknameIndex[session.Nickname] = session + } + + sm.sessions[session.ID] = session + + if existing.Mode == SessionModePrimary { + isBlacklisted := sm.isSessionBlacklisted(session.ID) + // SECURITY: Prevent dual-primary - check actual mode, not just existence + primaryExists := false + if sm.primarySessionID != "" { + if existingPrimary, ok := sm.sessions[sm.primarySessionID]; ok && existingPrimary.Mode == SessionModePrimary { + primaryExists = true + } + } + if sm.lastPrimaryID == session.ID && !isBlacklisted && !primaryExists { + sm.primarySessionID = session.ID + sm.lastPrimaryID = "" + delete(sm.reconnectGrace, session.ID) + } else { + session.Mode = SessionModeObserver + } + } + + go sm.broadcastSessionListUpdate() + return nil + } + + if len(sm.sessions) >= sm.maxSessions { + return ErrMaxSessionsReached + } + + if session.ID == "" { + session.ID = uuid.New().String() + } + + if clientSettings != nil && clientSettings.Nickname != "" { + session.Nickname = clientSettings.Nickname + } + + globalSettings := currentSessionSettings + + primaryExists := sm.primarySessionID != "" && sm.sessions[sm.primarySessionID] != nil + + hasActivePrimaryGracePeriod := false + if sm.lastPrimaryID != "" && sm.lastPrimaryID != session.ID { + if graceTime, exists := sm.reconnectGrace[sm.lastPrimaryID]; exists { + if time.Now().Before(graceTime) { + if reconnectInfo, hasInfo := sm.reconnectInfo[sm.lastPrimaryID]; hasInfo { + if reconnectInfo.Mode == SessionModePrimary { + hasActivePrimaryGracePeriod = true + } + } + } + } + } + + isBlacklisted := sm.isSessionBlacklisted(session.ID) + isOnlySession := len(sm.sessions) == 0 + + canBecomePrimary := !primaryExists && !hasActivePrimaryGracePeriod + isReconnectingPrimary := wasWithinGracePeriod && wasPreviouslyPrimary + isNewEligibleSession := !wasWithinGracePeriod && (!isBlacklisted || isOnlySession) + + shouldBecomePrimary := canBecomePrimary && (isReconnectingPrimary || isNewEligibleSession) + + if shouldBecomePrimary { + if sm.primarySessionID == "" || sm.sessions[sm.primarySessionID] == nil { + session.Mode = SessionModePrimary + sm.primarySessionID = session.ID + sm.lastPrimaryID = "" + + // Clear grace periods when new primary is established + for oldSessionID := range sm.reconnectGrace { + delete(sm.reconnectGrace, oldSessionID) + } + for oldSessionID := range sm.reconnectInfo { + delete(sm.reconnectInfo, oldSessionID) + } + + session.hidRPCAvailable = false + } else { + session.Mode = SessionModeObserver + } + } else if wasPreviouslyPending { + session.Mode = SessionModePending + } else if globalSettings != nil && globalSettings.RequireApproval && primaryExists && !wasWithinGracePeriod { + session.Mode = SessionModePending + // Notify primary about the pending session, but only if nickname is not required OR already provided + if primary := sm.sessions[sm.primarySessionID]; primary != nil { + // Check if nickname is required and missing + requiresNickname := globalSettings.RequireNickname + hasNickname := session.Nickname != "" && len(session.Nickname) > 0 + + if !requiresNickname || hasNickname { + go func() { + writeJSONRPCEvent("newSessionPending", map[string]interface{}{ + "sessionId": session.ID, + "source": session.Source, + "identity": session.Identity, + "nickname": session.Nickname, + }, primary) + }() + } + } + } else { + session.Mode = SessionModeObserver + } + + session.CreatedAt = time.Now() + session.LastActive = time.Now() + + sm.sessions[session.ID] = session + + sm.logger.Info(). + Str("sessionID", session.ID). + Str("mode", string(session.Mode)). + Int("totalSessions", len(sm.sessions)). + Msg("Session added to manager") + + sm.ensureNickname(session) + + if !nicknameReserved && session.Nickname != "" { + sm.nicknameIndex[session.Nickname] = session + } + + sm.validateSinglePrimary() + + // Clean up grace period after validation completes + if wasWithinGracePeriod { + delete(sm.reconnectGrace, session.ID) + delete(sm.reconnectInfo, session.ID) + } + + // Notify all sessions about the new connection + go sm.broadcastSessionListUpdate() + + return nil +} + +// RemoveSession removes a session from the manager +func (sm *SessionManager) RemoveSession(sessionID string) { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return + } + + wasPrimary := session.Mode == SessionModePrimary + delete(sm.sessions, sessionID) + + if session.Nickname != "" { + if sm.nicknameIndex[session.Nickname] == session { + delete(sm.nicknameIndex, session.Nickname) + } + } + + sm.logger.Info(). + Str("sessionID", sessionID). + Bool("wasPrimary", wasPrimary). + Int("remainingSessions", len(sm.sessions)). + Msg("Session removed from manager") + + sm.removeFromQueue(sessionID) + + // Check if this session was marked for immediate removal (intentional logout) + isIntentionalLogout := false + if graceTime, exists := sm.reconnectGrace[sessionID]; exists { + if time.Now().After(graceTime) { + isIntentionalLogout = true + delete(sm.reconnectGrace, sessionID) + delete(sm.reconnectInfo, sessionID) + } + } + + // Determine grace period duration (used for logging even if intentional logout) + gracePeriod := 10 + if currentSessionSettings != nil && currentSessionSettings.ReconnectGrace > 0 { + gracePeriod = currentSessionSettings.ReconnectGrace + } + + // Only add grace period if this is NOT an intentional logout + if !isIntentionalLogout { + // Limit grace period entries to prevent memory exhaustion + // Evict entries ONLY when full, and only evict one entry + if len(sm.reconnectGrace) >= maxGracePeriodEntries { + var evictID string + var earliestExpiration time.Time + for id, graceTime := range sm.reconnectGrace { + // Find the grace period that expires first (earliest time) + if earliestExpiration.IsZero() || graceTime.Before(earliestExpiration) { + evictID = id + earliestExpiration = graceTime + } + } + if evictID != "" { + delete(sm.reconnectGrace, evictID) + delete(sm.reconnectInfo, evictID) + sm.logger.Debug(). + Str("evictedSessionID", evictID). + Msg("Evicted oldest grace period entry due to limit") + } else { + // Defensive: if we couldn't evict, don't add grace period + sm.logger.Error(). + Int("graceCount", len(sm.reconnectGrace)). + Msg("Failed to evict grace period entry, skipping grace period for this session") + goto skipGracePeriod + } + } + + sm.reconnectGrace[sessionID] = time.Now().Add(time.Duration(gracePeriod) * time.Second) + + // Store session info for potential reconnection + sm.reconnectInfo[sessionID] = &SessionData{ + ID: session.ID, + Mode: session.Mode, + Source: session.Source, + Identity: session.Identity, + Nickname: session.Nickname, + CreatedAt: session.CreatedAt, + } + } + +skipGracePeriod: + + // If this was the primary session, clear primary slot and track for grace period + if wasPrimary { + if isIntentionalLogout { + // Intentional logout: clear immediately and promote right away + sm.primarySessionID = "" + sm.lastPrimaryID = "" + sm.logger.Info(). + Str("sessionID", sessionID). + Int("remainingSessions", len(sm.sessions)). + Msg("Primary session removed via intentional logout - immediate promotion") + } else { + // Accidental disconnect: use grace period + sm.lastPrimaryID = sessionID // Remember this was the primary for grace period + sm.primarySessionID = "" // Clear primary slot so other sessions can be promoted + + // Clear all blacklists to allow promotion after grace period expires + if len(sm.transferBlacklist) > 0 { + sm.transferBlacklist = make([]TransferBlacklistEntry, 0) + } + + sm.logger.Info(). + Str("sessionID", sessionID). + Dur("gracePeriod", time.Duration(gracePeriod)*time.Second). + Int("remainingSessions", len(sm.sessions)). + Msg("Primary session removed, grace period active") + } + + // Trigger validation for potential promotion + if len(sm.sessions) > 0 { + sm.validateSinglePrimary() + } + } + + // Notify remaining sessions + go sm.broadcastSessionListUpdate() +} + +// GetSession returns a session by ID +func (sm *SessionManager) GetSession(sessionID string) *Session { + sm.mu.RLock() + session := sm.sessions[sessionID] + sm.mu.RUnlock() + return session +} + +// IsValidReconnection checks if a session ID can be reused for reconnection +func (sm *SessionManager) IsValidReconnection(sessionID, source, identity string) bool { + sm.mu.RLock() + defer sm.mu.RUnlock() + + // Check if session is in reconnect grace period + if info, exists := sm.reconnectInfo[sessionID]; exists { + // Verify the source and identity match + return info.Source == source && info.Identity == identity + } + + return false +} + +// IsInGracePeriod checks if a session ID is within the reconnection grace period +func (sm *SessionManager) IsInGracePeriod(sessionID string) bool { + sm.mu.RLock() + defer sm.mu.RUnlock() + + if graceTime, exists := sm.reconnectGrace[sessionID]; exists { + return time.Now().Before(graceTime) + } + return false +} + +// ClearGracePeriod removes the grace period for a session (for intentional logout/disconnect) +// This marks the session for immediate removal without grace period protection +// Actual promotion will happen in RemoveSession when it detects no grace period +func (sm *SessionManager) ClearGracePeriod(sessionID string) { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Clear grace period and reconnect info to prevent grace period from being added + delete(sm.reconnectGrace, sessionID) + delete(sm.reconnectInfo, sessionID) + + // Mark this session with a special "immediate removal" grace period (already expired) + // This signals to RemoveSession that this was intentional and should skip grace period + sm.reconnectGrace[sessionID] = time.Now().Add(-1 * time.Second) // Already expired + + sm.logger.Info(). + Str("sessionID", sessionID). + Str("lastPrimaryID", sm.lastPrimaryID). + Str("primarySessionID", sm.primarySessionID). + Msg("Marked session for immediate removal (intentional logout)") +} + +// isSessionBlacklisted checks if a session was recently demoted via transfer and should not become primary +func (sm *SessionManager) isSessionBlacklisted(sessionID string) bool { + now := time.Now() + isBlacklisted := false + + // Clean expired entries in-place (zero allocations) + writeIndex := 0 + for readIndex := 0; readIndex < len(sm.transferBlacklist); readIndex++ { + entry := sm.transferBlacklist[readIndex] + if now.Before(entry.ExpiresAt) { + // Keep this entry - still valid + sm.transferBlacklist[writeIndex] = entry + writeIndex++ + if entry.SessionID == sessionID { + isBlacklisted = true + } + } + // Expired entries are automatically skipped (not copied forward) + } + // Truncate to only valid entries + sm.transferBlacklist = sm.transferBlacklist[:writeIndex] + + return isBlacklisted +} + +// GetPrimarySession returns the current primary session +func (sm *SessionManager) GetPrimarySession() *Session { + sm.mu.RLock() + if sm.primarySessionID == "" { + sm.mu.RUnlock() + return nil + } + session := sm.sessions[sm.primarySessionID] + sm.mu.RUnlock() + return session +} + +// SetPrimarySession sets a session as primary +func (sm *SessionManager) SetPrimarySession(sessionID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return ErrSessionNotFound + } + + session.Mode = SessionModePrimary + sm.primarySessionID = sessionID + sm.lastPrimaryID = "" + return nil +} + +// CanReceiveVideo checks if a session is allowed to receive video +// Sessions in pending state cannot receive video +// Sessions that require nickname but don't have one also cannot receive video (if enforced) +func (sm *SessionManager) CanReceiveVideo(session *Session, settings *SessionSettings) bool { + if !session.HasPermission(PermissionVideoView) { + return false + } + + if settings != nil && settings.RequireNickname && session.Nickname == "" { + return false + } + + return true +} + +// GetAllSessions returns information about all active sessions +func (sm *SessionManager) GetAllSessions() []SessionData { + sm.mu.RLock() + defer sm.mu.RUnlock() + + // Don't run validation on every getSessions call + // This was causing immediate demotion during transfers and page refreshes + // Validation should only run during state changes, not data queries + + infos := make([]SessionData, 0, len(sm.sessions)) + for _, session := range sm.sessions { + infos = append(infos, SessionData{ + ID: session.ID, + Mode: session.Mode, + Source: session.Source, + Identity: session.Identity, + Nickname: session.Nickname, + CreatedAt: session.CreatedAt, + LastActive: session.LastActive, + }) + } + return infos +} + +// RequestPrimary requests primary control for a session +func (sm *SessionManager) RequestPrimary(sessionID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return ErrSessionNotFound + } + + // If already primary, nothing to do + if session.Mode == SessionModePrimary { + return nil + } + + // Check if there's a primary in grace period before promoting + if sm.primarySessionID == "" { + // Don't promote immediately if there's a primary waiting in grace period + if sm.lastPrimaryID != "" { + // Check if grace period is still active + if graceTime, exists := sm.reconnectGrace[sm.lastPrimaryID]; exists { + if time.Now().Before(graceTime) { + // Primary is in grace period, queue this request instead + sm.queueOrder = append(sm.queueOrder, sessionID) + session.Mode = SessionModeQueued + sm.logger.Info(). + Str("sessionID", sessionID). + Str("gracePrimaryID", sm.lastPrimaryID). + Msg("Request queued - primary session in grace period") + go sm.broadcastSessionListUpdate() + return nil + } + } + } + + // No grace period conflict, promote immediately using centralized system + err := sm.transferPrimaryRole("", sessionID, "initial_promotion", "first session auto-promotion") + if err == nil { + // Send mode change event after promoting + writeJSONRPCEvent("modeChanged", map[string]string{"mode": "primary"}, session) + go sm.broadcastSessionListUpdate() + } + return err + } + + // Notify the primary session about the request + if primarySession, exists := sm.sessions[sm.primarySessionID]; exists { + event := PrimaryRequestEvent{ + RequestID: sessionID, + Identity: session.Identity, + Source: session.Source, + Nickname: session.Nickname, + } + writeJSONRPCEvent("primaryControlRequested", event, primarySession) + } + + // Add to queue if not already there + if session.Mode != SessionModeQueued { + session.Mode = SessionModeQueued + sm.queueOrder = append(sm.queueOrder, sessionID) + } + + // Broadcast update in goroutine to avoid deadlock + go sm.broadcastSessionListUpdate() + return nil +} + +// ReleasePrimary releases primary control from a session +func (sm *SessionManager) ReleasePrimary(sessionID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return ErrSessionNotFound + } + + if session.Mode != SessionModePrimary { + return nil + } + + // Check if there are other sessions that could take control + hasOtherEligibleSessions := false + for id, s := range sm.sessions { + if id != sessionID && (s.Mode == SessionModeObserver || s.Mode == SessionModeQueued) { + hasOtherEligibleSessions = true + break + } + } + + // Don't allow releasing primary if no one else can take control + if !hasOtherEligibleSessions { + return errors.New("cannot release primary control - no other sessions available") + } + + // Demote to observer + session.Mode = SessionModeObserver + sm.primarySessionID = "" + + // Clear any active input state + sm.clearInputState() + + // Find the next session to promote (excluding the current primary) + // For voluntary releases, ignore blacklisting since this is user-initiated + promotedSessionID := sm.findNextSessionToPromoteExcludingIgnoreBlacklist(sessionID) + + // If we found someone to promote, use centralized transfer + if promotedSessionID != "" { + err := sm.transferPrimaryRole(sessionID, promotedSessionID, "release_transfer", "primary release and auto-promotion") + if err != nil { + sm.logger.Error(). + Str("error", err.Error()). + Str("releasedBySessionID", sessionID). + Str("promotedSessionID", promotedSessionID). + Msg("Failed to transfer primary role after release") + return err + } + + sm.logger.Info(). + Str("releasedBySessionID", sessionID). + Str("promotedSessionID", promotedSessionID). + Msg("Primary control released and transferred to observer") + + // Send mode change event for promoted session + go func() { + if promotedSession := sessionManager.GetSession(promotedSessionID); promotedSession != nil { + writeJSONRPCEvent("modeChanged", map[string]string{"mode": "primary"}, promotedSession) + } + }() + } else { + sm.logger.Warn(). + Str("releasedBySessionID", sessionID). + Msg("Primary control released but no eligible sessions found for promotion") + } + + // Broadcast update in goroutine to avoid deadlock + go sm.broadcastSessionListUpdate() + return nil +} + +// TransferPrimary transfers primary control from one session to another +func (sm *SessionManager) TransferPrimary(fromID, toID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // SECURITY: Verify fromID is the actual current primary + if sm.primarySessionID != fromID { + return fmt.Errorf("transfer denied: %s is not the current primary (current primary: %s)", fromID, sm.primarySessionID) + } + + fromSession, exists := sm.sessions[fromID] + if !exists { + return ErrSessionNotFound + } + + if fromSession.Mode != SessionModePrimary { + return errors.New("transfer denied: from session is not in primary mode") + } + + // Use centralized transfer method + err := sm.transferPrimaryRole(fromID, toID, "direct_transfer", "manual transfer request") + if err != nil { + return err + } + + // Send events in goroutines to avoid holding lock + go func() { + if fromSession := sessionManager.GetSession(fromID); fromSession != nil { + writeJSONRPCEvent("modeChanged", map[string]string{"mode": "observer"}, fromSession) + } + }() + + go func() { + if toSession := sessionManager.GetSession(toID); toSession != nil { + writeJSONRPCEvent("modeChanged", map[string]string{"mode": "primary"}, toSession) + } + sm.broadcastSessionListUpdate() + }() + + return nil +} + +// ApprovePrimaryRequest approves a pending primary control request +func (sm *SessionManager) ApprovePrimaryRequest(currentPrimaryID, requesterID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Log the approval request + sm.logger.Info(). + Str("currentPrimaryID", currentPrimaryID). + Str("requesterID", requesterID). + Str("actualPrimaryID", sm.primarySessionID). + Msg("ApprovePrimaryRequest called") + + // Verify current primary is correct + if sm.primarySessionID != currentPrimaryID { + sm.logger.Error(). + Str("currentPrimaryID", currentPrimaryID). + Str("actualPrimaryID", sm.primarySessionID). + Msg("Not the primary session") + return errors.New("not the primary session") + } + + // SECURITY: Verify requester session exists and is in Queued mode + requesterSession, exists := sm.sessions[requesterID] + if !exists { + sm.logger.Error(). + Str("requesterID", requesterID). + Msg("Requester session not found") + return errors.New("requester session not found") + } + + if requesterSession.Mode != SessionModeQueued { + sm.logger.Error(). + Str("requesterID", requesterID). + Str("actualMode", string(requesterSession.Mode)). + Msg("Requester session is not in queued mode") + return fmt.Errorf("requester session is not in queued mode (current mode: %s)", requesterSession.Mode) + } + + // Remove requester from queue + sm.removeFromQueue(requesterID) + + // Use centralized transfer method + err := sm.transferPrimaryRole(currentPrimaryID, requesterID, "approval_transfer", "primary approval request") + if err != nil { + return err + } + + // Send events after releasing lock to avoid deadlock + go func() { + if demotedSession := sessionManager.GetSession(currentPrimaryID); demotedSession != nil { + writeJSONRPCEvent("modeChanged", map[string]string{"mode": "observer"}, demotedSession) + } + }() + + go func() { + if promotedSession := sessionManager.GetSession(requesterID); promotedSession != nil { + writeJSONRPCEvent("modeChanged", map[string]string{"mode": "primary"}, promotedSession) + } + sm.broadcastSessionListUpdate() + }() + + return nil +} + +// DenyPrimaryRequest denies a pending primary control request +func (sm *SessionManager) DenyPrimaryRequest(currentPrimaryID, requesterID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Verify current primary is correct + if sm.primarySessionID != currentPrimaryID { + return errors.New("not the primary session") + } + + requester, exists := sm.sessions[requesterID] + if !exists { + return ErrSessionNotFound + } + + // Move requester back to observer + requester.Mode = SessionModeObserver + sm.removeFromQueue(requesterID) + + // Validate session consistency after mode change + sm.validateSinglePrimary() + + // Notify requester of denial in goroutine + go func() { + writeJSONRPCEvent("primaryControlDenied", map[string]interface{}{}, requester) + sm.broadcastSessionListUpdate() + }() + + return nil +} + +// ApproveSession approves a pending session (thread-safe) +func (sm *SessionManager) ApproveSession(sessionID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return ErrSessionNotFound + } + + if session.Mode != SessionModePending { + return errors.New("session is not in pending mode") + } + + // Promote session to observer + session.Mode = SessionModeObserver + + sm.logger.Info(). + Str("sessionID", sessionID). + Msg("Session approved and promoted to observer") + + return nil +} + +// DenySession denies a pending session (thread-safe) +func (sm *SessionManager) DenySession(sessionID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[sessionID] + if !exists { + return ErrSessionNotFound + } + + if session.Mode != SessionModePending { + return errors.New("session is not in pending mode") + } + + sm.logger.Info(). + Str("sessionID", sessionID). + Msg("Session denied - notifying session") + + return nil +} + +// ForEachSession executes a function for each active session +func (sm *SessionManager) ForEachSession(fn func(*Session)) { + sm.mu.RLock() + // Create a copy of sessions to avoid holding lock during callbacks + sessionsCopy := make([]*Session, 0, len(sm.sessions)) + for _, session := range sm.sessions { + sessionsCopy = append(sessionsCopy, session) + } + sm.mu.RUnlock() + + // Call function outside of lock to prevent deadlocks + for _, session := range sessionsCopy { + fn(session) + } +} + +// UpdateLastActive updates the last active time for a session +func (sm *SessionManager) UpdateLastActive(sessionID string) { + sm.mu.Lock() + if session, exists := sm.sessions[sessionID]; exists { + session.LastActive = time.Now() + } + sm.mu.Unlock() +} + +// UpdateSessionNickname atomically updates a session's nickname with uniqueness check +func (sm *SessionManager) UpdateSessionNickname(sessionID, nickname string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + targetSession, exists := sm.sessions[sessionID] + if !exists { + return errors.New("session not found") + } + + // Check nickname uniqueness under lock + if existingSession, nicknameInUse := sm.nicknameIndex[nickname]; nicknameInUse { + if existingSession.ID != sessionID { + return fmt.Errorf("nickname '%s' is already in use by another session", nickname) + } + } + + // Remove old nickname from index + if targetSession.Nickname != "" { + delete(sm.nicknameIndex, targetSession.Nickname) + } + + // Update nickname and index atomically + targetSession.Nickname = nickname + sm.nicknameIndex[nickname] = targetSession + + return nil +} + +// Internal helper methods + +// validateSinglePrimary ensures there's only one primary session and fixes any inconsistencies +func (sm *SessionManager) validateSinglePrimary() { + primarySessions := make([]*Session, 0) + + // Find all sessions that think they're primary + for _, session := range sm.sessions { + if session.Mode == SessionModePrimary { + primarySessions = append(primarySessions, session) + } + } + + // If we have multiple primaries, fix it + if len(primarySessions) > 1 { + sm.logger.Error(). + Int("primaryCount", len(primarySessions)). + Msg("Multiple primary sessions detected, fixing") + + // Keep the first one as primary, demote the rest + for i, session := range primarySessions { + if i == 0 { + sm.primarySessionID = session.ID + } else { + session.Mode = SessionModeObserver + } + } + } + + // Ensure manager's primarySessionID matches reality + if len(primarySessions) == 1 && sm.primarySessionID != primarySessions[0].ID { + sm.logger.Warn(). + Str("managerPrimaryID", sm.primarySessionID). + Str("actualPrimaryID", primarySessions[0].ID). + Msg("Manager primary ID mismatch, fixing...") + sm.primarySessionID = primarySessions[0].ID + } + + // Don't clear primary slot if there's a grace period active + if len(primarySessions) == 0 && sm.primarySessionID != "" { + if sm.lastPrimaryID == sm.primarySessionID { + if graceTime, exists := sm.reconnectGrace[sm.primarySessionID]; exists { + if time.Now().Before(graceTime) { + return // Keep primary slot reserved during grace period + } + } + } + sm.primarySessionID = "" + } + + // Check if there's an active grace period for any primary session + hasActivePrimaryGracePeriod := false + for sessionID, graceTime := range sm.reconnectGrace { + if time.Now().Before(graceTime) { + if reconnectInfo, hasInfo := sm.reconnectInfo[sessionID]; hasInfo { + if reconnectInfo.Mode == SessionModePrimary { + hasActivePrimaryGracePeriod = true + break + } + } + } + } + + // Auto-promote if there are NO primary sessions at all AND no active grace period + if len(primarySessions) == 0 && sm.primarySessionID == "" && len(sm.sessions) > 0 && !hasActivePrimaryGracePeriod { + // Find a session to promote to primary + nextSessionID := sm.findNextSessionToPromote() + if nextSessionID != "" { + sm.logger.Info(). + Str("promotedSessionID", nextSessionID). + Msg("Auto-promoting observer to primary - no primary sessions exist and no grace period active") + + // Use the centralized promotion logic + err := sm.transferPrimaryRole("", nextSessionID, "emergency_auto_promotion", "no primary sessions detected") + if err != nil { + sm.logger.Error(). + Err(err). + Str("sessionID", nextSessionID). + Msg("Failed to auto-promote session to primary") + } + } else { + sm.logger.Warn(). + Msg("No eligible session found for emergency auto-promotion") + } + } +} + +func (sm *SessionManager) transferPrimaryRole(fromSessionID, toSessionID, transferType, context string) error { + sm.primaryPromotionLock.Lock() + defer sm.primaryPromotionLock.Unlock() + + // Validate sessions exist + toSession, toExists := sm.sessions[toSessionID] + if !toExists { + return ErrSessionNotFound + } + + // SECURITY: Prevent promoting a session that's already primary + if toSession.Mode == SessionModePrimary { + sm.logger.Warn(). + Str("sessionID", toSessionID). + Str("transferType", transferType). + Msg("Attempted to promote session that is already primary") + return errors.New("target session is already primary") + } + + var fromSession *Session + var fromExists bool + if fromSessionID != "" { + fromSession, fromExists = sm.sessions[fromSessionID] + if !fromExists { + return ErrSessionNotFound + } + } + + // Demote existing primary if specified + if fromExists && fromSession.Mode == SessionModePrimary { + fromSession.Mode = SessionModeObserver + fromSession.hidRPCAvailable = false + + // Always delete grace period when demoting - no exceptions + // If a session times out or is manually transferred, it should not auto-reclaim primary + delete(sm.reconnectGrace, fromSessionID) + delete(sm.reconnectInfo, fromSessionID) + + sm.logger.Info(). + Str("demotedSessionID", fromSessionID). + Str("transferType", transferType). + Str("context", context). + Msg("Demoted existing primary session") + } + + primaryCount := 0 + var existingPrimaryID string + for id, sess := range sm.sessions { + if sess.Mode == SessionModePrimary { + primaryCount++ + if id != toSessionID { + existingPrimaryID = id + } + } + } + + if primaryCount > 1 || (primaryCount == 1 && existingPrimaryID != "" && existingPrimaryID != sm.primarySessionID) { + sm.logger.Error(). + Int("primaryCount", primaryCount). + Str("existingPrimaryID", existingPrimaryID). + Str("targetPromotionID", toSessionID). + Str("managerPrimaryID", sm.primarySessionID). + Str("transferType", transferType). + Msg("CRITICAL: Dual-primary corruption detected - forcing fix") + + for id, sess := range sm.sessions { + if sess.Mode == SessionModePrimary { + if id != sm.primarySessionID && id != toSessionID { + sess.Mode = SessionModeObserver + sm.logger.Warn(). + Str("demotedSessionID", id). + Msg("Force-demoted session due to dual-primary corruption") + } + } + } + + if sm.primarySessionID != "" && sm.sessions[sm.primarySessionID] != nil { + if sm.sessions[sm.primarySessionID].Mode != SessionModePrimary { + sm.primarySessionID = "" + } + } + + existingPrimaryID = "" + for id, sess := range sm.sessions { + if id != toSessionID && sess.Mode == SessionModePrimary { + existingPrimaryID = id + break + } + } + + if existingPrimaryID != "" { + sm.logger.Error(). + Str("existingPrimaryID", existingPrimaryID). + Str("targetPromotionID", toSessionID). + Msg("CRITICAL: Cannot fix dual-primary corruption - blocking promotion") + return fmt.Errorf("cannot promote: dual-primary corruption detected and fix failed (%s)", existingPrimaryID) + } + } else if existingPrimaryID != "" { + sm.logger.Error(). + Str("existingPrimaryID", existingPrimaryID). + Str("targetPromotionID", toSessionID). + Str("transferType", transferType). + Msg("CRITICAL: Attempted to create second primary - blocking promotion") + return fmt.Errorf("cannot promote: another primary session exists (%s)", existingPrimaryID) + } + + toSession.Mode = SessionModePrimary + toSession.hidRPCAvailable = false + if transferType == "emergency_timeout_promotion" { + toSession.LastActive = time.Now() + } + sm.primarySessionID = toSessionID + + // ALWAYS set lastPrimaryID to the new primary to support WebRTC reconnections + // This allows the newly promoted session to handle page refreshes correctly + // The blacklist system prevents unwanted takeovers during manual transfers + sm.lastPrimaryID = toSessionID + + // Clear input state + sm.clearInputState() + + // Reset consecutive emergency promotion counter on successful manual transfer + if fromSessionID != "" && transferType != "emergency_promotion_deadlock_prevention" && transferType != "emergency_timeout_promotion" { + sm.consecutiveEmergencyPromotions = 0 + } + + // Apply bidirectional blacklisting - protect newly promoted session + // Only apply blacklisting for MANUAL transfers, not emergency promotions + // Emergency promotions need to happen immediately without blacklist interference + isManualTransfer := (transferType == "direct_transfer" || transferType == "approval_transfer" || transferType == "release_transfer") + now := time.Now() + blacklistedCount := 0 + + if isManualTransfer { + // First, clear any existing blacklist entries for the newly promoted session + cleanedBlacklist := make([]TransferBlacklistEntry, 0) + for _, entry := range sm.transferBlacklist { + if entry.SessionID != toSessionID { // Remove any old blacklist entries for the new primary + cleanedBlacklist = append(cleanedBlacklist, entry) + } + } + sm.transferBlacklist = cleanedBlacklist + + // Then blacklist all other sessions + for sessionID := range sm.sessions { + if sessionID != toSessionID { // Don't blacklist the newly promoted session + sm.transferBlacklist = append(sm.transferBlacklist, TransferBlacklistEntry{ + SessionID: sessionID, + ExpiresAt: now.Add(transferBlacklistDuration), + }) + blacklistedCount++ + } + } + } + + // Grace periods are cleared for demoted sessions (line 519-520) to prevent them from + // auto-reclaiming primary after manual transfer. New grace periods are created when + // sessions reconnect via RemoveSession. The blacklist provides additional protection + // during the transfer window, while lastPrimaryID allows the newly promoted session + // to safely handle browser refreshes and reclaim primary if disconnected. + + sm.logger.Info(). + Str("fromSessionID", fromSessionID). + Str("toSessionID", toSessionID). + Str("transferType", transferType). + Str("context", context). + Int("blacklistedSessions", blacklistedCount). + Dur("blacklistDuration", transferBlacklistDuration). + Msg("Primary role transferred with bidirectional protection") + + // DON'T validate here - causes recursive calls and map iteration issues + // The caller (AddSession, RemoveSession, etc.) will validate after we return + // sm.validateSinglePrimary() // REMOVED to prevent recursion + + // Send reconnection signal for emergency promotions via WebSocket (more reliable than RPC when channel is stale) + if toExists && (transferType == "emergency_timeout_promotion" || transferType == "emergency_auto_promotion") { + go func() { + time.Sleep(globalBroadcastDelay) + + eventData := map[string]interface{}{ + "sessionId": toSessionID, + "newMode": string(toSession.Mode), + "reason": "session_promotion", + "action": "reconnect_required", + "timestamp": time.Now().Unix(), + } + + err := toSession.sendWebSocketSignal("connectionModeChanged", eventData) + if err != nil { + sm.logger.Warn().Err(err).Str("sessionId", toSessionID).Msg("WebSocket signal failed, using RPC") + writeJSONRPCEvent("connectionModeChanged", eventData, toSession) + } + + sm.logger.Info().Str("sessionId", toSessionID).Str("transferType", transferType).Msg("Sent reconnection signal") + }() + } + + return nil +} + +// findNextSessionToPromote finds the next eligible session for promotion +// Replicates the logic from promoteNextSession but just returns the session ID +func (sm *SessionManager) findNextSessionToPromote() string { + return sm.findNextSessionToPromoteExcluding("", true) +} + +func (sm *SessionManager) findNextSessionToPromoteExcluding(excludeSessionID string, checkBlacklist bool) string { + // First, check if there are queued sessions (excluding the specified session) + if len(sm.queueOrder) > 0 { + nextID := sm.queueOrder[0] + if nextID != excludeSessionID { + if _, exists := sm.sessions[nextID]; exists { + if !checkBlacklist || !sm.isSessionBlacklisted(nextID) { + return nextID + } + } + } + } + + // Otherwise, find any observer session (excluding the specified session) + for id, session := range sm.sessions { + if id != excludeSessionID && session.Mode == SessionModeObserver { + if !checkBlacklist || !sm.isSessionBlacklisted(id) { + return id + } + } + } + + // If still no primary and there are pending sessions (edge case: all sessions are pending) + // This can happen if RequireApproval was enabled but primary left + for id, session := range sm.sessions { + if id != excludeSessionID && session.Mode == SessionModePending { + if !checkBlacklist || !sm.isSessionBlacklisted(id) { + return id + } + } + } + + return "" // No eligible session found +} + +func (sm *SessionManager) findNextSessionToPromoteExcludingIgnoreBlacklist(excludeSessionID string) string { + return sm.findNextSessionToPromoteExcluding(excludeSessionID, false) +} + +func (sm *SessionManager) removeFromQueue(sessionID string) { + // In-place removal is more efficient + for i, id := range sm.queueOrder { + if id == sessionID { + sm.queueOrder = append(sm.queueOrder[:i], sm.queueOrder[i+1:]...) + return + } + } +} + +func (sm *SessionManager) clearInputState() { + // Clear keyboard state + if gadget != nil { + _ = gadget.KeyboardReport(0, []byte{0, 0, 0, 0, 0, 0}) + } +} + +// getCurrentPrimaryTimeout returns the current primary timeout duration +func (sm *SessionManager) getCurrentPrimaryTimeout() time.Duration { + // Use session settings if available + if currentSessionSettings != nil { + if currentSessionSettings.PrimaryTimeout == 0 { + return disabledTimeoutValue + } else if currentSessionSettings.PrimaryTimeout > 0 { + return time.Duration(currentSessionSettings.PrimaryTimeout) * time.Second + } + } + // Fall back to config or default + return sm.primaryTimeout +} + +// getSessionTrustScore calculates a trust score for session selection during emergency promotion +func (sm *SessionManager) getSessionTrustScore(sessionID string) int { + session, exists := sm.sessions[sessionID] + if !exists { + return invalidSessionTrustScore + } + + score := 0 + now := time.Now() + + // Longer session duration = more trust (up to 100 points for 100+ minutes) + sessionAge := now.Sub(session.CreatedAt) + sessionAgeMinutes := sessionAge.Minutes() + if sessionAgeMinutes > 100 { + score += 100 + } else { + score += int(sessionAgeMinutes) + } + + // Recently successful primary sessions get higher trust + if sm.lastPrimaryID == sessionID { + score += 50 + } + + // Observer mode is more trustworthy than queued/pending for emergency promotion + switch session.Mode { + case SessionModeObserver: + score += 20 + case SessionModeQueued: + score += 10 + case SessionModePending: + // Pending sessions get no bonus and are less preferred + score += 0 + } + + // Check if session has nickname when required (shows engagement) + if currentSessionSettings != nil && currentSessionSettings.RequireNickname { + if session.Nickname != "" { + score += 15 + } else { + score -= 30 // Penalize sessions without required nickname + } + } + + return score +} + +// findMostTrustedSessionForEmergency finds the most trustworthy session for emergency promotion +func (sm *SessionManager) findMostTrustedSessionForEmergency() string { + bestSessionID := "" + bestScore := -1 + + for sessionID, session := range sm.sessions { + if sm.isSessionBlacklisted(sessionID) || + session.Mode == SessionModePrimary || + (session.Mode != SessionModeObserver && session.Mode != SessionModeQueued) { + continue + } + + score := sm.getSessionTrustScore(sessionID) + if score > bestScore { + bestScore = score + bestSessionID = sessionID + } + } + + if bestSessionID != "" { + sm.logger.Info(). + Str("selectedSession", bestSessionID). + Int("trustScore", bestScore). + Msg("Selected most trusted session for emergency promotion") + } + + return bestSessionID +} + +// extractBrowserFromUserAgent extracts browser name from user agent string +func extractBrowserFromUserAgent(userAgent string) *string { + ua := strings.ToLower(userAgent) + + // Check for common browsers (order matters - Chrome contains Safari, etc.) + // Optimize Safari check by caching Chrome detection + hasChrome := strings.Contains(ua, "chrome") + + if strings.Contains(ua, "edg/") || strings.Contains(ua, "edge") { + return &BrowserEdge + } + if strings.Contains(ua, "firefox") { + return &BrowserFirefox + } + if hasChrome { + return &BrowserChrome + } + if strings.Contains(ua, "safari") { + return &BrowserSafari + } + if strings.Contains(ua, "opera") || strings.Contains(ua, "opr/") { + return &BrowserOpera + } + + return &BrowserUnknown +} + +// generateAutoNickname creates a user-friendly auto-generated nickname +func generateAutoNickname(session *Session) string { + // Use browser type from session, fallback to "user" if not set + browser := "user" + if session.Browser != nil { + browser = *session.Browser + } + + // Use last 4 chars of session ID for uniqueness (lowercase) + sessionID := strings.ToLower(session.ID) + shortID := sessionID[len(sessionID)-4:] + + // Generate contextual lowercase nickname + return fmt.Sprintf("u-%s-%s", browser, shortID) +} + +// generateNicknameFromUserAgent creates a nickname from user agent (for frontend use) +func generateNicknameFromUserAgent(userAgent string) string { + // Extract browser info + browserPtr := extractBrowserFromUserAgent(userAgent) + browser := "user" + if browserPtr != nil { + browser = *browserPtr + } + + // Generate a random 4-character ID (lowercase) + shortID := strings.ToLower(fmt.Sprintf("%04x", time.Now().UnixNano()%0xFFFF)) + + // Generate contextual lowercase nickname + return fmt.Sprintf("u-%s-%s", browser, shortID) +} + +// ensureNickname ensures session has a nickname, auto-generating if needed +func (sm *SessionManager) validateNickname(nickname string) error { + if len(nickname) < minNicknameLength { + return fmt.Errorf("nickname must be at least %d characters", minNicknameLength) + } + if len(nickname) > maxNicknameLength { + return fmt.Errorf("nickname must be %d characters or less", maxNicknameLength) + } + if !isValidNickname(nickname) { + return errors.New("nickname can only contain letters, numbers, spaces, and - _ . @") + } + + for i, r := range nickname { + if r < 32 || r == 127 { + return fmt.Errorf("nickname contains control character at position %d", i) + } + if r >= 0x200B && r <= 0x200D { + return errors.New("nickname contains zero-width character") + } + } + + trimmed := "" + for _, r := range nickname { + trimmed += string(r) + } + if trimmed != nickname { + return errors.New("nickname contains disallowed unicode") + } + + return nil +} + +func (sm *SessionManager) ensureNickname(session *Session) { + // Skip if session already has a nickname + if session.Nickname != "" { + return + } + + // Skip if nickname is required (user must set manually) + if currentSessionSettings != nil && currentSessionSettings.RequireNickname { + return + } + + // Auto-generate nickname + session.Nickname = generateAutoNickname(session) + + sm.logger.Debug(). + Str("sessionID", session.ID). + Str("autoNickname", session.Nickname). + Msg("Auto-generated nickname for session") +} + +// updateAllSessionNicknames updates nicknames for all sessions when settings change +func (sm *SessionManager) updateAllSessionNicknames() { + sm.mu.Lock() + defer sm.mu.Unlock() + + updated := 0 + for _, session := range sm.sessions { + oldNickname := session.Nickname + sm.ensureNickname(session) + if session.Nickname != oldNickname { + updated++ + } + } + + if updated > 0 { + sm.logger.Info(). + Int("updatedSessions", updated). + Msg("Auto-generated nicknames for sessions after settings change") + + // Broadcast the update + go sm.broadcastSessionListUpdate() + } +} + +func (sm *SessionManager) broadcastWorker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-sm.broadcastQueue: + sm.broadcastPending.Store(false) + sm.executeBroadcast() + } + } +} + +func (sm *SessionManager) broadcastSessionListUpdate() { + if sm.broadcastPending.CompareAndSwap(false, true) { + select { + case sm.broadcastQueue <- struct{}{}: + default: + sm.logger.Warn(). + Int("queueLen", len(sm.broadcastQueue)). + Int("queueCap", cap(sm.broadcastQueue)). + Msg("Broadcast queue full, dropping update") + sm.broadcastPending.Store(false) + } + } +} + +func (sm *SessionManager) executeBroadcast() { + sm.broadcastMutex.Lock() + if time.Since(sm.lastBroadcast) < globalBroadcastDelay { + sm.broadcastMutex.Unlock() + return + } + sm.lastBroadcast = time.Now() + sm.broadcastMutex.Unlock() + + sm.mu.RLock() + infos := make([]SessionData, 0, len(sm.sessions)) + activeSessions := make([]*Session, 0, len(sm.sessions)) + + for _, session := range sm.sessions { + infos = append(infos, SessionData{ + ID: session.ID, + Mode: session.Mode, + Source: session.Source, + Identity: session.Identity, + Nickname: session.Nickname, + CreatedAt: session.CreatedAt, + LastActive: session.LastActive, + }) + + if session.RPCChannel != nil { + activeSessions = append(activeSessions, session) + } + } + sm.mu.RUnlock() + + for _, session := range activeSessions { + session.lastBroadcastMu.Lock() + shouldSkip := time.Since(session.LastBroadcast) < sessionBroadcastDelay + if !shouldSkip { + session.LastBroadcast = time.Now() + } + session.lastBroadcastMu.Unlock() + + if shouldSkip { + continue + } + + event := SessionsUpdateEvent{ + Sessions: infos, + YourMode: session.Mode, + } + writeJSONRPCEvent("sessionsUpdated", event, session) + } +} + +// Shutdown stops the session manager and cleans up resources +func (sm *SessionManager) Shutdown() { + if sm.cleanupCancel != nil { + sm.cleanupCancel() + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + close(sm.broadcastQueue) + + for id := range sm.sessions { + delete(sm.sessions, id) + } +} + +func (sm *SessionManager) cleanupInactiveSessions(ctx context.Context) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + validationCounter := 0 + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sm.mu.Lock() + now := time.Now() + needsBroadcast := false + + // Clean up expired emergency promotion window entries + sm.emergencyWindowMutex.Lock() + cutoff := now.Add(-emergencyPromotionWindowCleanupAge) + validEntries := make([]time.Time, 0, len(sm.emergencyPromotionWindow)) + for _, t := range sm.emergencyPromotionWindow { + if t.After(cutoff) { + validEntries = append(validEntries, t) + } + } + sm.emergencyPromotionWindow = validEntries + sm.emergencyWindowMutex.Unlock() + + // Handle expired grace periods + gracePeriodExpired := sm.handleGracePeriodExpiration(now) + if gracePeriodExpired { + needsBroadcast = true + } + + // Clean up timed-out pending sessions (DoS protection) + if sm.handlePendingSessionTimeout(now) { + needsBroadcast = true + } + + // Clean up inactive observer sessions + if sm.handleObserverSessionCleanup(now) { + needsBroadcast = true + } + + // Handle primary session timeout + if sm.handlePrimarySessionTimeout(now) { + needsBroadcast = true + } + + // Run validation immediately if grace period expired, otherwise periodically + if gracePeriodExpired { + sm.validateSinglePrimary() + } else { + validationCounter++ + if validationCounter >= 10 { + validationCounter = 0 + sm.validateSinglePrimary() + } + } + + sm.mu.Unlock() + + if needsBroadcast { + go sm.broadcastSessionListUpdate() + } + } + } +} + +// Global session manager instance +var ( + sessionManager *SessionManager + sessionManagerOnce sync.Once +) + +func initSessionManager() { + sessionManagerOnce.Do(func() { + sessionManager = NewSessionManager(websocketLogger) + }) +} + +// Global session settings - references config.SessionSettings for persistence +var currentSessionSettings *SessionSettings diff --git a/session_permissions.go b/session_permissions.go new file mode 100644 index 000000000..05a1bcbee --- /dev/null +++ b/session_permissions.go @@ -0,0 +1,77 @@ +package kvm + +import ( + "github.com/jetkvm/kvm/internal/session" +) + +type ( + Permission = session.Permission + PermissionSet = session.PermissionSet + SessionMode = session.SessionMode +) + +const ( + SessionModePrimary = session.SessionModePrimary + SessionModeObserver = session.SessionModeObserver + SessionModeQueued = session.SessionModeQueued + SessionModePending = session.SessionModePending + + PermissionVideoView = session.PermissionVideoView + PermissionKeyboardInput = session.PermissionKeyboardInput + PermissionMouseInput = session.PermissionMouseInput + PermissionPaste = session.PermissionPaste + PermissionSessionTransfer = session.PermissionSessionTransfer + PermissionSessionApprove = session.PermissionSessionApprove + PermissionSessionKick = session.PermissionSessionKick + PermissionSessionRequestPrimary = session.PermissionSessionRequestPrimary + PermissionSessionReleasePrimary = session.PermissionSessionReleasePrimary + PermissionSessionManage = session.PermissionSessionManage + PermissionPowerControl = session.PermissionPowerControl + PermissionUSBControl = session.PermissionUSBControl + PermissionMountMedia = session.PermissionMountMedia + PermissionUnmountMedia = session.PermissionUnmountMedia + PermissionMountList = session.PermissionMountList + PermissionExtensionManage = session.PermissionExtensionManage + PermissionExtensionATX = session.PermissionExtensionATX + PermissionExtensionDC = session.PermissionExtensionDC + PermissionExtensionSerial = session.PermissionExtensionSerial + PermissionExtensionWOL = session.PermissionExtensionWOL + PermissionTerminalAccess = session.PermissionTerminalAccess + PermissionSerialAccess = session.PermissionSerialAccess + PermissionSettingsRead = session.PermissionSettingsRead + PermissionSettingsWrite = session.PermissionSettingsWrite + PermissionSettingsAccess = session.PermissionSettingsAccess + PermissionSystemReboot = session.PermissionSystemReboot + PermissionSystemUpdate = session.PermissionSystemUpdate + PermissionSystemNetwork = session.PermissionSystemNetwork +) + +var ( + GetMethodPermission = session.GetMethodPermission +) + +type GetPermissionsResponse = session.GetPermissionsResponse + +func (s *Session) HasPermission(perm Permission) bool { + if s == nil { + return false + } + return session.CheckPermission(s.Mode, perm) +} + +func (s *Session) GetPermissions() PermissionSet { + if s == nil { + return PermissionSet{} + } + return session.GetPermissionsForMode(s.Mode) +} + +func RequirePermission(s *Session, perm Permission) error { + if s == nil { + return session.RequirePermissionForMode(SessionModePending, perm) + } + if !s.HasPermission(perm) { + return session.RequirePermissionForMode(s.Mode, perm) + } + return nil +} diff --git a/terminal.go b/terminal.go index e06e5cdc1..ea13087c9 100644 --- a/terminal.go +++ b/terminal.go @@ -16,9 +16,16 @@ type TerminalSize struct { Cols int `json:"cols"` } -func handleTerminalChannel(d *webrtc.DataChannel) { +func handleTerminalChannel(d *webrtc.DataChannel, session *Session) { scopedLogger := terminalLogger.With(). - Uint16("data_channel_id", *d.ID()).Logger() + Uint16("data_channel_id", *d.ID()). + Str("session_id", session.ID).Logger() + + // Check terminal access permission + if !session.HasPermission(PermissionTerminalAccess) { + handlePermissionDeniedChannel(d, "Terminal access denied: Permission required") + return + } var ptmx *os.File var cmd *exec.Cmd diff --git a/ui/src/api/sessionApi.ts b/ui/src/api/sessionApi.ts new file mode 100644 index 000000000..b6602fe46 --- /dev/null +++ b/ui/src/api/sessionApi.ts @@ -0,0 +1,132 @@ +import { SessionInfo } from "@/stores/sessionStore"; + +interface JsonRpcResponse { + result?: unknown; + error?: { message: string }; +} + +type RpcSendFunction = (method: string, params: Record, callback: (response: JsonRpcResponse) => void) => void; + +export const sessionApi = { + getSessions: async (sendFn: RpcSendFunction): Promise => { + return new Promise((resolve, reject) => { + sendFn("getSessions", {}, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve((response.result as SessionInfo[]) || []); + } + }); + }); + }, + + getSessionInfo: async (sendFn: RpcSendFunction, sessionId: string): Promise => { + return new Promise((resolve, reject) => { + sendFn("getSessionInfo", { sessionId }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(response.result as SessionInfo); + } + }); + }); + }, + + requestPrimary: async (sendFn: RpcSendFunction, sessionId: string): Promise<{ status: string; mode?: string; message?: string }> => { + return new Promise((resolve, reject) => { + sendFn("requestPrimary", { sessionId }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(response.result as { status: string; mode?: string; message?: string }); + } + }); + }); + }, + + releasePrimary: async (sendFn: RpcSendFunction, sessionId: string): Promise => { + return new Promise((resolve, reject) => { + sendFn("releasePrimary", { sessionId }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(); + } + }); + }); + }, + + transferPrimary: async ( + sendFn: RpcSendFunction, + fromId: string, + toId: string + ): Promise => { + return new Promise((resolve, reject) => { + sendFn("transferPrimary", { fromId, toId }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(); + } + }); + }); + }, + + updateNickname: async ( + sendFn: RpcSendFunction, + sessionId: string, + nickname: string + ): Promise => { + return new Promise((resolve, reject) => { + sendFn("updateSessionNickname", { sessionId, nickname }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(); + } + }); + }); + }, + + approveNewSession: async ( + sendFn: RpcSendFunction, + sessionId: string + ): Promise => { + return new Promise((resolve, reject) => { + sendFn("approveNewSession", { sessionId }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(); + } + }); + }); + }, + + denyNewSession: async ( + sendFn: RpcSendFunction, + sessionId: string + ): Promise => { + return new Promise((resolve, reject) => { + sendFn("denyNewSession", { sessionId }, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(); + } + }); + }); + }, + + requestSessionApproval: async (sendFn: RpcSendFunction): Promise => { + return new Promise((resolve, reject) => { + sendFn("requestSessionApproval", {}, (response: JsonRpcResponse) => { + if (response.error) { + reject(new Error(response.error.message)); + } else { + resolve(); + } + }); + }); + } +}; \ No newline at end of file diff --git a/ui/src/components/AccessDeniedOverlay.tsx b/ui/src/components/AccessDeniedOverlay.tsx new file mode 100644 index 000000000..a04f23cc4 --- /dev/null +++ b/ui/src/components/AccessDeniedOverlay.tsx @@ -0,0 +1,165 @@ +import { useEffect, useState, useCallback, useRef } from "react"; +import { useNavigate } from "react-router"; +import { XCircleIcon } from "@heroicons/react/24/outline"; + +import { DEVICE_API, CLOUD_API } from "@/ui.config"; +import { isOnDevice } from "@/main"; +import { useUserStore, useSettingsStore } from "@/hooks/stores"; +import { useSessionStore, useSharedSessionStore } from "@/stores/sessionStore"; +import api from "@/api"; + +import { Button } from "./Button"; + +interface AccessDeniedOverlayProps { + show: boolean; + message?: string; + onRetry?: () => void; + onRequestApproval?: () => void; +} + +export default function AccessDeniedOverlay({ + show, + message = "Your session access was denied", + onRetry, + onRequestApproval +}: AccessDeniedOverlayProps) { + const navigate = useNavigate(); + const setUser = useUserStore(state => state.setUser); + const { clearSession, rejectionCount, incrementRejectionCount } = useSessionStore(); + const { clearNickname } = useSharedSessionStore(); + const { maxRejectionAttempts } = useSettingsStore(); + const [countdown, setCountdown] = useState(10); + const [isRetrying, setIsRetrying] = useState(false); + const hasCountedRef = useRef(false); + + const handleLogout = useCallback(async () => { + try { + const logoutUrl = isOnDevice ? `${DEVICE_API}/auth/logout` : `${CLOUD_API}/logout`; + const res = await api.POST(logoutUrl); + if (!res.ok) { + console.warn("Logout API call failed, but continuing with local cleanup"); + } + } catch (error) { + console.error("Logout API call failed:", error); + } + + // Always clear local state and navigate, regardless of API call result + setUser(null); + clearSession(); + clearNickname(); + navigate("/"); + }, [navigate, setUser, clearSession, clearNickname]); + + useEffect(() => { + if (!show) { + hasCountedRef.current = false; + setCountdown(10); + return; + } + + // Only count rejection once per showing + if (hasCountedRef.current) return; + hasCountedRef.current = true; + + const newCount = incrementRejectionCount(); + + if (newCount >= maxRejectionAttempts) { + return; + } + + const timer = setInterval(() => { + setCountdown(prev => { + if (prev <= 1) { + clearInterval(timer); + handleLogout(); + return 0; + } + return prev - 1; + }); + }, 1000); + + return () => clearInterval(timer); + }, [show, handleLogout, incrementRejectionCount, maxRejectionAttempts]); + + if (!show) return null; + + if (rejectionCount >= maxRejectionAttempts) { + return null; + } + + return ( +
+
+
+ +
+

+ Access Denied +

+

+ {message} +

+
+
+ +
+
+

+ The primary session has denied your access request. This could be for security reasons + or because the session is restricted. +

+
+ + {rejectionCount < maxRejectionAttempts && ( +
+

+ Attempt {rejectionCount} of {maxRejectionAttempts}: {rejectionCount === maxRejectionAttempts - 1 + ? "This is your last attempt. Further rejections will hide this dialog." + : `You have ${maxRejectionAttempts - rejectionCount} attempt${maxRejectionAttempts - rejectionCount === 1 ? '' : 's'} remaining.` + } +

+
+ )} + +

+ Redirecting in {countdown} seconds... +

+ +
+ {(onRequestApproval || onRetry) && rejectionCount < maxRejectionAttempts && ( +
+
+
+
+ ); +} \ No newline at end of file diff --git a/ui/src/components/ActionBar.tsx b/ui/src/components/ActionBar.tsx index 4f79d7ed8..d978ef646 100644 --- a/ui/src/components/ActionBar.tsx +++ b/ui/src/components/ActionBar.tsx @@ -2,8 +2,8 @@ import { MdOutlineContentPasteGo } from "react-icons/md"; import { LuCable, LuHardDrive, LuMaximize, LuSettings, LuSignal } from "react-icons/lu"; import { FaKeyboard } from "react-icons/fa6"; import { Popover, PopoverButton, PopoverPanel } from "@headlessui/react"; -import { Fragment, useCallback, useRef } from "react"; -import { CommandLineIcon } from "@heroicons/react/20/solid"; +import { Fragment, useCallback, useRef, useEffect } from "react"; +import { CommandLineIcon, UserGroupIcon } from "@heroicons/react/20/solid"; import { Button } from "@components/Button"; import { @@ -11,14 +11,18 @@ import { useMountMediaStore, useSettingsStore, useUiStore, -} from "@/hooks/stores"; + useRTCStore } from "@/hooks/stores"; import Container from "@components/Container"; import { cx } from "@/cva.config"; import PasteModal from "@/components/popovers/PasteModal"; import WakeOnLanModal from "@/components/popovers/WakeOnLan/Index"; import MountPopopover from "@/components/popovers/MountPopover"; import ExtensionPopover from "@/components/popovers/ExtensionPopover"; +import SessionPopover from "@/components/popovers/SessionPopover"; import { useDeviceUiNavigation } from "@/hooks/useAppNavigation"; +import { useSessionStore } from "@/stores/sessionStore"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; export default function Actionbar({ requestFullscreen, @@ -33,6 +37,40 @@ export default function Actionbar({ state => state.remoteVirtualMediaState, ); const { developerMode } = useSettingsStore(); + const { currentMode, sessions, setSessions } = useSessionStore(); + const { rpcDataChannel } = useRTCStore(); + const { hasPermission } = usePermissions(); + + // Fetch sessions on mount if we have an RPC channel + useEffect(() => { + if (rpcDataChannel?.readyState === "open" && sessions.length === 0) { + const id = Math.random().toString(36).substring(2); + const message = JSON.stringify({ jsonrpc: "2.0", method: "getSessions", params: {}, id }); + + const handler = (event: MessageEvent) => { + try { + const response = JSON.parse(event.data); + if (response.id === id && response.result) { + setSessions(response.result); + } + } catch { + // Ignore parse errors for non-JSON messages + } + }; + + rpcDataChannel.addEventListener("message", handler); + rpcDataChannel.send(message); + + const timeoutId = setTimeout(() => { + rpcDataChannel.removeEventListener("message", handler); + }, 5000); + + return () => { + clearTimeout(timeoutId); + rpcDataChannel.removeEventListener("message", handler); + }; + } + }, [rpcDataChannel, sessions.length, setSessions]); // This is the only way to get a reliable state change for the popover // at time of writing this there is no mount, or unmount event for the popover @@ -44,7 +82,6 @@ export default function Actionbar({ if (!open) { setTimeout(() => { setDisableVideoFocusTrap(false); - console.debug("Popover is closing. Returning focus trap to video"); }, 0); } } @@ -60,7 +97,7 @@ export default function Actionbar({ className="flex flex-wrap items-center justify-between gap-x-4 gap-y-2 py-1.5" >
- {developerMode && ( + {developerMode && hasPermission(Permission.TERMINAL_ACCESS) && (
+ )} + {hasPermission(Permission.KEYBOARD_INPUT) && ( +
+ )} + + +
+ {/* Session Control */}
-
-
-
+ )} -
- - + {hasPermission(Permission.KEYBOARD_INPUT) && ( +
+
+ )}
-
-
+ {/* Only show Settings for sessions with settings access */} + {hasPermission(Permission.SETTINGS_ACCESS) && ( +
+
+ )}
diff --git a/ui/src/components/Header.tsx b/ui/src/components/Header.tsx index a650693f4..6fc7bb274 100644 --- a/ui/src/components/Header.tsx +++ b/ui/src/components/Header.tsx @@ -12,6 +12,7 @@ import LogoWhiteIcon from "@/assets/logo-white.svg"; import USBStateStatus from "@components/USBStateStatus"; import PeerConnectionStatusCard from "@components/PeerConnectionStatusCard"; import { CLOUD_API, DEVICE_API } from "@/ui.config"; +import { useSessionStore, useSharedSessionStore } from "@/stores/sessionStore"; import api from "../api"; import { isOnDevice } from "../main"; @@ -37,6 +38,8 @@ export default function DashboardNavbar({ }: NavbarProps) { const peerConnectionState = useRTCStore(state => state.peerConnectionState); const setUser = useUserStore(state => state.setUser); + const { clearSession } = useSessionStore(); + const { clearNickname } = useSharedSessionStore(); const navigate = useNavigate(); const onLogout = useCallback(async () => { const logoutUrl = isOnDevice ? `${DEVICE_API}/auth/logout` : `${CLOUD_API}/logout`; @@ -44,9 +47,12 @@ export default function DashboardNavbar({ if (!res.ok) return; setUser(null); + // Clear the stored session data via zustand + clearNickname(); + clearSession(); // The root route will redirect to appropriate login page, be it the local one or the cloud one navigate("/"); - }, [navigate, setUser]); + }, [navigate, setUser, clearNickname, clearSession]); const { usbState } = useHidStore(); diff --git a/ui/src/components/InfoBar.tsx b/ui/src/components/InfoBar.tsx index ce444d859..02b917bee 100644 --- a/ui/src/components/InfoBar.tsx +++ b/ui/src/components/InfoBar.tsx @@ -31,9 +31,16 @@ export default function InfoBar() { useEffect(() => { if (!rpcDataChannel) return; - rpcDataChannel.onclose = () => console.log("rpcDataChannel has closed"); - rpcDataChannel.onerror = (e: Event) => - console.error(`Error on DataChannel '${rpcDataChannel.label}': ${e}`); + rpcDataChannel.onclose = () => { + if (rpcDataChannel.readyState === "closed") { + console.debug("rpcDataChannel closed"); + } + }; + rpcDataChannel.onerror = (e: Event) => { + if (rpcDataChannel.readyState === "open" || rpcDataChannel.readyState === "connecting") { + console.error(`Error on DataChannel '${rpcDataChannel.label}':`, e); + } + }; }, [rpcDataChannel]); const { keyboardLedState, usbState } = useHidStore(); diff --git a/ui/src/components/MacroBar.tsx b/ui/src/components/MacroBar.tsx index 0ba8cf4f7..8726a778f 100644 --- a/ui/src/components/MacroBar.tsx +++ b/ui/src/components/MacroBar.tsx @@ -6,21 +6,26 @@ import Container from "@components/Container"; import { useMacrosStore } from "@/hooks/stores"; import useKeyboard from "@/hooks/useKeyboard"; import { useJsonRpc } from "@/hooks/useJsonRpc"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; export default function MacroBar() { const { macros, initialized, loadMacros, setSendFn } = useMacrosStore(); const { executeMacro } = useKeyboard(); const { send } = useJsonRpc(); + const { permissions, hasPermission } = usePermissions(); useEffect(() => { setSendFn(send); - - if (!initialized) { + + // Only load macros if user has permission to read settings + if (!initialized && permissions[Permission.SETTINGS_READ] === true) { loadMacros(); } - }, [initialized, loadMacros, setSendFn, send]); + }, [initialized, send, loadMacros, setSendFn, permissions]); - if (macros.length === 0) { + // Don't show macros if user can't provide keyboard input or if no macros exist + if (macros.length === 0 || !hasPermission(Permission.KEYBOARD_INPUT)) { return null; } diff --git a/ui/src/components/NicknameModal.tsx b/ui/src/components/NicknameModal.tsx new file mode 100644 index 000000000..5b9586216 --- /dev/null +++ b/ui/src/components/NicknameModal.tsx @@ -0,0 +1,263 @@ +import { useState, useEffect, useRef } from "react"; +import { Dialog, DialogPanel, DialogBackdrop } from "@headlessui/react"; +import { UserIcon, XMarkIcon } from "@heroicons/react/20/solid"; + +import { useSettingsStore , useRTCStore } from "@/hooks/stores"; +import { useJsonRpc } from "@/hooks/useJsonRpc"; +import { generateNickname } from "@/utils/nicknameGenerator"; + +import { Button } from "./Button"; + +type SessionRole = "primary" | "observer" | "queued" | "pending"; + +interface NicknameModalProps { + isOpen: boolean; + onSubmit: (nickname: string) => void | Promise; + onSkip?: () => void; + title?: string; + description?: string; + isRequired?: boolean; + expectedRole?: SessionRole; +} + +export default function NicknameModal({ + isOpen, + onSubmit, + onSkip, + title = "Set Your Session Nickname", + description = "Add a nickname to help identify your session to other users", + isRequired, + expectedRole = "observer" +}: NicknameModalProps) { + const [nickname, setNickname] = useState(""); + const [isSubmitting, setIsSubmitting] = useState(false); + const [error, setError] = useState(null); + const [generatedNickname, setGeneratedNickname] = useState(""); + const inputRef = useRef(null); + const { requireSessionNickname } = useSettingsStore(); + const { send } = useJsonRpc(); + const { rpcDataChannel } = useRTCStore(); + + const isNicknameRequired = isRequired ?? requireSessionNickname; + + // Role-based color coding + const getRoleColors = (role: SessionRole) => { + switch (role) { + case "primary": + return { + bg: "bg-green-100 dark:bg-green-900/30", + icon: "text-green-600 dark:text-green-400" + }; + case "observer": + return { + bg: "bg-blue-100 dark:bg-blue-900/30", + icon: "text-blue-600 dark:text-blue-400" + }; + case "queued": + return { + bg: "bg-yellow-100 dark:bg-yellow-900/30", + icon: "text-yellow-600 dark:text-yellow-400" + }; + case "pending": + return { + bg: "bg-orange-100 dark:bg-orange-900/30", + icon: "text-orange-600 dark:text-orange-400" + }; + default: + return { + bg: "bg-slate-100 dark:bg-slate-900/30", + icon: "text-slate-600 dark:text-slate-400" + }; + } + }; + + const roleColors = getRoleColors(expectedRole); + + // Generate nickname when modal opens and RPC is ready + useEffect(() => { + if (!isOpen || generatedNickname) return; + if (rpcDataChannel?.readyState !== "open") return; + + generateNickname(send).then(nickname => { + setGeneratedNickname(nickname); + }).catch((error) => { + console.error('Backend nickname generation failed:', error); + }); + }, [isOpen, generatedNickname, rpcDataChannel?.readyState, send]); + + // Focus input when modal opens + useEffect(() => { + if (isOpen) { + setTimeout(() => { + if (inputRef.current) { + inputRef.current.focus(); + } + }, 100); + } + }, [isOpen]); + + const validateNickname = (value: string): string | null => { + if (value.length < 2) { + return "Nickname must be at least 2 characters"; + } + if (value.length > 30) { + return "Nickname must be 30 characters or less"; + } + if (!/^[a-zA-Z0-9\s\-_.@]+$/.test(value)) { + return "Nickname can only contain letters, numbers, spaces, and - _ . @"; + } + return null; + }; + + const handleSubmit = async (e?: React.FormEvent) => { + e?.preventDefault(); + + // Use generated nickname if input is empty + const trimmedNickname = nickname.trim() || generatedNickname; + + // Validate + const validationError = validateNickname(trimmedNickname); + if (validationError) { + setError(validationError); + return; + } + + setIsSubmitting(true); + setError(null); + + try { + await onSubmit(trimmedNickname); + setNickname(""); + setGeneratedNickname(""); // Reset generated nickname after successful submit + } catch (error) { + setError(error instanceof Error ? error.message : "Failed to set nickname"); + setIsSubmitting(false); + } + }; + + const handleSkip = () => { + if (!isNicknameRequired && onSkip) { + onSkip(); + setNickname(""); + setError(null); + setGeneratedNickname(""); // Reset generated nickname when skipping + } + }; + + return ( + { + if (!isNicknameRequired && onSkip) { + onSkip(); + setNickname(""); + setError(null); + setGeneratedNickname(""); + } + }} + className="relative z-50" + > + +
+ +
+
+
+
+ +
+
+

+ {title} +

+

+ {description} +

+
+
+ {!isNicknameRequired && ( + + )} +
+ +
+
+ + { + setNickname(e.target.value); + setError(null); + }} + placeholder={generatedNickname || "e.g., John's Laptop, Office PC, etc."} + className="w-full px-3 py-2 border border-slate-300 dark:border-slate-600 rounded-md + bg-white dark:bg-slate-700 text-slate-900 dark:text-white + placeholder-slate-400 dark:placeholder-slate-500 + focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent" + maxLength={30} + /> +
+ {error ? ( +

{error}

+ ) : ( +
+

+ {nickname.trim() === "" && generatedNickname + ? `Leave empty to use: ${generatedNickname}` + : "2-30 characters, letters, numbers, spaces, and - _ . @ allowed"} +

+
+ )} + + {nickname.length}/30 + +
+
+ + {isNicknameRequired && ( +
+

+ Required: A nickname is required by the administrator to help identify sessions. +

+
+ )} + +
+
+
+
+
+
+
+ ); +} \ No newline at end of file diff --git a/ui/src/components/PendingApprovalOverlay.tsx b/ui/src/components/PendingApprovalOverlay.tsx new file mode 100644 index 000000000..6d96ab760 --- /dev/null +++ b/ui/src/components/PendingApprovalOverlay.tsx @@ -0,0 +1,53 @@ +import { useEffect, useState } from "react"; +import { ClockIcon } from "@heroicons/react/24/outline"; + +interface PendingApprovalOverlayProps { + show: boolean; +} + +export default function PendingApprovalOverlay({ show }: PendingApprovalOverlayProps) { + const [dots, setDots] = useState(""); + + useEffect(() => { + if (!show) return; + + const timer = setInterval(() => { + setDots(prev => (prev.length >= 3 ? "" : prev + ".")); + }, 500); + + return () => clearInterval(timer); + }, [show]); + + if (!show) return null; + + return ( +
+
+
+ + +
+

+ Awaiting Approval{dots} +

+

+ Your session is pending approval from the primary session +

+
+ +
+

+ The primary user will receive a notification to approve or deny your access. + This typically takes less than 30 seconds. +

+
+ +
+
+ Waiting for response from primary session +
+
+
+
+ ); +} \ No newline at end of file diff --git a/ui/src/components/SessionControlPanel.tsx b/ui/src/components/SessionControlPanel.tsx new file mode 100644 index 000000000..675b306e7 --- /dev/null +++ b/ui/src/components/SessionControlPanel.tsx @@ -0,0 +1,143 @@ +import { + LockClosedIcon, + LockOpenIcon, + ClockIcon +} from "@heroicons/react/16/solid"; +import clsx from "clsx"; + +import { useSessionStore } from "@/stores/sessionStore"; +import { sessionApi } from "@/api/sessionApi"; +import { Button } from "@/components/Button"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; + +type RpcSendFunction = (method: string, params: Record, callback: (response: { result?: unknown; error?: { message: string } }) => void) => void; + +interface SessionControlPanelProps { + sendFn: RpcSendFunction; + className?: string; +} + +export default function SessionControlPanel({ sendFn, className }: SessionControlPanelProps) { + const { + currentSessionId, + currentMode, + sessions, + isRequestingPrimary, + setRequestingPrimary, + setSessionError, + canRequestPrimary + } = useSessionStore(); + const { hasPermission } = usePermissions(); + + + const handleRequestPrimary = async () => { + if (!currentSessionId || isRequestingPrimary) return; + + setRequestingPrimary(true); + setSessionError(null); + + try { + const result = await sessionApi.requestPrimary(sendFn, currentSessionId); + + if (result.status === "success") { + if (result.mode === "primary") { + // Immediately became primary + setRequestingPrimary(false); + } else if (result.mode === "queued") { + // Request sent, waiting for approval + // Keep isRequestingPrimary true to show waiting state + } + } else if (result.status === "error") { + setSessionError(result.message || "Failed to request primary control"); + setRequestingPrimary(false); + } + } catch (error) { + setSessionError(error instanceof Error ? error.message : "Unknown error"); + console.error("Failed to request primary control:", error); + setRequestingPrimary(false); + } + }; + + const handleReleasePrimary = async () => { + if (!currentSessionId || currentMode !== "primary") return; + + try { + await sessionApi.releasePrimary(sendFn, currentSessionId); + } catch (error) { + setSessionError(error instanceof Error ? error.message : "Unknown error"); + console.error("Failed to release primary control:", error); + } + }; + + const canReleasePrimary = () => { + const otherEligibleSessions = sessions.filter( + s => s.id !== currentSessionId && (s.mode === "observer" || s.mode === "queued") + ); + return otherEligibleSessions.length > 0; + }; + + + return ( +
+ {/* Current session controls */} +
+

+ Session Control +

+ + {hasPermission(Permission.SESSION_RELEASE_PRIMARY) && ( +
+
+ )} + + {hasPermission(Permission.SESSION_REQUEST_PRIMARY) && ( + <> + {isRequestingPrimary ? ( +
+ + + Waiting for approval from primary session... + +
+ ) : ( +
+ +
+ ); +} \ No newline at end of file diff --git a/ui/src/components/SessionsList.tsx b/ui/src/components/SessionsList.tsx new file mode 100644 index 000000000..46e49626c --- /dev/null +++ b/ui/src/components/SessionsList.tsx @@ -0,0 +1,151 @@ +import { PencilIcon, CheckIcon, XMarkIcon } from "@heroicons/react/20/solid"; +import clsx from "clsx"; + +import { formatters } from "@/utils"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; + +interface Session { + id: string; + mode: string; + nickname?: string; + identity?: string; + source?: string; + createdAt?: string; +} + +interface SessionsListProps { + sessions: Session[]; + currentSessionId?: string; + onEditNickname?: (sessionId: string) => void; + onApprove?: (sessionId: string) => void; + onDeny?: (sessionId: string) => void; + onTransfer?: (sessionId: string) => void; + formatDuration?: (createdAt: string) => string; +} + +export default function SessionsList({ + sessions, + currentSessionId, + onEditNickname, + onApprove, + onDeny, + onTransfer, + formatDuration = (createdAt: string) => formatters.timeAgo(new Date(createdAt)) || "" +}: SessionsListProps) { + const { hasPermission } = usePermissions(); + return ( +
+ {sessions.map(session => ( +
+
+
+ + {session.id === currentSessionId && ( + (You) + )} +
+
+ + {session.createdAt ? formatDuration(session.createdAt) : ""} + + {/* Show approve/deny for pending sessions if user has permission */} + {session.mode === "pending" && hasPermission(Permission.SESSION_APPROVE) && onApprove && onDeny && ( +
+ + +
+ )} + {/* Show Transfer button if user has permission to transfer */} + {hasPermission(Permission.SESSION_TRANSFER) && session.mode === "observer" && session.id !== currentSessionId && onTransfer && ( + + )} + {/* Allow users with session manage permission to edit any nickname, or anyone to edit their own */} + {onEditNickname && (hasPermission(Permission.SESSION_MANAGE) || session.id === currentSessionId) && ( + + )} +
+
+ +
+ {session.nickname && ( +

+ {session.nickname} +

+ )} + {session.identity && ( +

+ {session.source === "cloud" ? "☁️ " : ""}{session.identity} +

+ )} + {session.mode === "pending" && ( +

+ Awaiting approval +

+ )} +
+
+ ))} +
+ ); +} + +export function SessionModeBadge({ mode }: { mode: string }) { + const getBadgeStyle = () => { + switch (mode) { + case "primary": + return "bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-400"; + case "observer": + return "bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400"; + case "queued": + return "bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400"; + case "pending": + return "bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"; + default: + return "bg-slate-100 text-slate-700 dark:bg-slate-900/30 dark:text-slate-400"; + } + }; + + return ( + + {mode} + + ); +} \ No newline at end of file diff --git a/ui/src/components/UnifiedSessionRequestDialog.tsx b/ui/src/components/UnifiedSessionRequestDialog.tsx new file mode 100644 index 000000000..ca936095c --- /dev/null +++ b/ui/src/components/UnifiedSessionRequestDialog.tsx @@ -0,0 +1,262 @@ +import { useEffect, useState } from "react"; +import { XMarkIcon, UserIcon, GlobeAltIcon, ComputerDesktopIcon } from "@heroicons/react/20/solid"; + +import { Button } from "./Button"; + +type RequestType = "session_approval" | "primary_control"; + +interface UnifiedSessionRequest { + id: string; // sessionId or requestId + type: RequestType; + source: "local" | "cloud" | string; // Allow string for IP addresses + identity?: string; + nickname?: string; +} + +interface UnifiedSessionRequestDialogProps { + request: UnifiedSessionRequest | null; + onApprove: (id: string) => void | Promise; + onDeny: (id: string) => void | Promise; + onDismiss?: () => void; + onClose: () => void; +} + +export default function UnifiedSessionRequestDialog({ + request, + onApprove, + onDeny, + onDismiss, + onClose +}: UnifiedSessionRequestDialogProps) { + const [timeRemaining, setTimeRemaining] = useState(0); + const [isProcessing, setIsProcessing] = useState(false); + const [hasTimedOut, setHasTimedOut] = useState(false); + + useEffect(() => { + if (!request) return; + + const isSessionApproval = request.type === "session_approval"; + const initialTime = isSessionApproval ? 60 : 0; // 60s for session approval, no timeout for primary control + + setTimeRemaining(initialTime); + setIsProcessing(false); + setHasTimedOut(false); + + // Only start timer for session approval requests + if (isSessionApproval) { + const timer = setInterval(() => { + setTimeRemaining(prev => { + const newTime = prev - 1; + if (newTime <= 0) { + clearInterval(timer); + setHasTimedOut(true); + return 0; + } + return newTime; + }); + }, 1000); + + return () => clearInterval(timer); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [request?.id, request?.type]); // Only depend on stable properties to avoid unnecessary re-renders + + // Handle auto-deny when timeout occurs + useEffect(() => { + if (hasTimedOut && !isProcessing && request) { + setIsProcessing(true); + Promise.resolve(onDeny(request.id)) + .catch(error => { + console.error("Failed to auto-deny request:", error); + }) + .finally(() => { + onClose(); + }); + } + }, [hasTimedOut, isProcessing, request, onDeny, onClose]); + + if (!request) return null; + + const isSessionApproval = request.type === "session_approval"; + const isPrimaryControl = request.type === "primary_control"; + + // Determine if source is cloud, local, or IP address + const getSourceInfo = () => { + if (request.source === "cloud") { + return { + type: "cloud", + label: "Cloud Session", + icon: GlobeAltIcon, + iconColor: "text-blue-500" + }; + } else if (request.source === "local") { + return { + type: "local", + label: "Local Session", + icon: ComputerDesktopIcon, + iconColor: "text-green-500" + }; + } else { + // Assume it's an IP address or hostname + return { + type: "ip", + label: request.source, + icon: ComputerDesktopIcon, + iconColor: "text-green-500" + }; + } + }; + + const sourceInfo = getSourceInfo(); + + const getTitle = () => { + if (isSessionApproval) return "New Session Request"; + if (isPrimaryControl) return "Primary Control Request"; + return "Session Request"; + }; + + const getDescription = () => { + if (isSessionApproval) return "A new session is attempting to connect to this device:"; + if (isPrimaryControl) return "A user is requesting primary control of this session:"; + return "A user is making a request:"; + }; + + return ( +
+
+
+

+ {getTitle()} +

+ +
+ +
+

+ {getDescription()} +

+ +
+ {/* Session type - always show with icon for both session approval and primary control */} +
+ + + {sourceInfo.type === "cloud" ? "Cloud Session" : + sourceInfo.type === "local" ? "Local Session" : + `Local Session`} + + {sourceInfo.type === "ip" && ( + + ({sourceInfo.label}) + + )} +
+ + {/* Nickname - always show with icon for consistency */} + {request.nickname && ( +
+ + + Nickname:{" "} + {request.nickname} + +
+ )} + + {/* Identity/User */} + {request.identity && ( +
+ {isSessionApproval ? ( +

Identity: {request.identity}

+ ) : ( +

+ User:{" "} + {request.identity} +

+ )} +
+ )} +
+ + {/* Security Note - only for session approval */} + {isSessionApproval && ( +
+

+ Security Note: Only approve sessions you recognize. + Approved sessions will have observer access and can request primary control. +

+
+ )} + + {/* Auto-deny timer - only for session approval */} + {isSessionApproval && ( +
+

+ Auto-deny in {timeRemaining} seconds +

+
+ )} + +
+
+
+ {onDismiss && ( +
+
+
+
+ ); +} \ No newline at end of file diff --git a/ui/src/components/WebRTCVideo.tsx b/ui/src/components/WebRTCVideo.tsx index 1ce25fe21..828e22436 100644 --- a/ui/src/components/WebRTCVideo.tsx +++ b/ui/src/components/WebRTCVideo.tsx @@ -14,6 +14,8 @@ import { useSettingsStore, useVideoStore, } from "@/hooks/stores"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; import useMouse from "@/hooks/useMouse"; import { @@ -35,6 +37,7 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu // Store hooks const settings = useSettingsStore(); + const { hasPermission } = usePermissions(); const { handleKeyPress, resetKeyboardState } = useKeyboard(); const { getRelMouseMoveHandler, @@ -214,29 +217,47 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu document.addEventListener("fullscreenchange", handleFullscreenChange); }, [releaseKeyboardLock]); - const absMouseMoveHandler = useMemo( - () => getAbsMouseMoveHandler({ + const absMouseMoveHandler = useMemo(() => { + const handler = getAbsMouseMoveHandler({ videoClientWidth, videoClientHeight, videoWidth, videoHeight, - }), - [getAbsMouseMoveHandler, videoClientWidth, videoClientHeight, videoWidth, videoHeight], - ); - - const relMouseMoveHandler = useMemo( - () => getRelMouseMoveHandler(), - [getRelMouseMoveHandler], - ); - - const mouseWheelHandler = useMemo( - () => getMouseWheelHandler(), - [getMouseWheelHandler], - ); + }); + return (e: MouseEvent) => { + // Only allow input if user has mouse permission + if (!hasPermission(Permission.MOUSE_INPUT)) return; + handler(e); + }; + }, [getAbsMouseMoveHandler, videoClientWidth, videoClientHeight, videoWidth, videoHeight, hasPermission]); + + const relMouseMoveHandler = useMemo(() => { + const handler = getRelMouseMoveHandler(); + return (e: MouseEvent) => { + // Only allow input if user has mouse permission + if (!hasPermission(Permission.MOUSE_INPUT)) return; + handler(e); + }; + }, [getRelMouseMoveHandler, hasPermission]); + + const mouseWheelHandler = useMemo(() => { + const handler = getMouseWheelHandler(); + return (e: WheelEvent) => { + // Only allow input if user has mouse permission + if (!hasPermission(Permission.MOUSE_INPUT)) return; + handler(e); + }; + }, [getMouseWheelHandler, hasPermission]); const keyDownHandler = useCallback( (e: KeyboardEvent) => { e.preventDefault(); + + // Only allow input if user has keyboard permission + if (!hasPermission(Permission.KEYBOARD_INPUT)) { + return; + } + if (e.repeat) return; const code = getAdjustedKeyCode(e); const hidKey = keys[code]; @@ -252,11 +273,9 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu // https://bugzilla.mozilla.org/show_bug.cgi?id=1299553 if (e.metaKey && hidKey < 0xE0) { setTimeout(() => { - console.debug(`Forcing the meta key release of associated key: ${hidKey}`); handleKeyPress(hidKey, false); }, 10); } - console.debug(`Key down: ${hidKey}`); handleKeyPress(hidKey, true); if (!isKeyboardLockActive && hidKey === keys.MetaLeft) { @@ -264,17 +283,22 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu // we'll never see the keyup event because the browser is going to lose // focus so set a deferred keyup after a short delay setTimeout(() => { - console.debug(`Forcing the left meta key release`); handleKeyPress(hidKey, false); }, 100); } }, - [handleKeyPress, isKeyboardLockActive], + [handleKeyPress, isKeyboardLockActive, hasPermission], ); const keyUpHandler = useCallback( async (e: KeyboardEvent) => { e.preventDefault(); + + // Only allow input if user has keyboard permission + if (!hasPermission(Permission.KEYBOARD_INPUT)) { + return; + } + const code = getAdjustedKeyCode(e); const hidKey = keys[code]; @@ -283,10 +307,9 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu return; } - console.debug(`Key up: ${hidKey}`); handleKeyPress(hidKey, false); }, - [handleKeyPress], + [handleKeyPress, hasPermission], ); const videoKeyUpHandler = useCallback((e: KeyboardEvent) => { @@ -297,7 +320,6 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu // Fix only works in chrome based browsers. if (e.code === "Space") { if (videoElm.current.paused) { - console.debug("Force playing video"); videoElm.current.play(); } } @@ -557,7 +579,7 @@ export default function WebRTCVideo({ hasConnectionIssues }: { hasConnectionIssu )}
- + {hasPermission(Permission.KEYBOARD_INPUT) && }
diff --git a/ui/src/components/popovers/SessionPopover.tsx b/ui/src/components/popovers/SessionPopover.tsx new file mode 100644 index 000000000..cf618eca0 --- /dev/null +++ b/ui/src/components/popovers/SessionPopover.tsx @@ -0,0 +1,208 @@ +import { useState, useEffect, useCallback } from "react"; +import { + UserGroupIcon, + ArrowPathIcon, + PencilIcon, +} from "@heroicons/react/20/solid"; +import clsx from "clsx"; + +import { useSessionStore, useSharedSessionStore } from "@/stores/sessionStore"; +import { useJsonRpc } from "@/hooks/useJsonRpc"; +import SessionControlPanel from "@/components/SessionControlPanel"; +import NicknameModal from "@/components/NicknameModal"; +import SessionsList, { SessionModeBadge } from "@/components/SessionsList"; +import { sessionApi } from "@/api/sessionApi"; + +export default function SessionPopover() { + const { + currentSessionId, + currentMode, + sessions, + sessionError, + setSessions, + } = useSessionStore(); + const { setNickname } = useSharedSessionStore(); + + const [isRefreshing, setIsRefreshing] = useState(false); + const [showNicknameModal, setShowNicknameModal] = useState(false); + const [editingSessionId, setEditingSessionId] = useState(null); + + const { send } = useJsonRpc(); + + // Adapter function to match existing callback pattern + const sendRpc = useCallback((method: string, params: Record, callback?: (response: { result?: unknown; error?: { message: string } }) => void) => { + send(method, params, (response) => { + if (callback) callback(response); + }); + }, [send]); + + const handleRefresh = async () => { + if (isRefreshing) return; + + setIsRefreshing(true); + try { + const refreshedSessions = await sessionApi.getSessions(sendRpc); + setSessions(refreshedSessions); + } catch (error) { + console.error("Failed to refresh sessions:", error); + } finally { + setIsRefreshing(false); + } + }; + + // Fetch sessions on mount + useEffect(() => { + if (sessions.length === 0) { + sessionApi.getSessions(sendRpc) + .then(sessions => setSessions(sessions)) + .catch(error => console.error("Failed to fetch sessions:", error)); + } + }, [sendRpc, sessions.length, setSessions]); + + return ( +
+ {/* Header */} +
+
+
+ +

+ Session Management +

+
+ +
+
+ + {/* Session Error */} + {sessionError && ( +
+

{sessionError}

+
+ )} + + {/* Current Session */} +
+
+
+
+ Your Session + +
+ +
+ + {currentSessionId && ( + <> + {/* Display current session nickname if exists */} + {sessions.find(s => s.id === currentSessionId)?.nickname && ( +
+ Nickname: + + {sessions.find(s => s.id === currentSessionId)?.nickname} + +
+ )} + +
+ +
+ + )} +
+
+ + {/* Active Sessions List */} +
+
+ Active Sessions ({sessions.length}) +
+ + {sessions.length > 0 ? ( + { + setEditingSessionId(sessionId); + setShowNicknameModal(true); + }} + onApprove={(sessionId) => { + sendRpc("approveNewSession", { sessionId }, (response) => { + if (response.error) { + console.error("Failed to approve session:", response.error); + } else { + handleRefresh(); + } + }); + }} + onDeny={(sessionId) => { + sendRpc("denyNewSession", { sessionId }, (response) => { + if (response.error) { + console.error("Failed to deny session:", response.error); + } else { + handleRefresh(); + } + }); + }} + onTransfer={async (sessionId) => { + try { + await sessionApi.transferPrimary(sendRpc, currentSessionId!, sessionId); + handleRefresh(); + } catch (error) { + console.error("Failed to transfer primary:", error); + } + }} + /> + ) : ( +

No active sessions

+ )} +
+ + s.id === currentSessionId)?.nickname ? "Update Your Nickname" : "Set Your Nickname") + : `Set Nickname for ${sessions.find(s => s.id === editingSessionId)?.mode || 'Session'}`} + description={editingSessionId === currentSessionId + ? "Choose a nickname to help identify your session to others" + : "Choose a nickname to help identify this session"} + onSubmit={async (nickname) => { + if (editingSessionId && sendRpc) { + try { + await sessionApi.updateNickname(sendRpc, editingSessionId, nickname); + if (editingSessionId === currentSessionId) { + setNickname(nickname); + } + setShowNicknameModal(false); + setEditingSessionId(null); + handleRefresh(); + } catch (error) { + console.error("Failed to update nickname:", error); + throw error; + } + } + }} + onSkip={() => { + setShowNicknameModal(false); + setEditingSessionId(null); + }} + /> +
+ ); +} + diff --git a/ui/src/contexts/PermissionsContext.ts b/ui/src/contexts/PermissionsContext.ts new file mode 100644 index 000000000..c6268d968 --- /dev/null +++ b/ui/src/contexts/PermissionsContext.ts @@ -0,0 +1,5 @@ +import { createContext } from "react"; + +import { PermissionsContextValue } from "@/hooks/usePermissions"; + +export const PermissionsContext = createContext(undefined); diff --git a/ui/src/hooks/stores.ts b/ui/src/hooks/stores.ts index 488bca5e3..7a2968107 100644 --- a/ui/src/hooks/stores.ts +++ b/ui/src/hooks/stores.ts @@ -343,6 +343,15 @@ export interface SettingsState { developerMode: boolean; setDeveloperMode: (enabled: boolean) => void; + requireSessionNickname: boolean; + setRequireSessionNickname: (required: boolean) => void; + + requireSessionApproval: boolean; + setRequireSessionApproval: (required: boolean) => void; + + maxRejectionAttempts: number; + setMaxRejectionAttempts: (attempts: number) => void; + displayRotation: string; setDisplayRotation: (rotation: string) => void; @@ -383,6 +392,15 @@ export const useSettingsStore = create( developerMode: false, setDeveloperMode: (enabled: boolean) => set({ developerMode: enabled }), + requireSessionNickname: false, + setRequireSessionNickname: (required: boolean) => set({ requireSessionNickname: required }), + + requireSessionApproval: true, + setRequireSessionApproval: (required: boolean) => set({ requireSessionApproval: required }), + + maxRejectionAttempts: 3, + setMaxRejectionAttempts: (attempts: number) => set({ maxRejectionAttempts: attempts }), + displayRotation: "270", setDisplayRotation: (rotation: string) => set({ displayRotation: rotation }), diff --git a/ui/src/hooks/useJsonRpc.ts b/ui/src/hooks/useJsonRpc.ts index 5c52d59cd..91965c744 100644 --- a/ui/src/hooks/useJsonRpc.ts +++ b/ui/src/hooks/useJsonRpc.ts @@ -1,4 +1,4 @@ -import { useCallback, useEffect } from "react"; +import { useCallback, useEffect, useRef } from "react"; import { useRTCStore } from "@/hooks/stores"; @@ -36,6 +36,12 @@ let requestCounter = 0; export function useJsonRpc(onRequest?: (payload: JsonRpcRequest) => void) { const { rpcDataChannel } = useRTCStore(); + const onRequestRef = useRef(onRequest); + + // Update ref when callback changes + useEffect(() => { + onRequestRef.current = onRequest; + }, [onRequest]); const send = useCallback( async (method: string, params: unknown, callback?: (resp: JsonRpcResponse) => void) => { @@ -59,7 +65,7 @@ export function useJsonRpc(onRequest?: (payload: JsonRpcRequest) => void) { // The "API" can also "request" data from the client // If the payload has a method, it's a request if ("method" in payload) { - if (onRequest) onRequest(payload); + if (onRequestRef.current) onRequestRef.current(payload); return; } @@ -79,7 +85,7 @@ export function useJsonRpc(onRequest?: (payload: JsonRpcRequest) => void) { rpcDataChannel.removeEventListener("message", messageHandler); }; }, - [rpcDataChannel, onRequest]); + [rpcDataChannel]); // Remove onRequest from dependencies return { send }; } diff --git a/ui/src/hooks/usePermissions.ts b/ui/src/hooks/usePermissions.ts new file mode 100644 index 000000000..a717bab2f --- /dev/null +++ b/ui/src/hooks/usePermissions.ts @@ -0,0 +1,34 @@ +import { useContext } from "react"; + +import { PermissionsContext } from "@/contexts/PermissionsContext"; +import { Permission } from "@/types/permissions"; + +export interface PermissionsContextValue { + permissions: Record; + isLoading: boolean; + hasPermission: (permission: Permission) => boolean; + hasAnyPermission: (...perms: Permission[]) => boolean; + hasAllPermissions: (...perms: Permission[]) => boolean; + isPrimary: () => boolean; + isObserver: () => boolean; + isPending: () => boolean; +} + +export function usePermissions(): PermissionsContextValue { + const context = useContext(PermissionsContext); + + if (context === undefined) { + return { + permissions: {}, + isLoading: true, + hasPermission: () => false, + hasAnyPermission: () => false, + hasAllPermissions: () => false, + isPrimary: () => false, + isObserver: () => false, + isPending: () => false, + }; + } + + return context; +} diff --git a/ui/src/hooks/useSessionEvents.ts b/ui/src/hooks/useSessionEvents.ts new file mode 100644 index 000000000..58d9715d0 --- /dev/null +++ b/ui/src/hooks/useSessionEvents.ts @@ -0,0 +1,167 @@ +import { useEffect, useRef } from "react"; + +import { useSessionStore, SessionInfo } from "@/stores/sessionStore"; +import { useRTCStore } from "@/hooks/stores"; +import { sessionApi } from "@/api/sessionApi"; +import { notify } from "@/notifications"; + +type RpcSendFunction = (method: string, params: Record, callback: (response: { result?: unknown; error?: { message: string } }) => void) => void; + +interface SessionEventData { + sessions: SessionInfo[]; + yourMode: string; +} + +interface ModeChangedData { + mode: string; +} + +interface ConnectionModeChangedData { + newMode: string; +} + +export function useSessionEvents(sendFn: RpcSendFunction | null) { + const { + currentMode, + setSessions, + updateSessionMode, + setSessionError + } = useSessionStore(); + + const sendFnRef = useRef(sendFn); + sendFnRef.current = sendFn; + + const handleSessionEvent = (method: string, params: unknown) => { + switch (method) { + case "sessionsUpdated": + handleSessionsUpdated(params as SessionEventData); + break; + case "modeChanged": + handleModeChanged(params as ModeChangedData); + break; + case "connectionModeChanged": + handleConnectionModeChanged(params as ConnectionModeChangedData); + break; + case "hidReadyForPrimary": + handleHidReadyForPrimary(); + break; + case "otherSessionConnected": + handleOtherSessionConnected(); + break; + default: + break; + } + }; + + const handleSessionsUpdated = (data: SessionEventData) => { + if (data.sessions) { + setSessions(data.sessions); + } + + // CRITICAL: Only update mode, never show notifications from sessionsUpdated + // Notifications are exclusively handled by handleModeChanged to prevent duplicates + if (data.yourMode && data.yourMode !== currentMode) { + updateSessionMode(data.yourMode as "primary" | "observer" | "queued" | "pending"); + } + }; + + // Debounce notifications to prevent rapid-fire duplicates + const lastNotificationRef = useRef<{mode: string, timestamp: number}>({mode: "", timestamp: 0}); + + const handleModeChanged = (data: ModeChangedData) => { + if (data.mode) { + // Get the most current mode from the store to avoid race conditions + const { currentMode: currentModeFromStore } = useSessionStore.getState(); + const previousMode = currentModeFromStore; + updateSessionMode(data.mode as "primary" | "observer" | "queued" | "pending"); + + if (previousMode === "queued" && data.mode !== "queued") { + const { setRequestingPrimary } = useSessionStore.getState(); + setRequestingPrimary(false); + } + + if (previousMode === "pending" && data.mode === "observer") { + const { resetRejectionCount } = useSessionStore.getState(); + resetRejectionCount(); + } + + // HID re-initialization is now handled automatically by permission changes in usePermissions + + // CRITICAL: Debounce notifications to prevent duplicates from rapid-fire events + const now = Date.now(); + const lastNotification = lastNotificationRef.current; + + // Only show notification if: + // 1. Mode actually changed, AND + // 2. Haven't shown the same notification in the last 2 seconds + const shouldNotify = previousMode !== data.mode && + (lastNotification.mode !== data.mode || now - lastNotification.timestamp > 2000); + + if (shouldNotify) { + if (data.mode === "primary") { + notify.success("Primary control granted"); + lastNotificationRef.current = {mode: "primary", timestamp: now}; + } else if (data.mode === "observer" && previousMode === "primary") { + notify.info("Primary control released"); + lastNotificationRef.current = {mode: "observer", timestamp: now}; + } + } + } + }; + + const handleConnectionModeChanged = (data: ConnectionModeChangedData) => { + if (data.newMode) { + handleModeChanged({ mode: data.newMode }); + } + }; + + const handleHidReadyForPrimary = () => { + const { rpcHidChannel } = useRTCStore.getState(); + if (rpcHidChannel?.readyState === "open") { + rpcHidChannel.dispatchEvent(new Event("open")); + } + }; + + const handleOtherSessionConnected = () => { + notify.warning("Another session is connecting", { + duration: 5000 + }); + }; + + useEffect(() => { + if (!sendFnRef.current) return; + + const fetchSessions = async () => { + try { + const sessions = await sessionApi.getSessions(sendFnRef.current!); + setSessions(sessions); + } catch (error) { + console.error("Failed to fetch sessions:", error); + setSessionError("Failed to fetch session information"); + } + }; + + fetchSessions(); + }, [setSessions, setSessionError]); + + useEffect(() => { + if (!sendFnRef.current) return; + + const intervalId = setInterval(async () => { + if (!sendFnRef.current) return; + + try { + const sessions = await sessionApi.getSessions(sendFnRef.current); + setSessions(sessions); + } catch { + // Silently fail on refresh errors + } + }, 30000); // Refresh every 30 seconds + + return () => clearInterval(intervalId); + }, [setSessions]); + + return { + handleSessionEvent + }; +} \ No newline at end of file diff --git a/ui/src/hooks/useSessionManagement.ts b/ui/src/hooks/useSessionManagement.ts new file mode 100644 index 000000000..8925b3902 --- /dev/null +++ b/ui/src/hooks/useSessionManagement.ts @@ -0,0 +1,173 @@ +import { useEffect, useCallback, useState } from "react"; + +import { useSessionStore } from "@/stores/sessionStore"; +import { useSessionEvents } from "@/hooks/useSessionEvents"; +import { useSettingsStore } from "@/hooks/stores"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; + +type RpcSendFunction = (method: string, params: Record, callback: (response: { result?: unknown; error?: { message: string } }) => void) => void; + +interface SessionResponse { + sessionId?: string; + mode?: string; +} + +interface PrimaryControlRequest { + requestId: string; + identity: string; + source: string; + nickname?: string; +} + +interface NewSessionRequest { + sessionId: string; + source: "local" | "cloud"; + identity?: string; + nickname?: string; +} + +export function useSessionManagement(sendFn: RpcSendFunction | null) { + const { + setCurrentSession, + clearSession + } = useSessionStore(); + + const { hasPermission, isLoading: isLoadingPermissions } = usePermissions(); + + const { requireSessionApproval } = useSettingsStore(); + const { handleSessionEvent } = useSessionEvents(sendFn); + const [primaryControlRequest, setPrimaryControlRequest] = useState(null); + const [newSessionRequest, setNewSessionRequest] = useState(null); + + const handleSessionResponse = useCallback((response: SessionResponse) => { + if (response.sessionId && response.mode) { + setCurrentSession(response.sessionId, response.mode as "primary" | "observer" | "queued" | "pending"); + } + }, [setCurrentSession]); + + const handleApprovePrimaryRequest = useCallback(async (requestId: string) => { + if (!sendFn) return; + + return new Promise((resolve, reject) => { + sendFn("approvePrimaryRequest", { requesterID: requestId }, (response: { result?: unknown; error?: { message: string } }) => { + if (response.error) { + console.error("Failed to approve primary request:", response.error); + reject(new Error(response.error.message || "Failed to approve")); + } else { + setPrimaryControlRequest(null); + resolve(); + } + }); + }); + }, [sendFn]); + + const handleDenyPrimaryRequest = useCallback(async (requestId: string) => { + if (!sendFn) return; + + return new Promise((resolve, reject) => { + sendFn("denyPrimaryRequest", { requesterID: requestId }, (response: { result?: unknown; error?: { message: string } }) => { + if (response.error) { + console.error("Failed to deny primary request:", response.error); + reject(new Error(response.error.message || "Failed to deny")); + } else { + setPrimaryControlRequest(null); + resolve(); + } + }); + }); + }, [sendFn]); + + const handleApproveNewSession = useCallback(async (sessionId: string) => { + if (!sendFn) return; + + return new Promise((resolve, reject) => { + sendFn("approveNewSession", { sessionId }, (response: { result?: unknown; error?: { message: string } }) => { + if (response.error) { + console.error("Failed to approve new session:", response.error); + reject(new Error(response.error.message || "Failed to approve")); + } else { + setNewSessionRequest(null); + resolve(); + } + }); + }); + }, [sendFn]); + + const handleDenyNewSession = useCallback(async (sessionId: string) => { + if (!sendFn) return; + + return new Promise((resolve, reject) => { + sendFn("denyNewSession", { sessionId }, (response: { result?: unknown; error?: { message: string } }) => { + if (response.error) { + console.error("Failed to deny new session:", response.error); + reject(new Error(response.error.message || "Failed to deny")); + } else { + setNewSessionRequest(null); + resolve(); + } + }); + }); + }, [sendFn]); + + const handleRpcEvent = useCallback((method: string, params: unknown) => { + if (method === "sessionsUpdated" || + method === "modeChanged" || + method === "connectionModeChanged" || + method === "otherSessionConnected") { + handleSessionEvent(method, params); + } + + if (method === "newSessionPending" && requireSessionApproval) { + if (isLoadingPermissions || hasPermission(Permission.SESSION_APPROVE)) { + setNewSessionRequest(params as NewSessionRequest); + } + } + + if (method === "primaryControlRequested") { + setPrimaryControlRequest(params as PrimaryControlRequest); + } + + if (method === "primaryControlApproved") { + const { setRequestingPrimary } = useSessionStore.getState(); + setRequestingPrimary(false); + } + + if (method === "primaryControlDenied") { + const { setRequestingPrimary, setSessionError } = useSessionStore.getState(); + setRequestingPrimary(false); + setSessionError("Your primary control request was denied"); + } + + if (method === "sessionAccessDenied") { + const { setSessionError } = useSessionStore.getState(); + const errorParams = params as { message?: string }; + setSessionError(errorParams.message || "Session access was denied by the primary session"); + } + }, [handleSessionEvent, hasPermission, isLoadingPermissions, requireSessionApproval]); + + useEffect(() => { + if (!isLoadingPermissions && newSessionRequest && !hasPermission(Permission.SESSION_APPROVE)) { + setNewSessionRequest(null); + } + }, [isLoadingPermissions, hasPermission, newSessionRequest]); + + useEffect(() => { + return () => { + clearSession(); + }; + }, [clearSession]); + + return { + handleSessionResponse, + handleRpcEvent, + primaryControlRequest, + handleApprovePrimaryRequest, + handleDenyPrimaryRequest, + closePrimaryControlRequest: () => setPrimaryControlRequest(null), + newSessionRequest, + handleApproveNewSession, + handleDenyNewSession, + closeNewSessionRequest: () => setNewSessionRequest(null) + }; +} \ No newline at end of file diff --git a/ui/src/hooks/useVersion.tsx b/ui/src/hooks/useVersion.tsx index 7341dacb0..e017014ba 100644 --- a/ui/src/hooks/useVersion.tsx +++ b/ui/src/hooks/useVersion.tsx @@ -29,7 +29,10 @@ export function useVersion() { return new Promise((resolve, reject) => { send("getUpdateStatus", {}, (resp: JsonRpcResponse) => { if ("error" in resp) { - notifications.error(`Failed to check for updates: ${resp.error}`); + const errorMsg = typeof resp.error === 'object' && resp.error.message + ? resp.error.message + : String(resp.error); + notifications.error(`Failed to check for updates: ${errorMsg}`); reject(new Error("Failed to check for updates")); } else { const result = resp.result as SystemVersionInfo; @@ -56,8 +59,11 @@ export function useVersion() { console.warn("Failed to get device version, using legacy version"); return getVersionInfo().then(result => resolve(result.local)).catch(reject); } - console.error("Failed to get device version N", resp.error); - notifications.error(`Failed to get device version: ${resp.error}`); + console.error("Failed to get device version:", resp.error); + const errorMsg = typeof resp.error === 'object' && resp.error.message + ? resp.error.message + : String(resp.error); + notifications.error(`Failed to get device version: ${errorMsg}`); reject(new Error("Failed to get device version")); } else { const result = resp.result as VersionInfo; diff --git a/ui/src/main.tsx b/ui/src/main.tsx index 79ca67170..f05d92f7c 100644 --- a/ui/src/main.tsx +++ b/ui/src/main.tsx @@ -49,6 +49,7 @@ const SecurityAccessLocalAuthRoute = lazy(() => import("@routes/devices.$id.sett const SettingsMacrosRoute = lazy(() => import("@routes/devices.$id.settings.macros")); const SettingsMacrosAddRoute = lazy(() => import("@routes/devices.$id.settings.macros.add")); const SettingsMacrosEditRoute = lazy(() => import("@routes/devices.$id.settings.macros.edit")); +const SettingsMultiSessionsRoute = lazy(() => import("@routes/devices.$id.settings.multi-session")); export const isOnDevice = import.meta.env.MODE === "device"; export const isInCloud = !isOnDevice; @@ -211,6 +212,10 @@ if (isOnDevice) { }, ], }, + { + path: "sessions", + element: , + }, ], }, ], @@ -344,6 +349,10 @@ if (isOnDevice) { }, ], }, + { + path: "sessions", + element: , + }, ], }, ], diff --git a/ui/src/notifications.tsx b/ui/src/notifications.tsx index 5158d8d30..a10e63a34 100644 --- a/ui/src/notifications.tsx +++ b/ui/src/notifications.tsx @@ -1,6 +1,11 @@ import toast, { Toast, Toaster, useToasterStore } from "react-hot-toast"; import React, { useEffect } from "react"; -import { CheckCircleIcon, XCircleIcon } from "@heroicons/react/20/solid"; +import { + CheckCircleIcon, + XCircleIcon, + InformationCircleIcon, + ExclamationTriangleIcon +} from "@heroicons/react/20/solid"; import Card from "@/components/Card"; @@ -57,6 +62,32 @@ const notifications = { { duration: 2000, ...options }, ); }, + + info: (message: string, options?: NotificationOptions) => { + return toast.custom( + t => ( + } + message={message} + t={t} + /> + ), + { duration: 2000, ...options }, + ); + }, + + warning: (message: string, options?: NotificationOptions) => { + return toast.custom( + t => ( + } + message={message} + t={t} + /> + ), + { duration: 3000, ...options }, + ); + }, }; function useMaxToasts(max: number) { @@ -82,7 +113,12 @@ export function Notifications({ } // eslint-disable-next-line react-refresh/only-export-components -export default Object.assign(Notifications, { +export const notify = { success: notifications.success, error: notifications.error, -}); + info: notifications.info, + warning: notifications.warning, +}; + +// eslint-disable-next-line react-refresh/only-export-components +export default Object.assign(Notifications, notify); diff --git a/ui/src/providers/PermissionsProvider.tsx b/ui/src/providers/PermissionsProvider.tsx new file mode 100644 index 000000000..57ac8e486 --- /dev/null +++ b/ui/src/providers/PermissionsProvider.tsx @@ -0,0 +1,106 @@ +import { useState, useEffect, useRef, useCallback, ReactNode } from "react"; + +import { useJsonRpc } from "@/hooks/useJsonRpc"; +import { useSessionStore } from "@/stores/sessionStore"; +import { useRTCStore } from "@/hooks/stores"; +import { Permission } from "@/types/permissions"; +import { PermissionsContextValue } from "@/hooks/usePermissions"; +import { PermissionsContext } from "@/contexts/PermissionsContext"; + +type RpcSendFunction = (method: string, params: Record, callback: (response: { result?: unknown; error?: { message: string } }) => void) => void; + +interface PermissionsResponse { + mode: string; + permissions: Record; +} + +export function PermissionsProvider({ children }: { children: ReactNode }) { + const { currentMode } = useSessionStore(); + const { setRpcHidProtocolVersion, rpcHidChannel, rpcDataChannel } = useRTCStore(); + const [permissions, setPermissions] = useState>({}); + const [isLoading, setIsLoading] = useState(true); + const previousCanControl = useRef(false); + + const pollPermissions = useCallback((send: RpcSendFunction) => { + if (!send) return; + + setIsLoading(true); + send("getPermissions", {}, (response: { result?: unknown; error?: { message: string } }) => { + if (!response.error && response.result) { + const result = response.result as PermissionsResponse; + setPermissions(result.permissions); + } + setIsLoading(false); + }); + }, []); + + const { send } = useJsonRpc(); + + useEffect(() => { + if (rpcDataChannel?.readyState !== "open") return; + pollPermissions(send); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [currentMode, rpcDataChannel?.readyState]); + + const hasPermission = useCallback((permission: Permission): boolean => { + return permissions[permission] === true; + }, [permissions]); + + const hasAnyPermission = useCallback((...perms: Permission[]): boolean => { + return perms.some(perm => hasPermission(perm)); + }, [hasPermission]); + + const hasAllPermissions = useCallback((...perms: Permission[]): boolean => { + return perms.every(perm => hasPermission(perm)); + }, [hasPermission]); + + useEffect(() => { + const currentCanControl = hasPermission(Permission.KEYBOARD_INPUT) && hasPermission(Permission.MOUSE_INPUT); + const hadControl = previousCanControl.current; + + if (currentCanControl && !hadControl && rpcHidChannel?.readyState === "open") { + console.info("Gained control permissions, re-initializing HID"); + + setRpcHidProtocolVersion(null); + + import("@/hooks/hidRpc").then(({ HID_RPC_VERSION, HandshakeMessage }) => { + setTimeout(() => { + if (rpcHidChannel?.readyState === "open") { + const handshakeMessage = new HandshakeMessage(HID_RPC_VERSION); + try { + const data = handshakeMessage.marshal(); + rpcHidChannel.send(data as unknown as ArrayBuffer); + console.info("Sent HID handshake after permission change"); + } catch (e) { + console.error("Failed to send HID handshake", e); + } + } + }, 100); + }); + } + + previousCanControl.current = currentCanControl; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [permissions, rpcHidChannel, setRpcHidProtocolVersion]); + + const isPrimary = useCallback(() => currentMode === "primary", [currentMode]); + const isObserver = useCallback(() => currentMode === "observer", [currentMode]); + const isPending = useCallback(() => currentMode === "pending", [currentMode]); + + const value: PermissionsContextValue = { + permissions, + isLoading, + hasPermission, + hasAnyPermission, + hasAllPermissions, + isPrimary, + isObserver, + isPending, + }; + + return ( + + {children} + + ); +} diff --git a/ui/src/routes/devices.$id.settings.access._index.tsx b/ui/src/routes/devices.$id.settings.access._index.tsx index f30bfef1c..18a680a49 100644 --- a/ui/src/routes/devices.$id.settings.access._index.tsx +++ b/ui/src/routes/devices.$id.settings.access._index.tsx @@ -201,6 +201,7 @@ export default function SettingsAccessIndexRoute() { if ("error" in resp) return console.error(resp.error); setDeviceId(resp.result as string); }); + }, [send, getCloudState, getTLSState]); return ( @@ -327,6 +328,7 @@ export default function SettingsAccessIndexRoute() { )} +
{ @@ -77,16 +80,19 @@ export default function SettingsHardwareRoute() { }; useEffect(() => { - send("getBacklightSettings", {}, (resp: JsonRpcResponse) => { - if ("error" in resp) { - return notifications.error( - `Failed to get backlight settings: ${resp.error.data || "Unknown error"}`, - ); - } - const result = resp.result as BacklightSettings; - setBacklightSettings(result); - }); - }, [send, setBacklightSettings]); + // Only fetch settings if user has permission + if (!isLoading && permissions[Permission.SETTINGS_READ] === true) { + send("getBacklightSettings", {}, (resp: JsonRpcResponse) => { + if ("error" in resp) { + return notifications.error( + `Failed to get backlight settings: ${resp.error.data || "Unknown error"}`, + ); + } + const result = resp.result as BacklightSettings; + setBacklightSettings(result); + }); + } + }, [send, setBacklightSettings, isLoading, permissions]); useEffect(() => { send("getVideoSleepMode", {}, (resp: JsonRpcResponse) => { @@ -99,6 +105,24 @@ export default function SettingsHardwareRoute() { }); }, [send]); + // Return early if permissions are loading + if (isLoading) { + return ( +
+
Loading...
+
+ ); + } + + // Return early if no permission + if (!hasPermission(Permission.SETTINGS_READ)) { + return ( +
+
Access Denied: You do not have permission to view these settings.
+
+ ); + } + return (
{ + send("getSessionSettings", {}, (response: JsonRpcResponse) => { + if ("error" in response) { + console.error("Failed to get session settings:", response.error); + } else { + const settings = response.result as { + requireApproval: boolean; + requireNickname: boolean; + reconnectGrace?: number; + primaryTimeout?: number; + privateKeystrokes?: boolean; + maxRejectionAttempts?: number; + maxSessions?: number; + observerTimeout?: number; + }; + setRequireSessionApproval(settings.requireApproval); + setRequireSessionNickname(settings.requireNickname); + if (settings.reconnectGrace !== undefined) { + setReconnectGrace(settings.reconnectGrace); + } + if (settings.primaryTimeout !== undefined) { + setPrimaryTimeout(settings.primaryTimeout); + } + if (settings.privateKeystrokes !== undefined) { + setPrivateKeystrokes(settings.privateKeystrokes); + } + if (settings.maxRejectionAttempts !== undefined) { + setMaxRejectionAttempts(settings.maxRejectionAttempts); + } + if (settings.maxSessions !== undefined) { + setMaxSessions(settings.maxSessions); + } + if (settings.observerTimeout !== undefined) { + setObserverTimeout(settings.observerTimeout); + } + } + }); + }, [send, setRequireSessionApproval, setRequireSessionNickname, setMaxRejectionAttempts]); + + const updateSessionSettings = (updates: Partial<{ + requireApproval: boolean; + requireNickname: boolean; + reconnectGrace: number; + primaryTimeout: number; + privateKeystrokes: boolean; + maxRejectionAttempts: number; + maxSessions: number; + observerTimeout: number; + }>) => { + if (!canModifySettings) { + notify.error("Only the primary session can change this setting"); + return; + } + + send("setSessionSettings", { + settings: { + requireApproval: requireSessionApproval, + requireNickname: requireSessionNickname, + reconnectGrace: reconnectGrace, + primaryTimeout: primaryTimeout, + privateKeystrokes: privateKeystrokes, + maxRejectionAttempts: maxRejectionAttempts, + maxSessions: maxSessions, + observerTimeout: observerTimeout, + ...updates + } + }, (response: JsonRpcResponse) => { + if ("error" in response) { + console.error("Failed to update session settings:", response.error); + notify.error("Failed to update session settings"); + } + }); + }; + + return ( +
+ + + {!canModifySettings && ( + +
+ Note: Only the primary session can modify these settings. + Request primary control to change settings. +
+
+ )} + + +
+
+ +

+ Access Control +

+
+ + + { + const newValue = e.target.checked; + setRequireSessionApproval(newValue); + updateSessionSettings({ requireApproval: newValue }); + notify.success( + newValue + ? "New sessions will require approval" + : "New sessions will be automatically approved" + ); + }} + /> + + + + { + const newValue = e.target.checked; + setRequireSessionNickname(newValue); + updateSessionSettings({ requireNickname: newValue }); + notify.success( + newValue + ? "Session nicknames are now required" + : "Session nicknames are now optional" + ); + }} + /> + + + +
+ { + const newValue = parseInt(e.target.value) || 3; + if (newValue < 1 || newValue > 10) { + notify.error("Maximum attempts must be between 1 and 10"); + return; + } + setMaxRejectionAttempts(newValue); + updateSessionSettings({ maxRejectionAttempts: newValue }); + notify.success( + `Denied sessions can now retry up to ${newValue} time${newValue === 1 ? '' : 's'}` + ); + }} + className="w-20 px-2 py-1.5 border rounded-md bg-white dark:bg-slate-800 border-slate-300 dark:border-slate-600 text-slate-900 dark:text-white disabled:opacity-50 disabled:cursor-not-allowed text-sm" + /> + attempts +
+
+ + +
+ { + const newValue = parseInt(e.target.value) || 10; + if (newValue < 5 || newValue > 60) { + notify.error("Grace period must be between 5 and 60 seconds"); + return; + } + setReconnectGrace(newValue); + updateSessionSettings({ reconnectGrace: newValue }); + notify.success( + `Session will have ${newValue} seconds to reconnect` + ); + }} + className="w-20 px-2 py-1.5 border rounded-md bg-white dark:bg-slate-800 border-slate-300 dark:border-slate-600 text-slate-900 dark:text-white disabled:opacity-50 disabled:cursor-not-allowed text-sm" + /> + seconds +
+
+ + +
+ { + const newValue = parseInt(e.target.value) || 0; + if (newValue < 0 || newValue > 3600) { + notify.error("Timeout must be between 0 and 3600 seconds"); + return; + } + setPrimaryTimeout(newValue); + updateSessionSettings({ primaryTimeout: newValue }); + notify.success( + newValue === 0 + ? "Primary session timeout disabled" + : `Primary session will timeout after ${Math.round(newValue / 60)} minutes of inactivity` + ); + }} + className="w-24 px-2 py-1.5 border rounded-md bg-white dark:bg-slate-800 border-slate-300 dark:border-slate-600 text-slate-900 dark:text-white disabled:opacity-50 disabled:cursor-not-allowed text-sm" + /> + seconds +
+
+ + +
+ { + const newValue = parseInt(e.target.value) || 10; + if (newValue < 1 || newValue > 20) { + notify.error("Max sessions must be between 1 and 20"); + return; + } + setMaxSessions(newValue); + updateSessionSettings({ maxSessions: newValue }); + notify.success( + `Maximum concurrent sessions set to ${newValue}` + ); + }} + className="w-20 px-2 py-1.5 border rounded-md bg-white dark:bg-slate-800 border-slate-300 dark:border-slate-600 text-slate-900 dark:text-white disabled:opacity-50 disabled:cursor-not-allowed text-sm" + /> + sessions +
+
+ + +
+ { + const newValue = parseInt(e.target.value) || 120; + if (newValue < 30 || newValue > 600) { + notify.error("Timeout must be between 30 and 600 seconds"); + return; + } + setObserverTimeout(newValue); + updateSessionSettings({ observerTimeout: newValue }); + notify.success( + `Observer cleanup timeout set to ${Math.round(newValue / 60)} minute${Math.round(newValue / 60) === 1 ? '' : 's'}` + ); + }} + className="w-20 px-2 py-1.5 border rounded-md bg-white dark:bg-slate-800 border-slate-300 dark:border-slate-600 text-slate-900 dark:text-white disabled:opacity-50 disabled:cursor-not-allowed text-sm" + /> + seconds +
+
+ + + { + const newValue = e.target.checked; + setPrivateKeystrokes(newValue); + updateSessionSettings({ privateKeystrokes: newValue }); + notify.success( + newValue + ? "Keystrokes are now private to primary session" + : "Keystrokes are visible to all authorized sessions" + ); + }} + /> + +
+
+ + +
+
+

+ How Multi-Session Access Works +

+
+
+ Primary: + Full control over the KVM device including keyboard, mouse, and settings +
+
+ Observer: + View-only access to monitor activity without control capabilities +
+
+ Pending: + Awaiting approval from the primary session (when approval is required) +
+
+
+ Use the Sessions panel in the top navigation bar to view and manage active sessions. +
+
+
+
+
+ ); +} \ No newline at end of file diff --git a/ui/src/routes/devices.$id.settings.tsx b/ui/src/routes/devices.$id.settings.tsx index 338beb976..a81f5c1c4 100644 --- a/ui/src/routes/devices.$id.settings.tsx +++ b/ui/src/routes/devices.$id.settings.tsx @@ -1,5 +1,5 @@ import React, { useEffect, useRef, useState } from "react"; -import { NavLink, Outlet, useLocation } from "react-router"; +import { NavLink, Outlet, useLocation , useNavigate } from "react-router"; import { LuSettings, LuMouse, @@ -12,6 +12,7 @@ import { LuPalette, LuCommand, LuNetwork, + LuUsers, } from "react-icons/lu"; import { useResizeObserver } from "usehooks-ts"; @@ -20,11 +21,24 @@ import Card from "@components/Card"; import { LinkButton } from "@components/Button"; import { FeatureFlag } from "@components/FeatureFlag"; import { useUiStore } from "@/hooks/stores"; +import { useSessionStore } from "@/stores/sessionStore"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; /* TODO: Migrate to using URLs instead of the global state. To simplify the refactoring, we'll keep the global state for now. */ export default function SettingsRoute() { const location = useLocation(); + const navigate = useNavigate(); const { setDisableVideoFocusTrap } = useUiStore(); + const { currentMode } = useSessionStore(); + const { hasPermission, isLoading, permissions } = usePermissions(); + + useEffect(() => { + if (!isLoading && !permissions[Permission.SETTINGS_ACCESS] && currentMode !== null) { + navigate("/", { replace: true }); + } + }, [permissions, isLoading, currentMode, navigate]); + const scrollContainerRef = useRef(null); const [showLeftGradient, setShowLeftGradient] = useState(false); const [showRightGradient, setShowRightGradient] = useState(false); @@ -69,6 +83,21 @@ export default function SettingsRoute() { }; }, [setDisableVideoFocusTrap]); + // Check permissions first - return early to prevent any content flash + // Show loading state while permissions are being checked + if (isLoading) { + return ( +
+
Checking permissions...
+
+ ); + } + + // Don't render settings content if user doesn't have permission + if (!hasPermission(Permission.SETTINGS_ACCESS)) { + return null; + } + return (
@@ -223,6 +252,17 @@ export default function SettingsRoute() {
+
+ (isActive ? "active" : "")} + > +
+ +

Multi-Session Access

+
+
+
import('@/components/sidebar/connectionStats')); -const Terminal = lazy(() => import('@components/Terminal')); -const UpdateInProgressStatusCard = lazy(() => import("@/components/UpdateInProgressStatusCard")); +const ConnectionStatsSidebar = lazy(() => import("@/components/sidebar/connectionStats")); +const Terminal = lazy(() => import("@components/Terminal")); +const UpdateInProgressStatusCard = lazy( + () => import("@/components/UpdateInProgressStatusCard"), +); import Modal from "@/components/Modal"; -import { JsonRpcRequest, JsonRpcResponse, RpcMethodNotFound, useJsonRpc } from "@/hooks/useJsonRpc"; +import { + JsonRpcRequest, + JsonRpcResponse, + RpcMethodNotFound, + useJsonRpc, +} from "@/hooks/useJsonRpc"; import { ConnectionFailedOverlay, LoadingConnectionOverlay, @@ -50,8 +62,14 @@ import { } from "@/components/VideoOverlay"; import { useDeviceUiNavigation } from "@/hooks/useAppNavigation"; import { FeatureFlagProvider } from "@/providers/FeatureFlagProvider"; +import { PermissionsProvider } from "@/providers/PermissionsProvider"; +import { usePermissions } from "@/hooks/usePermissions"; +import { Permission } from "@/types/permissions"; import { DeviceStatus } from "@routes/welcome-local"; import { useVersion } from "@/hooks/useVersion"; +import { useSessionManagement } from "@/hooks/useSessionManagement"; +import { useSessionStore, useSharedSessionStore } from "@/stores/sessionStore"; +import { sessionApi } from "@/api/sessionApi"; interface LocalLoaderResp { authMode: "password" | "noPassword" | null; @@ -124,15 +142,25 @@ export default function KvmIdRoute() { const authMode = "authMode" in loaderResp ? loaderResp.authMode : null; const params = useParams() as { id: string }; - const { sidebarView, setSidebarView, disableVideoFocusTrap, rebootState, setRebootState } = useUiStore(); + const { + sidebarView, + setSidebarView, + disableVideoFocusTrap, + setDisableVideoFocusTrap, + rebootState, + setRebootState, + } = useUiStore(); const [queryParams, setQueryParams] = useSearchParams(); const { - peerConnection, setPeerConnection, - peerConnectionState, setPeerConnectionState, + peerConnection, + setPeerConnection, + peerConnectionState, + setPeerConnectionState, setMediaStream, setRpcDataChannel, - isTurnServerInUse, setTurnServerInUse, + isTurnServerInUse, + setTurnServerInUse, rpcDataChannel, setTransceiver, setRpcHidChannel, @@ -143,15 +171,22 @@ export default function KvmIdRoute() { const location = useLocation(); const isLegacySignalingEnabled = useRef(false); const [connectionFailed, setConnectionFailed] = useState(false); + const [showNicknameModal, setShowNicknameModal] = useState(false); + const [accessDenied, setAccessDenied] = useState(false); const navigate = useNavigate(); const { otaState, setOtaState, setModalView } = useUpdateStore(); + const { currentSessionId, currentMode, setCurrentSession } = useSessionStore(); + const { nickname, setNickname } = useSharedSessionStore(); + const { setRequireSessionApproval, setRequireSessionNickname } = useSettingsStore(); + const [globalSessionSettings, setGlobalSessionSettings] = useState<{ + requireApproval: boolean; + requireNickname: boolean; + } | null>(null); const [loadingMessage, setLoadingMessage] = useState("Connecting to device..."); const cleanupAndStopReconnecting = useCallback( function cleanupAndStopReconnecting() { - console.log("Closing peer connection"); - setConnectionFailed(true); if (peerConnection) { setPeerConnectionState(peerConnection.connectionState); @@ -188,7 +223,6 @@ export default function KvmIdRoute() { try { await pc.setRemoteDescription(new RTCSessionDescription(remoteDescription)); - console.log("[setRemoteSessionDescription] Remote description set successfully"); setLoadingMessage("Establishing secure connection..."); } catch (error) { console.error( @@ -206,7 +240,6 @@ export default function KvmIdRoute() { // When vivaldi has disabled "Broadcast IP for Best WebRTC Performance", this never connects if (pc.sctp?.state === "connected") { - console.log("[setRemoteSessionDescription] Remote description set"); clearInterval(checkInterval); setLoadingMessage("Connection established"); } else if (attempts >= 10) { @@ -219,11 +252,6 @@ export default function KvmIdRoute() { ); cleanupAndStopReconnecting(); clearInterval(checkInterval); - } else { - console.log("[setRemoteSessionDescription] Waiting for connection, state:", { - connectionState: pc.connectionState, - iceConnectionState: pc.iceConnectionState, - }); } }, 1000); }, @@ -246,7 +274,6 @@ export default function KvmIdRoute() { reconnectAttempts: 2000, reconnectInterval: 1000, onReconnectStop: () => { - console.debug("Reconnect stopped"); cleanupAndStopReconnecting(); }, @@ -255,9 +282,8 @@ export default function KvmIdRoute() { return !isLegacySignalingEnabled.current; }, - onClose(event) { - console.debug("[Websocket] onClose", event); - // We don't want to close everything down, we wait for the reconnect to stop instead + onClose(_event) { + // Handled by onReconnectStop instead }, onError(event) { @@ -296,27 +322,51 @@ export default function KvmIdRoute() { const parsedMessage = JSON.parse(message.data); if (parsedMessage.type === "device-metadata") { - const { deviceVersion } = parsedMessage.data; - console.debug("[Websocket] Received device-metadata message"); - console.debug("[Websocket] Device version", deviceVersion); + const { deviceVersion, sessionSettings } = parsedMessage.data; + + // Store session settings if provided + if (sessionSettings) { + setGlobalSessionSettings({ + requireNickname: sessionSettings.requireNickname || false, + requireApproval: sessionSettings.requireApproval || false, + }); + // Also update the settings store for approval handling + setRequireSessionApproval(sessionSettings.requireApproval || false); + setRequireSessionNickname(sessionSettings.requireNickname || false); + } + // If the device version is not set, we can assume the device is using the legacy signaling if (!deviceVersion) { - console.log("[Websocket] Device is using legacy signaling"); - // Now we don't need the websocket connection anymore, as we've established that we need to use the legacy signaling // which does everything over HTTP(at least from the perspective of the client) isLegacySignalingEnabled.current = true; getWebSocket()?.close(); } else { - console.log("[Websocket] Device is using new signaling"); isLegacySignalingEnabled.current = false; } + + // Always setup peer connection first to establish RPC channel for nickname generation setupPeerConnection(); + + // Check if nickname is required and not set - modal will be shown after RPC channel is ready + const requiresNickname = sessionSettings?.requireNickname || false; + + if (requiresNickname && !nickname) { + // Store that we need to show the nickname modal once RPC is ready + // The useEffect in NicknameModal will handle waiting for RPC channel readiness + setShowNicknameModal(true); + setDisableVideoFocusTrap(true); + } } - if (!peerConnection) return; + if (!peerConnection) { + console.warn( + "[Websocket] Ignoring message because peerConnection is not ready:", + parsedMessage.type, + ); + return; + } if (parsedMessage.type === "answer") { - console.debug("[Websocket] Received answer"); const readyForOffer = // If we're making an offer, we don't want to accept an answer !makingOffer && @@ -330,14 +380,46 @@ export default function KvmIdRoute() { // Set so we don't accept an answer while we're setting the remote description isSettingRemoteAnswerPending.current = parsedMessage.type === "answer"; - console.debug( - "[Websocket] Setting remote answer pending", - isSettingRemoteAnswerPending.current, - ); const sd = atob(parsedMessage.data); const remoteSessionDescription = JSON.parse(sd); + if (parsedMessage.sessionId && parsedMessage.mode) { + handleSessionResponse({ + sessionId: parsedMessage.sessionId, + mode: parsedMessage.mode, + }); + + // Store sessionId via zustand (persists to sessionStorage for per-tab isolation) + setCurrentSession(parsedMessage.sessionId, parsedMessage.mode); + if ( + parsedMessage.requireNickname !== undefined && + parsedMessage.requireApproval !== undefined + ) { + setGlobalSessionSettings({ + requireNickname: parsedMessage.requireNickname, + requireApproval: parsedMessage.requireApproval, + }); + // Also update the settings store for approval handling + setRequireSessionApproval(parsedMessage.requireApproval); + setRequireSessionNickname(parsedMessage.requireNickname); + } + + // Show nickname modal if: + // 1. Nickname is required by backend settings + // 2. We don't already have a nickname + // This happens even for pending sessions so the nickname is included in approval + const hasNickname = + parsedMessage.nickname && parsedMessage.nickname.length > 0; + const requiresNickname = + parsedMessage.requireNickname || globalSessionSettings?.requireNickname; + + if (requiresNickname && !hasNickname) { + setShowNicknameModal(true); + setDisableVideoFocusTrap(true); + } + } + setRemoteSessionDescription( peerConnection, new RTCSessionDescription(remoteSessionDescription), @@ -346,21 +428,63 @@ export default function KvmIdRoute() { // Reset the remote answer pending flag isSettingRemoteAnswerPending.current = false; } else if (parsedMessage.type === "new-ice-candidate") { - console.debug("[Websocket] Received new-ice-candidate"); const candidate = parsedMessage.data; - peerConnection.addIceCandidate(candidate); + // Always try to add the ICE candidate - the browser will queue it internally if needed + peerConnection.addIceCandidate(candidate).catch(error => { + console.warn("[Websocket] Failed to add ICE candidate:", error); + }); + } else if (parsedMessage.type === "connectionModeChanged") { + // Handle mode changes via WebSocket (fallback when RPC channel stale) + const { newMode, action } = parsedMessage.data; + + if (action === "reconnect_required" && newMode) { + // Update session state immediately + if (currentSessionId) { + setCurrentSession(currentSessionId, newMode); + } + + // Trigger RPC event handler + handleRpcEvent("connectionModeChanged", parsedMessage.data); + + // Only reconnect if the peer connection is actually stale + // If already connected, the mode change via RPC is sufficient + const isConnectionHealthy = + peerConnection?.connectionState === "connected" && + peerConnection?.iceConnectionState === "connected"; + + if (!isConnectionHealthy) { + console.log( + `[Websocket] Mode changed to ${newMode}, connection unhealthy, reconnecting...`, + ); + setTimeout(() => { + peerConnection?.close(); + setupPeerConnection(); + }, 500); + } else { + console.log( + `[Websocket] Mode changed to ${newMode}, connection healthy, skipping reconnect`, + ); + } + } } }, - } + }, ); const sendWebRTCSignal = useCallback( (type: string, data: unknown) => { // Second argument tells the library not to queue the message, and send it once the connection is established again. // We have event handlers that handle the connection set up, so we don't need to queue the message. - sendMessage(JSON.stringify({ type, data }), false); + const message = JSON.stringify({ type, data }); + const ws = getWebSocket(); + if (ws?.readyState === WebSocket.OPEN) { + sendMessage(message, false); + } else { + console.warn(`[WebSocket] WebSocket not open, queuing message:`, message); + sendMessage(message, true); // Queue the message + } }, - [sendMessage], + [sendMessage, getWebSocket], ); const legacyHTTPSignaling = useCallback( @@ -371,12 +495,12 @@ export default function KvmIdRoute() { // In device mode, old devices wont server this JS, and on newer devices legacy mode wont be enabled const sessionUrl = `${CLOUD_API}/webrtc/session`; - console.log("Trying to get remote session description"); setLoadingMessage( `Getting remote session description... ${signalingAttempts.current > 0 ? `(attempt ${signalingAttempts.current + 1})` : ""}`, ); const res = await api.POST(sessionUrl, { sd, + userAgent: navigator.userAgent, // When on device, we don't need to specify the device id, as it's already known ...(isOnDevice ? {} : { id: params.id }), }); @@ -389,7 +513,6 @@ export default function KvmIdRoute() { return; } - console.debug("Successfully got Remote Session Description. Setting."); setLoadingMessage("Setting remote session description..."); const decodedSd = atob(json.sd); @@ -400,13 +523,11 @@ export default function KvmIdRoute() { ); const setupPeerConnection = useCallback(async () => { - console.debug("[setupPeerConnection] Setting up peer connection"); setConnectionFailed(false); setLoadingMessage("Connecting to device..."); let pc: RTCPeerConnection; try { - console.debug("[setupPeerConnection] Creating peer connection"); setLoadingMessage("Creating peer connection..."); pc = new RTCPeerConnection({ // We only use STUN or TURN servers if we're in the cloud @@ -416,7 +537,6 @@ export default function KvmIdRoute() { }); setPeerConnectionState(pc.connectionState); - console.debug("[setupPeerConnection] Peer connection created", pc); setLoadingMessage("Setting up connection to device..."); } catch (e) { console.error(`[setupPeerConnection] Error creating peer connection: ${e}`); @@ -428,13 +548,11 @@ export default function KvmIdRoute() { // Set up event listeners and data channels pc.onconnectionstatechange = () => { - console.debug("[setupPeerConnection] Connection state changed", pc.connectionState); setPeerConnectionState(pc.connectionState); }; pc.onnegotiationneeded = async () => { try { - console.debug("[setupPeerConnection] Creating offer"); makingOffer.current = true; const offer = await pc.createOffer(); @@ -442,9 +560,19 @@ export default function KvmIdRoute() { const sd = btoa(JSON.stringify(pc.localDescription)); const isNewSignalingEnabled = isLegacySignalingEnabled.current === false; if (isNewSignalingEnabled) { - sendWebRTCSignal("offer", { sd: sd }); - } else { - console.log("Legacy signaling. Waiting for ICE Gathering to complete..."); + // Get nickname and sessionId from zustand stores + // sessionId is per-tab (sessionStorage), nickname is shared (localStorage) + const { currentSessionId: storeSessionId } = useSessionStore.getState(); + const { nickname: storeNickname } = useSharedSessionStore.getState(); + + sendWebRTCSignal("offer", { + sd: sd, + sessionId: storeSessionId || undefined, + userAgent: navigator.userAgent, + sessionSettings: { + nickname: storeNickname || undefined, + }, + }); } } catch (e) { console.error( @@ -458,15 +586,18 @@ export default function KvmIdRoute() { }; pc.onicecandidate = ({ candidate }) => { - if (!candidate) return; - if (candidate.candidate === "") return; + if (!candidate) { + return; + } + if (candidate.candidate === "") { + return; + } sendWebRTCSignal("new-ice-candidate", candidate); }; pc.onicegatheringstatechange = event => { const pc = event.currentTarget as RTCPeerConnection; if (pc.iceGatheringState === "complete") { - console.debug("ICE Gathering completed"); setLoadingMessage("ICE Gathering completed"); if (isLegacySignalingEnabled.current) { @@ -474,7 +605,6 @@ export default function KvmIdRoute() { legacyHTTPSignaling(pc); } } else if (pc.iceGatheringState === "gathering") { - console.debug("ICE Gathering Started"); setLoadingMessage("Gathering ICE candidates..."); } }; @@ -505,10 +635,13 @@ export default function KvmIdRoute() { setRpcHidUnreliableChannel(rpcHidUnreliableChannel); }; - const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel("hidrpc-unreliable-nonordered", { - ordered: false, - maxRetransmits: 0, - }); + const rpcHidUnreliableNonOrderedChannel = pc.createDataChannel( + "hidrpc-unreliable-nonordered", + { + ordered: false, + maxRetransmits: 0, + }, + ); rpcHidUnreliableNonOrderedChannel.binaryType = "arraybuffer"; rpcHidUnreliableNonOrderedChannel.onopen = () => { setRpcHidUnreliableNonOrderedChannel(rpcHidUnreliableNonOrderedChannel); @@ -599,19 +732,24 @@ export default function KvmIdRoute() { } // Fire and forget - api.POST(`${CLOUD_API}/webrtc/turn_activity`, { - bytesReceived: bytesReceivedDelta, - bytesSent: bytesSentDelta, - }).catch(() => { - // we don't care about errors here, but we don't want unhandled promise rejections - }); + api + .POST(`${CLOUD_API}/webrtc/turn_activity`, { + bytesReceived: bytesReceivedDelta, + bytesSent: bytesSentDelta, + }) + .catch(() => { + // we don't care about errors here, but we don't want unhandled promise rejections + }); }, 10000); const { setNetworkState } = useNetworkStateStore(); const { setHdmiState } = useVideoStore(); const { - keyboardLedState, setKeyboardLedState, - keysDownState, setKeysDownState, setUsbState, + keyboardLedState, + setKeyboardLedState, + keysDownState, + setKeysDownState, + setUsbState, } = useHidStore(); const setHidRpcDisabled = useRTCStore(state => state.setHidRpcDisabled); @@ -619,42 +757,56 @@ export default function KvmIdRoute() { const { navigateTo } = useDeviceUiNavigation(); function onJsonRpcRequest(resp: JsonRpcRequest) { - if (resp.method === "otherSessionConnected") { - navigateTo("/other-session"); + // Handle session-related events + if ( + resp.method === "sessionsUpdated" || + resp.method === "modeChanged" || + resp.method === "connectionModeChanged" || + resp.method === "otherSessionConnected" || + resp.method === "primaryControlRequested" || + resp.method === "primaryControlApproved" || + resp.method === "primaryControlDenied" || + resp.method === "newSessionPending" || + resp.method === "sessionAccessDenied" + ) { + handleRpcEvent(resp.method, resp.params); + + // Show access denied overlay if our session was denied + if (resp.method === "sessionAccessDenied") { + setAccessDenied(true); + } + + if (resp.method === "otherSessionConnected") { + navigateTo("/other-session"); + } } if (resp.method === "usbState") { const usbState = resp.params as unknown as USBStates; - console.debug("Setting USB state", usbState); setUsbState(usbState); } if (resp.method === "videoInputState") { const hdmiState = resp.params as Parameters[0]; - console.debug("Setting HDMI state", hdmiState); setHdmiState(hdmiState); } if (resp.method === "networkState") { - console.debug("Setting network state", resp.params); setNetworkState(resp.params as NetworkState); } if (resp.method === "keyboardLedState") { const ledState = resp.params as KeyboardLedState; - console.debug("Setting keyboard led state", ledState); setKeyboardLedState(ledState); } if (resp.method === "keysDownState") { const downState = resp.params as KeysDownState; - console.debug("Setting key down state:", downState); setKeysDownState(downState); } if (resp.method === "otaState") { const otaState = resp.params as OtaState; - console.debug("Setting OTA state", otaState); setOtaState(otaState); if (otaState.updating === true) { @@ -687,24 +839,44 @@ export default function KvmIdRoute() { const { send } = useJsonRpc(onJsonRpcRequest); + const { + handleSessionResponse, + handleRpcEvent, + primaryControlRequest, + handleApprovePrimaryRequest, + handleDenyPrimaryRequest, + closePrimaryControlRequest, + newSessionRequest, + handleApproveNewSession, + handleDenyNewSession, + closeNewSessionRequest, + } = useSessionManagement(send); + + const { hasPermission, isLoading: isLoadingPermissions } = usePermissions(); + useEffect(() => { if (rpcDataChannel?.readyState !== "open") return; - console.log("Requesting video state"); + if (isLoadingPermissions || !hasPermission(Permission.VIDEO_VIEW)) return; + send("getVideoState", {}, (resp: JsonRpcResponse) => { if ("error" in resp) return; const hdmiState = resp.result as Parameters[0]; - console.debug("Setting HDMI state", hdmiState); setHdmiState(hdmiState); }); - }, [rpcDataChannel?.readyState, send, setHdmiState]); + }, [ + rpcDataChannel?.readyState, + hasPermission, + isLoadingPermissions, + send, + setHdmiState, + ]); const [needLedState, setNeedLedState] = useState(true); - // request keyboard led state from the device useEffect(() => { if (rpcDataChannel?.readyState !== "open") return; if (!needLedState) return; - console.log("Requesting keyboard led state"); + if (isLoadingPermissions || !hasPermission(Permission.KEYBOARD_INPUT)) return; send("getKeyboardLedState", {}, (resp: JsonRpcResponse) => { if ("error" in resp) { @@ -712,39 +884,54 @@ export default function KvmIdRoute() { return; } else { const ledState = resp.result as KeyboardLedState; - console.debug("Keyboard led state: ", ledState); setKeyboardLedState(ledState); } setNeedLedState(false); }); - }, [rpcDataChannel?.readyState, send, setKeyboardLedState, keyboardLedState, needLedState]); + }, [ + rpcDataChannel?.readyState, + send, + setKeyboardLedState, + keyboardLedState, + needLedState, + hasPermission, + isLoadingPermissions, + ]); const [needKeyDownState, setNeedKeyDownState] = useState(true); - // request keyboard key down state from the device useEffect(() => { if (rpcDataChannel?.readyState !== "open") return; if (!needKeyDownState) return; - console.log("Requesting keys down state"); + if (isLoadingPermissions || !hasPermission(Permission.KEYBOARD_INPUT)) return; send("getKeyDownState", {}, (resp: JsonRpcResponse) => { if ("error" in resp) { - // -32601 means the method is not supported if (resp.error.code === RpcMethodNotFound) { - // if we don't support key down state, we know key press is also not available - console.warn("Failed to get key down state, switching to old-school", resp.error); + console.warn( + "Failed to get key down state, switching to old-school", + resp.error, + ); setHidRpcDisabled(true); } else { console.error("Failed to get key down state", resp.error); } } else { const downState = resp.result as KeysDownState; - console.debug("Keyboard key down state", downState); setKeysDownState(downState); } setNeedKeyDownState(false); }); - }, [keysDownState, needKeyDownState, rpcDataChannel?.readyState, send, setKeysDownState, setHidRpcDisabled]); + }, [ + keysDownState, + needKeyDownState, + rpcDataChannel?.readyState, + send, + setKeysDownState, + setHidRpcDisabled, + hasPermission, + isLoadingPermissions, + ]); // When the update is successful, we need to refresh the client javascript and show a success modal useEffect(() => { @@ -777,9 +964,12 @@ export default function KvmIdRoute() { useEffect(() => { if (appVersion) return; + if (rpcDataChannel?.readyState !== "open") return; + if (currentMode === "pending") return; getLocalVersion(); - }, [appVersion, getLocalVersion]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [appVersion, rpcDataChannel?.readyState, currentMode]); const ConnectionStatusElement = useMemo(() => { const isOtherSession = location.pathname.includes("other-session"); @@ -787,7 +977,9 @@ export default function KvmIdRoute() { // Rebooting takes priority over connection status if (rebootState?.isRebooting) { - return ; + return ( + + ); } const hasConnectionFailed = @@ -814,87 +1006,188 @@ export default function KvmIdRoute() { } return null; - }, [location.pathname, rebootState?.isRebooting, rebootState?.postRebootAction, connectionFailed, peerConnectionState, peerConnection, setupPeerConnection, loadingMessage]); + }, [ + location.pathname, + rebootState?.isRebooting, + rebootState?.postRebootAction, + connectionFailed, + peerConnectionState, + peerConnection, + setupPeerConnection, + loadingMessage, + ]); return ( - - {!outlet && otaState.updating && ( - - - - - - )} -
- -
-
-
- -
- - -
- -
+ + {!outlet && otaState.updating && ( + + -
- {!!ConnectionStatusElement && ConnectionStatusElement} -
+ +
+
+ )} +
+ +
+
+
+ +
+ + +
+ {/* Only show video feed if nickname is set (when required) and not pending approval */} + {!showNicknameModal && currentMode !== "pending" ? ( + <> + +
+
+ {!!ConnectionStatusElement && ConnectionStatusElement} +
+
+ + ) : ( +
+
+ {showNicknameModal &&

Please set your nickname to continue

} + {currentMode === "pending" &&

Waiting for session approval...

} +
+
+ )} +
-
-
-
e.stopPropagation()} - onMouseUp={e => e.stopPropagation()} - onMouseDown={e => e.stopPropagation()} - onKeyUp={e => e.stopPropagation()} - onKeyDown={e => { - e.stopPropagation(); - if (e.key === "Escape") navigateTo("/"); - }} - > - - {/* The 'used by other session' modal needs to have access to the connectWebRTC function */} - - -
+
e.stopPropagation()} + onMouseUp={e => e.stopPropagation()} + onMouseDown={e => e.stopPropagation()} + onKeyUp={e => e.stopPropagation()} + onKeyDown={e => { + e.stopPropagation(); + if (e.key === "Escape") navigateTo("/"); + }} + > + + {/* The 'used by other session' modal needs to have access to the connectWebRTC function */} + + + + { + setNickname(nickname); + setShowNicknameModal(false); + setDisableVideoFocusTrap(false); + + if (currentSessionId && send) { + try { + await sessionApi.updateNickname(send, currentSessionId, nickname); + } catch (error) { + console.error("Failed to update nickname:", error); + } + } + }} + onSkip={() => { + setShowNicknameModal(false); + setDisableVideoFocusTrap(false); + }} + /> +
- {kvmTerminal && ( - - )} + {kvmTerminal && ( + + )} + + {serialConsole && ( + + )} + + {/* Unified Session Request Dialog */} + {(primaryControlRequest || newSessionRequest) && ( + + )} + + { + if (!send) return; + try { + await sessionApi.requestSessionApproval(send); + setAccessDenied(false); + } catch (error) { + console.error("Failed to re-request approval:", error); + } + }} + /> - {serialConsole && ( - - )} - + + + ); } diff --git a/ui/src/stores/sessionStore.ts b/ui/src/stores/sessionStore.ts new file mode 100644 index 000000000..eee8aa0a2 --- /dev/null +++ b/ui/src/stores/sessionStore.ts @@ -0,0 +1,175 @@ +import { create } from "zustand"; +import { persist, createJSONStorage } from "zustand/middleware"; + +export type SessionMode = "primary" | "observer" | "queued" | "pending"; + +export interface SessionInfo { + id: string; + mode: SessionMode; + source: "local" | "cloud"; + identity?: string; + nickname?: string; + createdAt: string; + lastActive: string; +} + +export interface SessionState { + // Current session info + currentSessionId: string | null; + currentMode: SessionMode | null; + + // All active sessions + sessions: SessionInfo[]; + + // UI state + isRequestingPrimary: boolean; + sessionError: string | null; + rejectionCount: number; + + // Actions + setCurrentSession: (id: string, mode: SessionMode) => void; + setSessions: (sessions: SessionInfo[]) => void; + setRequestingPrimary: (requesting: boolean) => void; + setSessionError: (error: string | null) => void; + updateSessionMode: (mode: SessionMode) => void; + clearSession: () => void; + incrementRejectionCount: () => number; + resetRejectionCount: () => void; + + // Computed getters + isPrimary: () => boolean; + isObserver: () => boolean; + isQueued: () => boolean; + isPending: () => boolean; + canRequestPrimary: () => boolean; + getPrimarySession: () => SessionInfo | undefined; + getQueuePosition: () => number; +} + +export const useSessionStore = create()( + persist( + (set, get) => ({ + // Initial state + currentSessionId: null, + currentMode: null, + sessions: [], + isRequestingPrimary: false, + sessionError: null, + rejectionCount: 0, + + // Actions + setCurrentSession: (id: string, mode: SessionMode) => { + set({ + currentSessionId: id, + currentMode: mode, + sessionError: null + }); + }, + + setSessions: (sessions: SessionInfo[]) => { + set({ sessions }); + }, + + setRequestingPrimary: (requesting: boolean) => { + set({ isRequestingPrimary: requesting }); + }, + + setSessionError: (error: string | null) => { + set({ sessionError: error }); + }, + + updateSessionMode: (mode: SessionMode) => { + set({ currentMode: mode }); + }, + + clearSession: () => { + set({ + currentSessionId: null, + currentMode: null, + sessions: [], + sessionError: null, + isRequestingPrimary: false, + rejectionCount: 0 + }); + }, + + incrementRejectionCount: () => { + const newCount = get().rejectionCount + 1; + set({ rejectionCount: newCount }); + return newCount; + }, + + resetRejectionCount: () => { + set({ rejectionCount: 0 }); + }, + + // Computed getters + isPrimary: () => { + return get().currentMode === "primary"; + }, + + isObserver: () => { + return get().currentMode === "observer"; + }, + + isQueued: () => { + return get().currentMode === "queued"; + }, + + isPending: () => { + return get().currentMode === "pending"; + }, + + canRequestPrimary: () => { + const state = get(); + return state.currentMode === "observer" && + !state.isRequestingPrimary && + state.sessions.some(s => s.mode === "primary"); + }, + + getPrimarySession: () => { + return get().sessions.find(s => s.mode === "primary"); + }, + + getQueuePosition: () => { + const state = get(); + if (state.currentMode !== "queued") return -1; + + const queuedSessions = state.sessions + .filter(s => s.mode === "queued") + .sort((a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime()); + + return queuedSessions.findIndex(s => s.id === state.currentSessionId) + 1; + } + }), + { + name: 'session', + storage: createJSONStorage(() => sessionStorage), + partialize: (state) => ({ + currentSessionId: state.currentSessionId, + }), + } + ) +); + +// Shared session store - separate with localStorage (shared across tabs) +// Used for user preferences that should be consistent across all tabs +export interface SharedSessionState { + nickname: string | null; + setNickname: (nickname: string | null) => void; + clearNickname: () => void; +} + +export const useSharedSessionStore = create()( + persist( + (set) => ({ + nickname: null, + setNickname: (nickname: string | null) => set({ nickname }), + clearNickname: () => set({ nickname: null }), + }), + { + name: 'sharedSession', + storage: createJSONStorage(() => localStorage), + } + ) +); \ No newline at end of file diff --git a/ui/src/types/permissions.ts b/ui/src/types/permissions.ts new file mode 100644 index 000000000..5035fed88 --- /dev/null +++ b/ui/src/types/permissions.ts @@ -0,0 +1,30 @@ +export enum Permission { + VIDEO_VIEW = "video.view", + KEYBOARD_INPUT = "keyboard.input", + MOUSE_INPUT = "mouse.input", + PASTE = "clipboard.paste", + SESSION_TRANSFER = "session.transfer", + SESSION_APPROVE = "session.approve", + SESSION_KICK = "session.kick", + SESSION_REQUEST_PRIMARY = "session.request_primary", + SESSION_RELEASE_PRIMARY = "session.release_primary", + SESSION_MANAGE = "session.manage", + MOUNT_MEDIA = "mount.media", + UNMOUNT_MEDIA = "mount.unmedia", + MOUNT_LIST = "mount.list", + EXTENSION_MANAGE = "extension.manage", + EXTENSION_ATX = "extension.atx", + EXTENSION_DC = "extension.dc", + EXTENSION_SERIAL = "extension.serial", + EXTENSION_WOL = "extension.wol", + SETTINGS_READ = "settings.read", + SETTINGS_WRITE = "settings.write", + SETTINGS_ACCESS = "settings.access", + SYSTEM_REBOOT = "system.reboot", + SYSTEM_UPDATE = "system.update", + SYSTEM_NETWORK = "system.network", + POWER_CONTROL = "power.control", + USB_CONTROL = "usb.control", + TERMINAL_ACCESS = "terminal.access", + SERIAL_ACCESS = "serial.access", +} diff --git a/ui/src/utils/jsonrpc.ts b/ui/src/utils/jsonrpc.ts index ecfa1c4b3..bf750dbd3 100644 --- a/ui/src/utils/jsonrpc.ts +++ b/ui/src/utils/jsonrpc.ts @@ -51,7 +51,7 @@ export function callJsonRpc(options: JsonRpcCallOptions): Promise, callback: (response: { result?: unknown; error?: { message: string } }) => void) => void; + +// Main function that uses backend generation +export async function generateNickname(sendFn?: RpcSendFunction): Promise { + // Require backend function - no fallback + if (!sendFn) { + throw new Error('Backend connection required for nickname generation'); + } + + return new Promise((resolve, reject) => { + try { + const result = sendFn('generateNickname', { userAgent: navigator.userAgent }, (response: { result?: unknown; error?: { message: string } }) => { + const result = response.result as { nickname?: string } | undefined; + if (response && !response.error && result?.nickname) { + resolve(result.nickname); + } else { + reject(new Error('Failed to generate nickname from backend')); + } + }); + + // If sendFn returns undefined (RPC channel not ready), reject immediately + if (result === undefined) { + reject(new Error('RPC connection not ready yet')); + } + } catch (error) { + reject(error); + } + }); +} + +// Synchronous version removed - backend generation is always async +export function generateNicknameSync(): string { + throw new Error('Synchronous nickname generation not supported - use backend generateNickname()'); +} \ No newline at end of file diff --git a/usb.go b/usb.go index af57692f6..9b4a16ab5 100644 --- a/usb.go +++ b/usb.go @@ -27,20 +27,43 @@ func initUsbGadget() { }() gadget.SetOnKeyboardStateChange(func(state usbgadget.KeyboardState) { - if currentSession != nil { - currentSession.reportHidRPCKeyboardLedState(state) + // Check if keystrokes should be private + if currentSessionSettings != nil && currentSessionSettings.PrivateKeystrokes { + // Report to primary session only + if primary := sessionManager.GetPrimarySession(); primary != nil { + primary.reportHidRPCKeyboardLedState(state) + } + } else { + // Report to all authorized sessions (primary and observers, but not pending) + sessionManager.ForEachSession(func(s *Session) { + if s.Mode == SessionModePrimary || s.Mode == SessionModeObserver { + s.reportHidRPCKeyboardLedState(state) + } + }) } }) gadget.SetOnKeysDownChange(func(state usbgadget.KeysDownState) { - if currentSession != nil { - currentSession.enqueueKeysDownState(state) + // Check if keystrokes should be private + if currentSessionSettings != nil && currentSessionSettings.PrivateKeystrokes { + // Report to primary session only + if primary := sessionManager.GetPrimarySession(); primary != nil { + primary.enqueueKeysDownState(state) + } + } else { + // Report to all authorized sessions (primary and observers, but not pending) + sessionManager.ForEachSession(func(s *Session) { + if s.Mode == SessionModePrimary || s.Mode == SessionModeObserver { + s.enqueueKeysDownState(state) + } + }) } }) gadget.SetOnKeepAliveReset(func() { - if currentSession != nil { - currentSession.resetKeepAliveTime() + // Reset keep-alive for primary session + if primary := sessionManager.GetPrimarySession(); primary != nil { + primary.resetKeepAliveTime() } }) @@ -50,26 +73,82 @@ func initUsbGadget() { } } -func rpcKeyboardReport(modifier byte, keys []byte) error { +func (s *Session) rpcKeyboardReport(modifier byte, keys []byte) error { + if s == nil || !s.HasPermission(PermissionKeyboardInput) { + return ErrPermissionDeniedKeyboard + } + sessionManager.UpdateLastActive(s.ID) return gadget.KeyboardReport(modifier, keys) } -func rpcKeypressReport(key byte, press bool) error { +func (s *Session) rpcKeypressReport(key byte, press bool) error { + if s == nil || !s.HasPermission(PermissionKeyboardInput) { + return ErrPermissionDeniedKeyboard + } + sessionManager.UpdateLastActive(s.ID) return gadget.KeypressReport(key, press) } -func rpcAbsMouseReport(x int, y int, buttons uint8) error { +func (s *Session) rpcAbsMouseReport(x int, y int, buttons uint8) error { + if s == nil || !s.HasPermission(PermissionMouseInput) { + return ErrPermissionDeniedMouse + } + sessionManager.UpdateLastActive(s.ID) return gadget.AbsMouseReport(x, y, buttons) } -func rpcRelMouseReport(dx int8, dy int8, buttons uint8) error { +func (s *Session) rpcRelMouseReport(dx int8, dy int8, buttons uint8) error { + if s == nil || !s.HasPermission(PermissionMouseInput) { + return ErrPermissionDeniedMouse + } + sessionManager.UpdateLastActive(s.ID) return gadget.RelMouseReport(dx, dy, buttons) } -func rpcWheelReport(wheelY int8) error { +func (s *Session) rpcWheelReport(wheelY int8) error { + if s == nil || !s.HasPermission(PermissionMouseInput) { + return ErrPermissionDeniedMouse + } + sessionManager.UpdateLastActive(s.ID) return gadget.AbsMouseWheelReport(wheelY) } +// RPC functions that route to the primary session +func rpcKeyboardReport(modifier byte, keys []byte) error { + if primary := sessionManager.GetPrimarySession(); primary != nil { + return primary.rpcKeyboardReport(modifier, keys) + } + return ErrNotPrimarySession +} + +func rpcKeypressReport(key byte, press bool) error { + if primary := sessionManager.GetPrimarySession(); primary != nil { + return primary.rpcKeypressReport(key, press) + } + return ErrNotPrimarySession +} + +func rpcAbsMouseReport(x int, y int, buttons uint8) error { + if primary := sessionManager.GetPrimarySession(); primary != nil { + return primary.rpcAbsMouseReport(x, y, buttons) + } + return ErrNotPrimarySession +} + +func rpcRelMouseReport(dx int8, dy int8, buttons uint8) error { + if primary := sessionManager.GetPrimarySession(); primary != nil { + return primary.rpcRelMouseReport(dx, dy, buttons) + } + return ErrNotPrimarySession +} + +func rpcWheelReport(wheelY int8) error { + if primary := sessionManager.GetPrimarySession(); primary != nil { + return primary.rpcWheelReport(wheelY) + } + return ErrNotPrimarySession +} + func rpcGetKeyboardLedState() (state usbgadget.KeyboardState) { return gadget.GetKeyboardState() } @@ -89,11 +168,7 @@ func rpcGetUSBState() (state string) { func triggerUSBStateUpdate() { go func() { - if currentSession == nil { - usbLogger.Info().Msg("No active RPC session, skipping USB state update") - return - } - writeJSONRPCEvent("usbState", usbState, currentSession) + broadcastJSONRPCEvent("usbState", usbState) }() } diff --git a/video.go b/video.go index cd74e6804..a77db0004 100644 --- a/video.go +++ b/video.go @@ -20,7 +20,7 @@ const ( func triggerVideoStateUpdate() { go func() { - writeJSONRPCEvent("videoInputState", lastVideoState, currentSession) + broadcastJSONRPCEvent("videoInputState", lastVideoState) }() nativeLogger.Info().Interface("state", lastVideoState).Msg("video state updated") diff --git a/web.go b/web.go index 0fd968b88..ff00952bb 100644 --- a/web.go +++ b/web.go @@ -34,10 +34,25 @@ import ( var staticFiles embed.FS type WebRTCSessionRequest struct { - Sd string `json:"sd"` - OidcGoogle string `json:"OidcGoogle,omitempty"` - IP string `json:"ip,omitempty"` - ICEServers []string `json:"iceServers,omitempty"` + Sd string `json:"sd"` + SessionId string `json:"sessionId,omitempty"` + OidcGoogle string `json:"OidcGoogle,omitempty"` + IP string `json:"ip,omitempty"` + ICEServers []string `json:"iceServers,omitempty"` + UserAgent string `json:"userAgent,omitempty"` // Browser user agent for nickname generation + SessionSettings *SessionSettings `json:"sessionSettings,omitempty"` +} + +type SessionSettings struct { + RequireApproval bool `json:"requireApproval"` + RequireNickname bool `json:"requireNickname"` + ReconnectGrace int `json:"reconnectGrace,omitempty"` // Grace period in seconds for primary reconnection + PrimaryTimeout int `json:"primaryTimeout,omitempty"` // Inactivity timeout in seconds for primary session + Nickname string `json:"nickname,omitempty"` + PrivateKeystrokes bool `json:"privateKeystrokes,omitempty"` // If true, only primary session sees keystroke events + MaxRejectionAttempts int `json:"maxRejectionAttempts,omitempty"` // Number of times denied session can retry before modal hides + MaxSessions int `json:"maxSessions,omitempty"` // Maximum number of concurrent sessions (default: 10) + ObserverTimeout int `json:"observerTimeout,omitempty"` // Time in seconds to wait before cleaning up inactive observer sessions (default: 120) } type SetPasswordRequest struct { @@ -158,32 +173,16 @@ func setupRouter() *gin.Engine { protected := r.Group("/") protected.Use(protectedMiddleware()) { - /* - * Legacy WebRTC session endpoint - * - * This endpoint is maintained for backward compatibility when users upgrade from a version - * using the legacy HTTP-based signaling method to the new WebSocket-based signaling method. - * - * During the upgrade process, when the "Rebooting device after update..." message appears, - * the browser still runs the previous JavaScript code which polls this endpoint to establish - * a new WebRTC session. Once the session is established, the page will automatically reload - * with the updated code. - * - * Without this endpoint, the stale JavaScript would fail to establish a connection, - * causing users to see the "Rebooting device after update..." message indefinitely - * until they manually refresh the page, leading to a confusing user experience. - */ - protected.POST("/webrtc/session", handleWebRTCSession) protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal) protected.POST("/cloud/register", handleCloudRegister) protected.GET("/cloud/state", handleCloudState) protected.GET("/device", handleDevice) protected.POST("/auth/logout", handleLogout) - protected.POST("/auth/password-local", handleCreatePassword) - protected.PUT("/auth/password-local", handleUpdatePassword) - protected.DELETE("/auth/local-password", handleDeletePassword) - protected.POST("/storage/upload", handleUploadHttp) + protected.POST("/auth/password-local", requirePermissionMiddleware(PermissionSettingsWrite), handleCreatePassword) + protected.PUT("/auth/password-local", requirePermissionMiddleware(PermissionSettingsWrite), handleUpdatePassword) + protected.DELETE("/auth/local-password", requirePermissionMiddleware(PermissionSettingsWrite), handleDeletePassword) + protected.POST("/storage/upload", requirePermissionMiddleware(PermissionMountMedia), handleUploadHttp) } // Catch-all route for SPA @@ -198,44 +197,6 @@ func setupRouter() *gin.Engine { return r } -// TODO: support multiple sessions? -var currentSession *Session - -func handleWebRTCSession(c *gin.Context) { - var req WebRTCSessionRequest - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - session, err := newSession(SessionConfig{}) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err}) - return - } - - sd, err := session.ExchangeOffer(req.Sd) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err}) - return - } - if currentSession != nil { - writeJSONRPCEvent("otherSessionConnected", nil, currentSession) - peerConn := currentSession.peerConnection - go func() { - time.Sleep(1 * time.Second) - _ = peerConn.Close() - }() - } - - // Cancel any ongoing keyboard macro when session changes - cancelKeyboardMacro() - - currentSession = session - c.JSON(http.StatusOK, gin.H{"sd": sd}) -} - var ( pingMessage = []byte("ping") pongMessage = []byte("pong") @@ -244,7 +205,15 @@ var ( func handleLocalWebRTCSignal(c *gin.Context) { // get the source from the request source := c.ClientIP() - connectionID := uuid.New().String() + + // Try to get existing session ID from cookie for session persistence + sessionID, _ := c.Cookie("sessionId") + if sessionID == "" { + sessionID = uuid.New().String() + // Set session ID cookie with same expiry as auth token (7 days) + c.SetCookie("sessionId", sessionID, 7*24*60*60, "/", "", false, true) + } + connectionID := sessionID scopedLogger := websocketLogger.With(). Str("component", "websocket"). @@ -276,7 +245,17 @@ func handleLocalWebRTCSignal(c *gin.Context) { // Now use conn for websocket operations defer wsCon.Close(websocket.StatusNormalClosure, "") - err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}}) + // Include session settings in device metadata so client knows requirements upfront + sessionSettingsData := gin.H{ + "deviceVersion": builtAppVersion, + } + if currentSessionSettings != nil { + sessionSettingsData["sessionSettings"] = gin.H{ + "requireNickname": currentSessionSettings.RequireNickname, + "requireApproval": currentSessionSettings.RequireApproval, + } + } + err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": sessionSettingsData}) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -380,6 +359,13 @@ func handleWebRTCSignalWsMessages( typ, msg, err := wsCon.Read(runCtx) if err != nil { l.Warn().Str("error", err.Error()).Msg("websocket read error") + // Clean up session when websocket closes + if session := sessionManager.GetSession(connectionID); session != nil && session.peerConnection != nil { + l.Info(). + Str("sessionID", session.ID). + Msg("Closing peer connection due to websocket error") + _ = session.peerConnection.Close() + } return err } if typ != websocket.MessageText { @@ -412,14 +398,17 @@ func handleWebRTCSignalWsMessages( continue } + l.Info().Str("type", message.Type).Str("dataLen", fmt.Sprintf("%d", len(message.Data))).Msg("received WebSocket message") + if message.Type == "offer" { - l.Info().Msg("new session request received") + l.Info().Str("dataRaw", string(message.Data)).Msg("new session request received with raw data") var req WebRTCSessionRequest err = json.Unmarshal(message.Data, &req) if err != nil { - l.Warn().Str("error", err.Error()).Msg("unable to parse session request data") + l.Warn().Str("error", err.Error()).Str("dataRaw", string(message.Data)).Msg("unable to parse session request data") continue } + l.Info().Str("sd", req.Sd[:50]).Msg("parsed session request") if req.OidcGoogle != "" { l.Info().Str("oidcGoogle", req.OidcGoogle).Msg("new session request with OIDC Google") @@ -427,7 +416,7 @@ func handleWebRTCSignalWsMessages( metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() - err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source, &l) + err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source, connectionID, &l) if err != nil { l.Warn().Str("error", err.Error()).Msg("error starting new session") continue @@ -449,14 +438,16 @@ func handleWebRTCSignalWsMessages( l.Info().Str("data", fmt.Sprintf("%v", candidate)).Msg("unmarshalled incoming ICE candidate") - if currentSession == nil { - l.Warn().Msg("no current session, skipping incoming ICE candidate") + // Find the session this ICE candidate belongs to using the connectionID + session := sessionManager.GetSession(connectionID) + if session == nil { + l.Warn().Str("connectionID", connectionID).Msg("no session found for connection ID, skipping incoming ICE candidate") continue } - l.Info().Str("data", fmt.Sprintf("%v", candidate)).Msg("adding incoming ICE candidate to current session") - if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { - l.Warn().Str("error", err.Error()).Msg("failed to add incoming ICE candidate to our peer connection") + l.Info().Str("sessionID", session.ID).Str("data", fmt.Sprintf("%v", candidate)).Msg("adding incoming ICE candidate to correct session") + if err = session.peerConnection.AddICECandidate(candidate); err != nil { + l.Warn().Str("error", err.Error()).Str("sessionID", session.ID).Msg("failed to add incoming ICE candidate to peer connection") } } } @@ -481,7 +472,16 @@ func handleLogin(c *gin.Context) { return } - config.LocalAuthToken = uuid.New().String() + // Don't generate a new token - use the existing one + // This ensures all sessions can share the same auth token + if config.LocalAuthToken == "" { + // Only generate if we don't have one (shouldn't happen in normal operation) + config.LocalAuthToken = uuid.New().String() + if err := SaveConfig(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save configuration"}) + return + } + } // Set the cookie c.SetCookie("authToken", config.LocalAuthToken, 7*24*60*60, "/", "", false, true) @@ -490,14 +490,30 @@ func handleLogin(c *gin.Context) { } func handleLogout(c *gin.Context) { - config.LocalAuthToken = "" - if err := SaveConfig(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save configuration"}) - return + // Get session ID from cookie before clearing + sessionID, _ := c.Cookie("sessionId") + + // Close the WebRTC session immediately for intentional logout + if sessionID != "" { + if session := sessionManager.GetSession(sessionID); session != nil { + websocketLogger.Info(). + Str("sessionID", sessionID). + Msg("Closing session due to intentional logout - no grace period") + + // Close peer connection (will trigger cleanupSession) + if session.peerConnection != nil { + _ = session.peerConnection.Close() + } + + // Clear grace period for intentional logout - observer should be promoted immediately + sessionManager.ClearGracePeriod(sessionID) + } } - // Clear the auth cookie + // Clear the cookies for this session, don't invalidate the token + // The token should remain valid for other sessions c.SetCookie("authToken", "", -1, "/", "", false, true) + c.SetCookie("sessionId", "", -1, "/", "", false, true) c.JSON(http.StatusOK, gin.H{"message": "Logout successful"}) } @@ -519,6 +535,38 @@ func protectedMiddleware() gin.HandlerFunc { } } +// requirePermissionMiddleware creates a middleware that enforces specific permissions +func requirePermissionMiddleware(permission Permission) gin.HandlerFunc { + return func(c *gin.Context) { + // Get session ID from cookie + sessionID, err := c.Cookie("sessionId") + if err != nil || sessionID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "No session ID found"}) + c.Abort() + return + } + + // Get session from manager + session := sessionManager.GetSession(sessionID) + if session == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Session not found"}) + c.Abort() + return + } + + // Check permission + if !session.HasPermission(permission) { + c.JSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("Permission denied: %s required", permission)}) + c.Abort() + return + } + + // Store session in context for use by handlers + c.Set("session", session) + c.Next() + } +} + func sendErrorJsonThenAbort(c *gin.Context, status int, message string) { c.JSON(status, gin.H{"error": message}) c.Abort() @@ -591,7 +639,7 @@ func RunWebServer() { logger.Info().Str("bindAddress", bindAddress).Bool("loopbackOnly", config.LocalLoopbackOnly).Msg("Starting web server") if err := r.Run(bindAddress); err != nil { - panic(err) + logger.Fatal().Err(err).Msg("failed to start web server") } } diff --git a/web_tls.go b/web_tls.go index 41f532ea9..5d04b031b 100644 --- a/web_tls.go +++ b/web_tls.go @@ -184,7 +184,7 @@ func runWebSecureServer() { err := server.ListenAndServeTLS("", "") if !errors.Is(err, http.ErrServerClosed) { - panic(err) + websecureLogger.Fatal().Err(err).Msg("failed to start websecure server") } } diff --git a/webrtc.go b/webrtc.go index 37488f778..a0558df3d 100644 --- a/webrtc.go +++ b/webrtc.go @@ -7,6 +7,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "github.com/coder/websocket" @@ -19,15 +20,42 @@ import ( "github.com/rs/zerolog" ) +// Predefined browser string constants for memory efficiency +var ( + BrowserChrome = "chrome" + BrowserFirefox = "firefox" + BrowserSafari = "safari" + BrowserEdge = "edge" + BrowserOpera = "opera" + BrowserUnknown = "user" +) + type Session struct { + ID string + Mode SessionMode + Source string + Identity string + Nickname string + Browser *string // Pointer to predefined browser string constant for memory efficiency + CreatedAt time.Time + LastActive time.Time + LastBroadcast time.Time // Per-session broadcast throttle + + // RPC rate limiting (DoS protection) + rpcRateLimitMu sync.Mutex // Protects rate limit fields + rpcRateLimit int // Count of RPCs in current window + rpcRateLimitWin time.Time // Start of current rate limit window + lastBroadcastMu sync.Mutex // Protects LastBroadcast field + peerConnection *webrtc.PeerConnection VideoTrack *webrtc.TrackLocalStaticSample ControlChannel *webrtc.DataChannel RPCChannel *webrtc.DataChannel HidChannel *webrtc.DataChannel shouldUmountVirtualMedia bool - - rpcQueue chan webrtc.DataChannelMessage + flushCandidates func() // Callback to flush buffered ICE candidates + ws *websocket.Conn // WebSocket for critical signaling when RPC unavailable + rpcQueue chan webrtc.DataChannelMessage hidRPCAvailable bool lastKeepAliveArrivalTime time.Time // Track when last keep-alive packet arrived @@ -39,32 +67,38 @@ type Session struct { keysDownStateQueue chan usbgadget.KeysDownState } -var ( - actionSessions int = 0 - activeSessionsMutex = &sync.Mutex{} -) - -func incrActiveSessions() int { - activeSessionsMutex.Lock() - defer activeSessionsMutex.Unlock() +var actionSessions atomic.Int32 - actionSessions++ - return actionSessions +func incrActiveSessions() int32 { + return actionSessions.Add(1) } -func decrActiveSessions() int { - activeSessionsMutex.Lock() - defer activeSessionsMutex.Unlock() - - actionSessions-- - return actionSessions +func getActiveSessions() int32 { + return actionSessions.Load() } -func getActiveSessions() int { - activeSessionsMutex.Lock() - defer activeSessionsMutex.Unlock() +// CheckRPCRateLimit checks if the session has exceeded RPC rate limits (DoS protection) +func (s *Session) CheckRPCRateLimit() bool { + const ( + maxRPCPerSecond = 500 // Increased to support 10+ concurrent sessions with broadcasts and state updates + rateLimitWindow = time.Second + ) + + s.rpcRateLimitMu.Lock() + defer s.rpcRateLimitMu.Unlock() + + now := time.Now() + // Reset window if it has expired + if now.Sub(s.rpcRateLimitWin) > rateLimitWindow { + s.rpcRateLimit = 0 + s.rpcRateLimitWin = now + } - return actionSessions + s.rpcRateLimit++ + if s.rpcRateLimit > maxRPCPerSecond { + return false // Rate limit exceeded + } + return true // Within limits } func (s *Session) resetKeepAliveTime() { @@ -74,6 +108,25 @@ func (s *Session) resetKeepAliveTime() { s.lastTimerResetTime = time.Time{} // Reset auto-release timer tracking } +// sendWebSocketSignal sends critical state changes via WebSocket (fallback when RPC channel stale) +func (s *Session) sendWebSocketSignal(messageType string, data map[string]interface{}) error { + if s == nil || s.ws == nil { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := wsjson.Write(ctx, s.ws, gin.H{"type": messageType, "data": data}) + if err != nil { + webrtcLogger.Debug().Err(err).Str("sessionId", s.ID).Msg("Failed to send WebSocket signal") + return err + } + + webrtcLogger.Info().Str("sessionId", s.ID).Str("messageType", messageType).Msg("Sent WebSocket signal") + return nil +} + type hidQueueMessage struct { webrtc.DataChannelMessage channel string @@ -83,6 +136,7 @@ type SessionConfig struct { ICEServers []string LocalIP string IsCloud bool + UserAgent string // User agent for browser detection and nickname generation ws *websocket.Conn Logger *zerolog.Logger } @@ -246,7 +300,11 @@ func newSession(config SessionConfig) (*Session, error) { return nil, err } - session := &Session{peerConnection: peerConnection} + session := &Session{ + peerConnection: peerConnection, + Browser: extractBrowserFromUserAgent(config.UserAgent), + ws: config.ws, + } session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) session.initQueues() session.initKeysDownStateQueue() @@ -283,16 +341,22 @@ func newSession(config SessionConfig) (*Session, error) { case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) { - // Enqueue to ensure ordered processing + queueLen := len(session.rpcQueue) + if queueLen > 200 { + scopedLogger.Warn(). + Str("sessionID", session.ID). + Int("queueLen", queueLen). + Msg("RPC queue approaching capacity") + } session.rpcQueue <- msg }) triggerOTAStateUpdate() triggerVideoStateUpdate() triggerUSBStateUpdate() case "terminal": - handleTerminalChannel(d) + handleTerminalChannel(d, session) case "serial": - handleSerialChannel(d) + handleSerialChannel(d, session) default: if strings.HasPrefix(d.Label(), uploadIdPrefix) { go handleUploadChannel(d) @@ -325,67 +389,147 @@ func newSession(config SessionConfig) (*Session, error) { }() var isConnected bool + // Buffer to hold ICE candidates until answer is sent + var candidateBuffer []webrtc.ICECandidateInit + var candidateBufferMutex sync.Mutex + var answerSent bool + peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { scopedLogger.Info().Interface("candidate", candidate).Msg("WebRTC peerConnection has a new ICE candidate") if candidate != nil { - err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()}) + candidateBufferMutex.Lock() + if !answerSent { + // Buffer the candidate until answer is sent + candidateBuffer = append(candidateBuffer, candidate.ToJSON()) + candidateBufferMutex.Unlock() + return + } + candidateBufferMutex.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := wsjson.Write(ctx, config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()}) if err != nil { scopedLogger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel") } } }) - peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - scopedLogger.Info().Str("connectionState", connectionState.String()).Msg("ICE Connection State has changed") - if connectionState == webrtc.ICEConnectionStateConnected { - if !isConnected { - isConnected = true - onActiveSessionsChanged() - if incrActiveSessions() == 1 { - onFirstSessionConnected() - } + // Store the callback to flush buffered candidates + session.flushCandidates = func() { + candidateBufferMutex.Lock() + answerSent = true + // Send all buffered candidates + for _, candidate := range candidateBuffer { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := wsjson.Write(ctx, config.ws, gin.H{"type": "new-ice-candidate", "data": candidate}) + cancel() + if err != nil { + scopedLogger.Warn().Err(err).Msg("failed to write buffered new-ice-candidate to WebRTC signaling channel") } } - //state changes on closing browser tab disconnected->failed, we need to manually close it - if connectionState == webrtc.ICEConnectionStateFailed { - scopedLogger.Debug().Msg("ICE Connection State is failed, closing peerConnection") - _ = peerConnection.Close() + candidateBuffer = nil + candidateBufferMutex.Unlock() + } + + // Track cleanup state to prevent double cleanup + var cleanedUp bool + var cleanupMutex sync.Mutex + + cleanupSession := func(reason string) { + cleanupMutex.Lock() + defer cleanupMutex.Unlock() + + if cleanedUp { + return + } + cleanedUp = true + + scopedLogger.Info(). + Str("sessionID", session.ID). + Str("reason", reason). + Msg("Cleaning up session") + + // Remove from session manager + sessionManager.RemoveSession(session.ID) + + // Cancel any ongoing keyboard macro if session has permission + if session.HasPermission(PermissionPaste) { + cancelKeyboardMacro() } - if connectionState == webrtc.ICEConnectionStateClosed { - scopedLogger.Debug().Msg("ICE Connection State is closed, unmounting virtual media") - if session == currentSession { - // Cancel any ongoing keyboard report multi when session closes - cancelKeyboardMacro() - currentSession = nil - } - // Stop RPC processor - if session.rpcQueue != nil { - close(session.rpcQueue) - session.rpcQueue = nil - } - // Stop HID RPC processor - for i := 0; i < len(session.hidQueue); i++ { + // Stop RPC processor + if session.rpcQueue != nil { + close(session.rpcQueue) + session.rpcQueue = nil + } + + // Stop HID RPC processor + for i := 0; i < len(session.hidQueue); i++ { + if session.hidQueue[i] != nil { close(session.hidQueue[i]) session.hidQueue[i] = nil } + } + if session.keysDownStateQueue != nil { close(session.keysDownStateQueue) session.keysDownStateQueue = nil + } - if session.shouldUmountVirtualMedia { - if err := rpcUnmountImage(); err != nil { - scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close") - } + if session.shouldUmountVirtualMedia { + if err := rpcUnmountImage(); err != nil { + scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close") } - if isConnected { - isConnected = false + } + + if isConnected { + isConnected = false + newCount := actionSessions.Add(-1) + onActiveSessionsChanged() + if newCount == 0 { + onLastSessionDisconnected() + } + } + } + + peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { + scopedLogger.Info(). + Str("sessionID", session.ID). + Str("connectionState", connectionState.String()). + Msg("ICE Connection State has changed") + + if connectionState == webrtc.ICEConnectionStateConnected { + if !isConnected { + isConnected = true onActiveSessionsChanged() - if decrActiveSessions() == 0 { - onLastSessionDisconnected() + if incrActiveSessions() == 1 { + onFirstSessionConnected() } } } + + // Handle disconnection and failure states + if connectionState == webrtc.ICEConnectionStateDisconnected { + scopedLogger.Info(). + Str("sessionID", session.ID). + Msg("ICE Connection State is disconnected, connection may recover") + } + + if connectionState == webrtc.ICEConnectionStateFailed { + scopedLogger.Info(). + Str("sessionID", session.ID). + Msg("ICE Connection State is failed, closing peerConnection and cleaning up") + cleanupSession("ice-failed") + _ = peerConnection.Close() + } + + if connectionState == webrtc.ICEConnectionStateClosed { + scopedLogger.Info(). + Str("sessionID", session.ID). + Msg("ICE Connection State is closed, cleaning up") + cleanupSession("ice-closed") + } }) return session, nil }