diff --git a/cmd/tatanka/main.go b/cmd/tatanka/main.go index 7b5543d..f804bcb 100644 --- a/cmd/tatanka/main.go +++ b/cmd/tatanka/main.go @@ -1,32 +1,38 @@ package main import ( + "bufio" "context" + "errors" "fmt" "io" "os" "os/signal" "path/filepath" + "strings" "syscall" + "github.com/bisoncraft/mesh/tatanka" "github.com/decred/slog" "github.com/jessevdk/go-flags" "github.com/jrick/logrotate/rotator" - "github.com/bisoncraft/mesh/tatanka" + "github.com/libp2p/go-libp2p/core/peer" ) var log slog.Logger // Config defines the configuration options for the Tatanka node. type Config struct { - AppDataDir string `short:"A" long:"appdata" description:"Path to application home directory."` - ConfigFile string `short:"C" long:"configfile" description:"Path to configuration file."` - DebugLevel string `short:"d" long:"debuglevel" description:"Logging level {trace, debug, info, warn, error, critical}."` - ListenIP string `long:"listenip" description:"IP address to listen on."` - ListenPort int `long:"listenport" description:"Port to listen on."` - MetricsPort int `long:"metricsport" description:"Port to scrape metrics and fetch profiles from."` - AdminPort int `long:"adminport" description:"Port to expose the admin interface on."` - WhitelistPath string `long:"whitelistpath" description:"Path to local whitelist file."` + AppDataDir string `short:"A" long:"appdata" description:"Path to application home directory."` + ConfigFile string `short:"C" long:"configfile" description:"Path to configuration file."` + DebugLevel string `short:"d" long:"debuglevel" description:"Logging level {trace, debug, info, warn, error, critical}."` + ListenIP string `long:"listenip" description:"IP address to listen on."` + ListenPort int `long:"listenport" description:"Port to listen on."` + MetricsPort int `long:"metricsport" description:"Port to scrape metrics and fetch profiles from."` + AdminPort int `long:"adminport" description:"Port to expose the admin interface on."` + Whitelist string `long:"whitelist" description:"Path to whitelist file (plain text: one peer ID per line). On first run, installs it into the data directory. On subsequent runs, validates it matches the existing whitelist unless --forcewl is set."` + ForceWhitelist bool `long:"forcewl" description:"Overwrite any existing whitelist in the data directory with the provided --whitelist file."` + Bootstrap []string `long:"bootstrap" description:"Bootstrap peer address in multiaddr format (/ip4/.../tcp/.../p2p/12D3KooW...). Can be specified multiple times."` // Oracle Configuration CMCKey string `long:"cmckey" description:"coinmarketcap API key"` @@ -49,20 +55,26 @@ func initLogRotator(dir string) (*rotator.Rotator, error) { } func main() { + // Handle "init" subcommand before parsing flags. + if len(os.Args) > 1 && os.Args[1] == "init" { + runInit() + return + } + // Default config values cfg := Config{ - AppDataDir: defaultAppDataDir(), - ConfigFile: defaultConfigFile(), - DebugLevel: "info", - ListenIP: "0.0.0.0", - ListenPort: 12345, - MetricsPort: 12355, - WhitelistPath: "", - CMCKey: "", + AppDataDir: defaultAppDataDir(), + ConfigFile: defaultConfigFile(), + DebugLevel: "info", + ListenIP: "0.0.0.0", + ListenPort: 12345, + MetricsPort: 12355, + CMCKey: "", } // Parse command-line flags (overrides file values) parser := flags.NewParser(&cfg, flags.Default) + parser.Usage = "[OPTIONS]\n\nSubcommands:\n init\tGenerate (or load) a private key and print the peer ID" if _, err := parser.Parse(); err != nil { if e, ok := err.(*flags.Error); ok && e.Type == flags.ErrHelp { os.Exit(0) @@ -98,15 +110,49 @@ func main() { log.Infof("Using app data directory: %s", cfg.AppDataDir) + if cfg.ForceWhitelist && cfg.Whitelist == "" { + log.Errorf("--forcewl requires --whitelist") + os.Exit(1) + } + + var whitelistPeers []peer.ID + if cfg.Whitelist != "" { + var err error + whitelistPeers, err = loadWhitelistPeerIDsList(cfg.Whitelist) + if err != nil { + log.Errorf("Failed to load whitelist: %v", err) + os.Exit(1) + } + } + + if cfg.Whitelist != "" { + log.Infof("Loaded %d whitelist peer IDs from %s", len(whitelistPeers), cfg.Whitelist) + } + + if cfg.ForceWhitelist { + log.Infof("Whitelist force-update enabled") + } + + if cfg.Whitelist == "" { + log.Infof("No whitelist file provided; will load existing whitelist from data directory") + } + + if len(whitelistPeers) == 0 && cfg.Whitelist != "" { + log.Errorf("No peer IDs found in whitelist file %s", cfg.Whitelist) + os.Exit(1) + } + // Create Tatanka config tatankaCfg := &tatanka.Config{ - DataDir: cfg.AppDataDir, - Logger: log, - ListenIP: cfg.ListenIP, - ListenPort: cfg.ListenPort, - MetricsPort: cfg.MetricsPort, - WhitelistPath: cfg.WhitelistPath, - AdminPort: cfg.AdminPort, + DataDir: cfg.AppDataDir, + Logger: log, + ListenIP: cfg.ListenIP, + ListenPort: cfg.ListenPort, + MetricsPort: cfg.MetricsPort, + AdminPort: cfg.AdminPort, + BootstrapAddrs: cfg.Bootstrap, + WhitelistPeers: whitelistPeers, + ForceWhitelist: cfg.ForceWhitelist, CMCKey: cfg.CMCKey, TatumKey: cfg.TatumKey, BlockcypherToken: cfg.BlockcypherToken, @@ -140,6 +186,66 @@ func main() { } } +func loadWhitelistPeerIDsList(path string) ([]peer.ID, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + lineNum := 0 + ids := make([]peer.ID, 0) + seen := make(map[peer.ID]struct{}) + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + if idx := strings.IndexByte(line, '#'); idx >= 0 { + line = strings.TrimSpace(line[:idx]) + } + if line == "" { + continue + } + pid, err := peer.Decode(line) + if err != nil { + return nil, fmt.Errorf("invalid peer ID on line %d: %w", lineNum, err) + } + if _, ok := seen[pid]; ok { + continue + } + seen[pid] = struct{}{} + ids = append(ids, pid) + } + if err := scanner.Err(); err != nil { + return nil, err + } + if len(ids) == 0 { + return nil, errors.New("no peer IDs found") + } + return ids, nil +} + +// runInit generates (or loads) a private key and prints the peer ID. +func runInit() { + // Parse a minimal set of flags for init. + type initConfig struct { + AppDataDir string `short:"A" long:"appdata" description:"Path to application home directory."` + } + initCfg := &initConfig{ + AppDataDir: defaultAppDataDir(), + } + parser := flags.NewParser(initCfg, flags.Default) + parser.Parse() + + pid, err := tatanka.InitTatankaNode(initCfg.AppDataDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Init failed: %v\n", err) + os.Exit(1) + } + + fmt.Println(pid.String()) +} + // defaultAppDataDir returns the default application data directory. func defaultAppDataDir() string { homeDir, _ := os.UserHomeDir() diff --git a/cmd/tatankactl/api.go b/cmd/tatankactl/api.go index 09c2a7c..34acdf0 100644 --- a/cmd/tatankactl/api.go +++ b/cmd/tatankactl/api.go @@ -1,8 +1,11 @@ package main import ( + "bytes" "encoding/json" "fmt" + "io" + "net/http" "net/url" "sort" "strings" @@ -13,6 +16,7 @@ import ( "github.com/gorilla/websocket" "github.com/bisoncraft/mesh/oracle" "github.com/bisoncraft/mesh/tatanka/admin" + "github.com/bisoncraft/mesh/tatanka/types" ) // --- Navigation messages --- @@ -27,11 +31,13 @@ const ( viewOracleDetail viewOracleAggregated viewAggregatedDetail + viewWhitelist + viewWhitelistEditor ) type navigateMsg struct{ view viewID } type navigateBackMsg struct{} -type navigateToDiffMsg struct{ node admin.NodeInfo } +type navigateToDiffMsg struct{ node admin.PeerInfo } type navigateToSourceDetailMsg struct { sourceName string } @@ -39,6 +45,11 @@ type navigateToAggregatedDetailMsg struct { dataType oracle.DataType key string // ticker or network name } +type navigateToWhitelistEditorMsg struct{} + +type proposeResultMsg struct{ err error } +type clearProposalResultMsg struct{ err error } +type adoptResultMsg struct{ err error } // --- Data messages --- @@ -46,6 +57,10 @@ type adminStateMsg struct { state *admin.AdminState } +type peerUpdateMsg struct{ peer admin.PeerInfo } +type whitelistStateUpdateMsg struct{ state *types.WhitelistState } +type whitelistUpdateMsg struct{ update admin.WhitelistUpdate } + type wsConnectedMsg struct{} type wsErrorMsg struct{ err error } type wsReconnectMsg struct{} @@ -287,6 +302,24 @@ func (c *apiClient) connectWebSocket(ch chan<- tea.Msg) tea.Cmd { continue } msg = adminStateMsg{state: &state} + case "peer_update": + var pi admin.PeerInfo + if err := json.Unmarshal(envelope.Data, &pi); err != nil { + continue + } + msg = peerUpdateMsg{peer: pi} + case "whitelist_state": + var ws types.WhitelistState + if err := json.Unmarshal(envelope.Data, &ws); err != nil { + continue + } + msg = whitelistStateUpdateMsg{state: &ws} + case "whitelist_update": + var wu admin.WhitelistUpdate + if err := json.Unmarshal(envelope.Data, &wu); err != nil { + continue + } + msg = whitelistUpdateMsg{update: wu} case "oracle_snapshot": var snapshot oracleSnapshotMsg if err := json.Unmarshal(envelope.Data, &snapshot); err != nil { @@ -342,3 +375,217 @@ func listenForWSUpdates(ch <-chan tea.Msg) tea.Cmd { return msg } } + +func (c *apiClient) proposeWhitelist(peers []string) tea.Cmd { + return func() tea.Msg { + body, _ := json.Marshal(map[string][]string{"peers": peers}) + resp, err := http.Post(c.address+"/admin/propose-whitelist", "application/json", bytes.NewReader(body)) + if err != nil { + return proposeResultMsg{err: err} + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return proposeResultMsg{err: fmt.Errorf("%s", strings.TrimSpace(string(b)))} + } + return proposeResultMsg{} + } +} + +func (c *apiClient) adoptWhitelist(peerID string) tea.Cmd { + return func() tea.Msg { + body, _ := json.Marshal(map[string]string{"peer_id": peerID}) + resp, err := http.Post(c.address+"/admin/adopt-whitelist", "application/json", bytes.NewReader(body)) + if err != nil { + return adoptResultMsg{err: err} + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return adoptResultMsg{err: fmt.Errorf("%s", strings.TrimSpace(string(b)))} + } + return adoptResultMsg{} + } +} + +func (c *apiClient) clearProposal() tea.Cmd { + return func() tea.Msg { + req, _ := http.NewRequest(http.MethodDelete, c.address+"/admin/propose-whitelist", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return clearProposalResultMsg{err: err} + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return clearProposalResultMsg{err: fmt.Errorf("%s", strings.TrimSpace(string(b)))} + } + return clearProposalResultMsg{} + } +} + +// --- Whitelist helpers --- + +// peerIDStrings returns sorted peer ID strings from a Whitelist. +func peerIDStrings(wl *types.Whitelist) []string { + if wl == nil { + return nil + } + strs := make([]string, 0, len(wl.PeerIDs)) + for pid := range wl.PeerIDs { + strs = append(strs, pid.String()) + } + sort.Strings(strs) + return strs +} + +// getOurWhitelist returns the sorted peer ID strings from our current whitelist. +func getOurWhitelist(state *admin.AdminState) []string { + if state == nil || state.WhitelistState == nil { + return nil + } + return peerIDStrings(state.WhitelistState.Current) +} + +// getOurProposal returns the sorted peer ID strings from our proposed whitelist. +func getOurProposal(state *admin.AdminState) []string { + if state == nil || state.WhitelistState == nil { + return nil + } + return peerIDStrings(state.WhitelistState.Proposed) +} + +// networkProposal represents a proposed whitelist with its supporters. +type networkProposal struct { + proposedPeerIDs []string + supporters map[string]bool // peerID string -> ready +} + +// computeNetworkProposals derives network proposals from peer states. +func computeNetworkProposals(state *admin.AdminState) []networkProposal { + if state == nil { + return nil + } + + type proposal struct { + peerIDs []string + supporters map[string]bool + } + proposals := make(map[string]*proposal) + + // Include our own proposal from the top-level WhitelistState. + if state.WhitelistState != nil && state.WhitelistState.Proposed != nil { + hash := state.WhitelistState.Proposed.Hash() + np := &proposal{ + peerIDs: peerIDStrings(state.WhitelistState.Proposed), + supporters: make(map[string]bool), + } + np.supporters[state.OurPeerID] = state.WhitelistState.Ready + proposals[hash] = np + } + + for _, peer := range state.Peers { + if peer.WhitelistState == nil || peer.WhitelistState.Proposed == nil { + continue + } + hash := peer.WhitelistState.Proposed.Hash() + np, ok := proposals[hash] + if !ok { + np = &proposal{ + peerIDs: peerIDStrings(peer.WhitelistState.Proposed), + supporters: make(map[string]bool), + } + proposals[hash] = np + } + np.supporters[peer.PeerID] = peer.WhitelistState.Ready + } + + result := make([]networkProposal, 0, len(proposals)) + for _, np := range proposals { + result = append(result, networkProposal{ + proposedPeerIDs: np.peerIDs, + supporters: np.supporters, + }) + } + + sort.Slice(result, func(i, j int) bool { + if len(result[i].supporters) != len(result[j].supporters) { + return len(result[i].supporters) > len(result[j].supporters) + } + return strings.Join(result[i].proposedPeerIDs, ",") < strings.Join(result[j].proposedPeerIDs, ",") + }) + + return result +} + +// overlapPeerInfo describes one overlap peer's consensus status. +type overlapPeerInfo struct { + peerID string + agreeing bool + ready bool + online bool +} + +// consensusInfo summarises consensus progress for a whitelist proposal. +type consensusInfo struct { + overlap []overlapPeerInfo + nOverlap int + threshold int + agreeing int + blocking int // online overlap peers not agreeing +} + +// computeConsensusInfo computes consensus progress for a proposal relative +// to the current whitelist, given its set of supporters. +func computeConsensusInfo(state *admin.AdminState, currentWl, proposedWl []string, supporters map[string]bool) consensusInfo { + currentSet := make(map[string]bool, len(currentWl)) + for _, id := range currentWl { + currentSet[id] = true + } + proposedSet := make(map[string]bool, len(proposedWl)) + for _, id := range proposedWl { + proposedSet[id] = true + } + + // Overlap = peers in both current and proposed whitelists. + var overlapIDs []string + for id := range currentSet { + if proposedSet[id] { + overlapIDs = append(overlapIDs, id) + } + } + sort.Strings(overlapIDs) + + nOverlap := len(overlapIDs) + threshold := (2*nOverlap + 2) / 3 + + info := consensusInfo{nOverlap: nOverlap, threshold: threshold} + + for _, id := range overlapIDs { + opi := overlapPeerInfo{peerID: id} + + _, supports := supporters[id] + if supports { + opi.agreeing = true + opi.ready = supporters[id] + info.agreeing++ + } + + if id == state.OurPeerID { + opi.online = true + if !supports { + info.blocking++ + } + } else if peer, ok := state.Peers[id]; ok && (peer.State == admin.StateConnected || peer.State == admin.StateWhitelistMismatch) { + opi.online = true + if !supports { + info.blocking++ + } + } + + info.overlap = append(info.overlap, opi) + } + + return info +} + diff --git a/cmd/tatankactl/connections.go b/cmd/tatankactl/connections.go index 527fd4f..f3cdd2c 100644 --- a/cmd/tatankactl/connections.go +++ b/cmd/tatankactl/connections.go @@ -13,7 +13,7 @@ import ( type connectionsModel struct { state *admin.AdminState - nodes []admin.NodeInfo + nodes []admin.PeerInfo mismatchIndices []int cursor int // index into mismatchIndices lastUpdate time.Time @@ -50,8 +50,8 @@ func (m connectionsModel) Update(msg tea.Msg) (connectionsModel, tea.Cmd) { } func (m *connectionsModel) sortNodes() { - nodes := make([]admin.NodeInfo, 0, len(m.state.Nodes)) - for _, node := range m.state.Nodes { + nodes := make([]admin.PeerInfo, 0, len(m.state.Peers)) + for _, node := range m.state.Peers { nodes = append(nodes, node) } @@ -71,7 +71,7 @@ func (m *connectionsModel) sortNodes() { m.nodes = nodes m.mismatchIndices = nil for i, n := range nodes { - if n.State == admin.StateWhitelistMismatch { + if n.State == admin.StateWhitelistMismatch && n.WhitelistState != nil { m.mismatchIndices = append(m.mismatchIndices, i) } } @@ -126,6 +126,9 @@ func (m connectionsModel) View() string { for i, node := range m.nodes { icon := getStateIcon(node.State) stateStr := getStateString(node.State) + if node.State == admin.StateWhitelistMismatch && node.WhitelistState == nil { + stateStr = mismatchStyle.Render("Not on peer's whitelist") + } cursorStr := "" if i == selectedNodeIdx { diff --git a/cmd/tatankactl/diff.go b/cmd/tatankactl/diff.go index 4c629ab..f36ab46 100644 --- a/cmd/tatankactl/diff.go +++ b/cmd/tatankactl/diff.go @@ -10,28 +10,36 @@ import ( ) type diffModel struct { - node admin.NodeInfo + node admin.PeerInfo + api *apiClient inBoth []string onlyOurs []string onlyPeers []string scrollOffset int height int + confirming bool + statusMsg string } -func newDiffModel(node admin.NodeInfo, state *admin.AdminState) diffModel { - ourSet := make(map[string]bool) - for _, id := range state.OurWhitelist { +func newDiffModel(node admin.PeerInfo, state *admin.AdminState, api *apiClient) diffModel { + ourWl := getOurWhitelist(state) + ourSet := make(map[string]bool, len(ourWl)) + for _, id := range ourWl { ourSet[id] = true } - peerSet := make(map[string]bool) - for _, id := range node.PeerWhitelist { + var peerWl []string + if node.WhitelistState != nil { + peerWl = peerIDStrings(node.WhitelistState.Current) + } + peerSet := make(map[string]bool, len(peerWl)) + for _, id := range peerWl { peerSet[id] = true } var inBoth, onlyOurs, onlyPeers []string - for _, id := range state.OurWhitelist { + for _, id := range ourWl { if peerSet[id] { inBoth = append(inBoth, id) } else { @@ -39,7 +47,7 @@ func newDiffModel(node admin.NodeInfo, state *admin.AdminState) diffModel { } } - for _, id := range node.PeerWhitelist { + for _, id := range peerWl { if !ourSet[id] { onlyPeers = append(onlyPeers, id) } @@ -51,6 +59,7 @@ func newDiffModel(node admin.NodeInfo, state *admin.AdminState) diffModel { return diffModel{ node: node, + api: api, inBoth: inBoth, onlyOurs: onlyOurs, onlyPeers: onlyPeers, @@ -67,6 +76,16 @@ func (m diffModel) Update(msg tea.Msg) (diffModel, tea.Cmd) { case tea.WindowSizeMsg: m.height = msg.Height case tea.KeyMsg: + if m.confirming { + switch msg.String() { + case "y": + m.confirming = false + return m, m.api.adoptWhitelist(m.node.PeerID) + case "n", "esc": + m.confirming = false + } + return m, nil + } switch msg.String() { case "up", "k": if m.scrollOffset > 0 { @@ -76,6 +95,10 @@ func (m diffModel) Update(msg tea.Msg) (diffModel, tea.Cmd) { if m.scrollOffset < m.maxOffset() { m.scrollOffset++ } + case "f": + if m.node.WhitelistState != nil { + m.confirming = true + } case "esc", "q": return m, navBack() } @@ -116,6 +139,22 @@ func (m diffModel) View() string { lines = append(lines, headerStyle.Render(fmt.Sprintf(" Whitelist Diff \u2014 %s", m.node.PeerID)), "", + ) + + if m.statusMsg != "" { + lines = append(lines, fmt.Sprintf(" %s", disconnectedStyle.Render(m.statusMsg)), "") + } + + if m.confirming { + lines = append(lines, + fmt.Sprintf(" %s %s", + cursorStyle.Render("Adopt peer's whitelist?"), + dimStyle.Render("(y/n)")), + "", + ) + } + + lines = append(lines, dimStyle.Render(" "+strings.Repeat("\u2500", 50)), "", ) @@ -147,7 +186,12 @@ func (m diffModel) View() string { lines = append(lines, "") } - lines = append(lines, helpStyle.Render(" \u2191\u2193 Scroll Esc: Back to connections")) + helpParts := []string{"\u2191\u2193 Scroll"} + if m.node.WhitelistState != nil { + helpParts = append(helpParts, "f: Adopt peer's whitelist") + } + helpParts = append(helpParts, "Esc: Back to connections") + lines = append(lines, helpStyle.Render(" "+strings.Join(helpParts, " "))) // Apply scroll maxOffset := len(lines) - m.height + 2 diff --git a/cmd/tatankactl/main.go b/cmd/tatankactl/main.go index 98b282d..dcca177 100644 --- a/cmd/tatankactl/main.go +++ b/cmd/tatankactl/main.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/tatanka/admin" ) // rootModel is the top-level bubbletea model that routes between views. @@ -18,6 +19,8 @@ type rootModel struct { menu menuModel connections connectionsModel diff diffModel + whitelistView whitelistViewModel + whitelistEditor whitelistEditorModel oracle oracleModel oracleDetail oracleDetailModel oracleAggregated oracleAggregatedModel @@ -76,6 +79,12 @@ func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case viewDiff: m.diff, cmd = m.diff.Update(msg) return m, cmd + case viewWhitelist: + m.whitelistView, cmd = m.whitelistView.Update(msg) + return m, cmd + case viewWhitelistEditor: + m.whitelistEditor, cmd = m.whitelistEditor.Update(msg) + return m, cmd case viewOracleSources: m.oracle, cmd = m.oracle.Update(msg) return m, cmd @@ -97,6 +106,11 @@ func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.activeView = viewConnections m.connections.height = m.height return m, nil + case viewWhitelist: + m.whitelistView = newWhitelistViewModel(m.connections.state, m.api) + m.whitelistView.height = m.height + m.activeView = viewWhitelist + return m, nil case viewOracleSources: m.activeView = viewOracleSources m.oracle = newOracleModel(m.oracleData) @@ -115,6 +129,14 @@ func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case viewConnections: m.activeView = viewMenu return m, nil + case viewWhitelist: + m.activeView = viewMenu + return m, nil + case viewWhitelistEditor: + m.whitelistView = newWhitelistViewModel(m.connections.state, m.api) + m.whitelistView.height = m.height + m.activeView = viewWhitelist + return m, nil case viewDiff: m.activeView = viewConnections return m, nil @@ -165,12 +187,103 @@ func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case adminStateMsg: if msg.state != nil { + m.menu.peerID = msg.state.OurPeerID m.connections.state = msg.state m.connections.sortNodes() m.connections.lastUpdate = time.Now() + if m.activeView == viewWhitelist { + m.whitelistView.state = msg.state + m.whitelistView.lastUpdate = time.Now() + m.whitelistView.buildSections() + } + } + return m, listenForWSUpdates(m.wsCh) + + case peerUpdateMsg: + if m.connections.state != nil && msg.peer.PeerID != m.connections.state.OurPeerID { + m.connections.state.Peers[msg.peer.PeerID] = msg.peer + m.connections.sortNodes() + m.connections.lastUpdate = time.Now() + if m.activeView == viewWhitelist { + m.whitelistView.state = m.connections.state + m.whitelistView.lastUpdate = time.Now() + m.whitelistView.buildSections() + } + } + return m, listenForWSUpdates(m.wsCh) + + case whitelistStateUpdateMsg: + if m.connections.state != nil { + m.connections.state.WhitelistState = msg.state + if m.activeView == viewWhitelist { + m.whitelistView.state = m.connections.state + m.whitelistView.lastUpdate = time.Now() + m.whitelistView.buildSections() + } + } + return m, listenForWSUpdates(m.wsCh) + + case whitelistUpdateMsg: + if m.connections.state != nil && msg.update.WhitelistState != nil && msg.update.WhitelistState.Current != nil { + newWl := msg.update.WhitelistState.Current.PeerIDs + // Build string set for O(1) lookup. + newPeers := make(map[string]struct{}, len(newWl)) + for pid := range newWl { + newPeers[pid.String()] = struct{}{} + } + // Remove peers no longer in the whitelist. + for id := range m.connections.state.Peers { + if _, ok := newPeers[id]; !ok { + delete(m.connections.state.Peers, id) + } + } + // Add new peers that aren't already tracked. + for pid := range newWl { + s := pid.String() + if s == m.connections.state.OurPeerID { + continue + } + if _, ok := m.connections.state.Peers[s]; !ok { + m.connections.state.Peers[s] = admin.PeerInfo{ + PeerID: s, + State: admin.StateDisconnected, + } + } + } + m.connections.state.WhitelistState = msg.update.WhitelistState + m.connections.sortNodes() + m.connections.lastUpdate = time.Now() + if m.activeView == viewWhitelist { + m.whitelistView.state = m.connections.state + m.whitelistView.lastUpdate = time.Now() + m.whitelistView.buildSections() + } } return m, listenForWSUpdates(m.wsCh) + case navigateToWhitelistEditorMsg: + m.whitelistEditor = newWhitelistEditorModel(m.connections.state, m.api) + m.whitelistEditor.height = m.height + m.activeView = viewWhitelistEditor + return m, nil + + case proposeResultMsg: + if msg.err != nil { + m.whitelistView.statusMsg = fmt.Sprintf("Propose failed: %v", msg.err) + } else { + m.whitelistView.statusMsg = "Proposal submitted" + } + return m, nil + + case clearProposalResultMsg: + if msg.err != nil { + m.whitelistView.statusMsg = fmt.Sprintf("Clear failed: %v", msg.err) + } else { + m.whitelistView.statusMsg = "Proposal cleared" + } + return m, nil + + case renderTickMsg: if m.isOracleView() { // Re-render for relative time updates @@ -191,10 +304,18 @@ func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, m.oracleAggregatedDetail.Init() case navigateToDiffMsg: - m.diff = newDiffModel(msg.node, m.connections.state) + m.diff = newDiffModel(msg.node, m.connections.state, m.api) m.diff.height = m.height m.activeView = viewDiff return m, m.diff.Init() + + case adoptResultMsg: + if msg.err != nil { + m.diff.statusMsg = fmt.Sprintf("Adopt failed: %v", msg.err) + } else { + m.activeView = viewConnections + } + return m, nil } // Delegate to active view @@ -206,6 +327,10 @@ func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.connections, cmd = m.connections.Update(msg) case viewDiff: m.diff, cmd = m.diff.Update(msg) + case viewWhitelist: + m.whitelistView, cmd = m.whitelistView.Update(msg) + case viewWhitelistEditor: + m.whitelistEditor, cmd = m.whitelistEditor.Update(msg) case viewOracleSources: m.oracle, cmd = m.oracle.Update(msg) case viewOracleDetail: @@ -226,6 +351,10 @@ func (m rootModel) View() string { return m.connections.View() case viewDiff: return m.diff.View() + case viewWhitelist: + return m.whitelistView.View() + case viewWhitelistEditor: + return m.whitelistEditor.View() case viewOracleSources: return m.oracle.View() case viewOracleDetail: diff --git a/cmd/tatankactl/menu.go b/cmd/tatankactl/menu.go index 4a21daa..3a0c0c0 100644 --- a/cmd/tatankactl/menu.go +++ b/cmd/tatankactl/menu.go @@ -12,12 +12,13 @@ type menuModel struct { views []viewID cursor int height int + peerID string } func newMenuModel() menuModel { return menuModel{ - choices: []string{"Connections", "Oracle Sources", "Oracle Data"}, - views: []viewID{viewConnections, viewOracleSources, viewOracleAggregated}, + choices: []string{"Connections", "Whitelist", "Oracle Sources", "Oracle Data"}, + views: []viewID{viewConnections, viewWhitelist, viewOracleSources, viewOracleAggregated}, cursor: 0, } } @@ -55,6 +56,10 @@ func (m menuModel) View() string { var b strings.Builder b.WriteString(titleStyle.Render("Tatanka Admin")) + b.WriteString("\n") + if m.peerID != "" { + b.WriteString(dimStyle.Render(" " + m.peerID)) + } b.WriteString("\n\n") for i, choice := range m.choices { diff --git a/cmd/tatankactl/oracle_aggregated.go b/cmd/tatankactl/oracle_aggregated.go index caf583f..378241a 100644 --- a/cmd/tatankactl/oracle_aggregated.go +++ b/cmd/tatankactl/oracle_aggregated.go @@ -36,7 +36,7 @@ func (m oracleAggregatedModel) Update(msg tea.Msg) (oracleAggregatedModel, tea.C case tea.KeyMsg: if m.filter.active { - if m.filter.handleFilterKey(msg.String()) { + if m.filter.handleFilterKey(msg) { m.buildSections() } return m, nil @@ -127,8 +127,8 @@ func (m oracleAggregatedModel) View() string { m.filter.renderFilterBar(&b) if len(m.sections) == 0 { - if m.filter.text != "" { - b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("No matches for \""+m.filter.text+"\""))) + if m.filter.input.text != "" { + b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("No matches for \""+m.filter.input.text+"\""))) } else if len(m.data.Prices) == 0 && len(m.data.FeeRates) == 0 { b.WriteString(" " + dimStyle.Render("No aggregated data available") + "\n") } diff --git a/cmd/tatankactl/oracle_detail.go b/cmd/tatankactl/oracle_detail.go index 182356c..bd0fbbd 100644 --- a/cmd/tatankactl/oracle_detail.go +++ b/cmd/tatankactl/oracle_detail.go @@ -65,7 +65,7 @@ func (m oracleDetailModel) Update(msg tea.Msg) (oracleDetailModel, tea.Cmd) { m.height = msg.Height case tea.KeyMsg: if m.filter.active { - if m.filter.handleFilterKey(msg.String()) { + if m.filter.handleFilterKey(msg) { m.buildSections() } return m, nil @@ -188,8 +188,8 @@ func (m oracleDetailModel) View() string { m.filter.renderFilterBar(&b) if len(m.sections) == 0 { - if m.filter.text != "" { - b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("No matches for \""+m.filter.text+"\""))) + if m.filter.input.text != "" { + b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("No matches for \""+m.filter.input.text+"\""))) } else { b.WriteString(" " + dimStyle.Render("No data available") + "\n") } diff --git a/cmd/tatankactl/section.go b/cmd/tatankactl/section.go index e5270bd..415bc9e 100644 --- a/cmd/tatankactl/section.go +++ b/cmd/tatankactl/section.go @@ -3,6 +3,8 @@ package main import ( "fmt" "strings" + + tea "github.com/charmbracelet/bubbletea" ) const sectionMaxVisible = 10 @@ -33,8 +35,12 @@ func (s *detailSection) scrollUp() { } func (s *detailSection) cursorDown() { - if s.itemCursor < len(s.lines)-1 { - s.itemCursor++ + // Skip lines with empty keys to only land on selectable items. + for next := s.itemCursor + 1; next < len(s.lines); next++ { + if next < len(s.keys) && s.keys[next] != "" { + s.itemCursor = next + break + } } if s.itemCursor >= s.offset+sectionMaxVisible { s.offset = s.itemCursor - sectionMaxVisible + 1 @@ -42,8 +48,12 @@ func (s *detailSection) cursorDown() { } func (s *detailSection) cursorUp() { - if s.itemCursor > 0 { - s.itemCursor-- + // Skip lines with empty keys to only land on selectable items. + for prev := s.itemCursor - 1; prev >= 0; prev-- { + if prev < len(s.keys) && s.keys[prev] != "" { + s.itemCursor = prev + break + } } if s.itemCursor < s.offset { s.offset = s.itemCursor @@ -111,13 +121,9 @@ func renderSection(b *strings.Builder, sec *detailSection, focused bool) { for i, line := range sec.visibleLines() { absIdx := visibleStart + i if hasCursor && focused && absIdx == sec.itemCursor { - if len(line) > 0 { - b.WriteString(cursorStyle.Render(">") + line[1:] + "\n") - } else { - b.WriteString(cursorStyle.Render(">") + "\n") - } + b.WriteString(cursorStyle.Render(">") + line + "\n") } else { - b.WriteString(line + "\n") + b.WriteString(" " + line + "\n") } } @@ -159,7 +165,7 @@ func buildFilterHelp(sections []detailSection, filter filterState, extra ...stri } parts = append(parts, extra...) parts = append(parts, "/: Filter") - if filter.text != "" { + if filter.input.text != "" { parts = append(parts, "Esc: Clear filter") } else { parts = append(parts, "Esc: Back") @@ -167,55 +173,98 @@ func buildFilterHelp(sections []detailSection, filter filterState, extra ...stri return helpStyle.Render(" " + strings.Join(parts, " ")) } +// textInput is a shared single-line text input with cursor movement support. +type textInput struct { + text string + cursor int +} + +func (t *textInput) handleKey(msg tea.KeyMsg) { + switch msg.String() { + case "left": + if t.cursor > 0 { + t.cursor-- + } + case "right": + if t.cursor < len(t.text) { + t.cursor++ + } + case "home", "ctrl+a": + t.cursor = 0 + case "end", "ctrl+e": + t.cursor = len(t.text) + case "backspace": + if t.cursor > 0 { + t.text = t.text[:t.cursor-1] + t.text[t.cursor:] + t.cursor-- + } + case "delete": + if t.cursor < len(t.text) { + t.text = t.text[:t.cursor] + t.text[t.cursor+1:] + } + case "ctrl+u": + t.text = "" + t.cursor = 0 + default: + if msg.Type == tea.KeyRunes { + s := string(msg.Runes) + t.text = t.text[:t.cursor] + s + t.text[t.cursor:] + t.cursor += len(s) + } + } +} + +func (t *textInput) clear() { + t.text = "" + t.cursor = 0 +} + +func (t *textInput) render() string { + return t.text[:t.cursor] + "\u2588" + t.text[t.cursor:] +} + // filterState manages text filtering shared by multiple views. type filterState struct { active bool - text string + input textInput } func (f *filterState) startFiltering() { f.active = true - f.text = "" + f.input.clear() } func (f *filterState) matches(name string) bool { - if f.text == "" { + if f.input.text == "" { return true } - return strings.Contains(strings.ToUpper(name), strings.ToUpper(f.text)) + return strings.Contains(strings.ToUpper(name), strings.ToUpper(f.input.text)) } // handleFilterKey processes a key press while in filter mode. // Returns true if sections need rebuilding. -func (f *filterState) handleFilterKey(key string) bool { - switch key { +func (f *filterState) handleFilterKey(msg tea.KeyMsg) bool { + switch msg.String() { case "enter": f.active = false return true case "esc": f.active = false - f.text = "" + f.input.clear() return true - case "backspace": - if len(f.text) > 0 { - f.text = f.text[:len(f.text)-1] - return true - } default: - if len(key) == 1 && key[0] >= 32 && key[0] <= 126 { - f.text += key - return true - } + old := f.input.text + f.input.handleKey(msg) + return f.input.text != old } - return false } // handleEscOrQ handles esc/q when not in filter mode. // Returns true if the filter was cleared (sections need rebuilding). // Returns false if navigation back should happen. func (f *filterState) handleEscOrQ() bool { - if f.text != "" { - f.text = "" + if f.input.text != "" { + f.input.clear() return true } return false @@ -224,12 +273,12 @@ func (f *filterState) handleEscOrQ() bool { // renderFilterBar renders the filter input or active filter indicator. func (f *filterState) renderFilterBar(b *strings.Builder) { if f.active { - b.WriteString(fmt.Sprintf(" %s %s\u2588\n\n", + b.WriteString(fmt.Sprintf(" %s %s\n\n", cursorStyle.Render("/"), - f.text)) - } else if f.text != "" { + f.input.render())) + } else if f.input.text != "" { b.WriteString(fmt.Sprintf(" %s %s\n\n", dimStyle.Render("Filter:"), - connectedStyle.Render(f.text))) + connectedStyle.Render(f.input.text))) } } diff --git a/cmd/tatankactl/styles.go b/cmd/tatankactl/styles.go index dc27083..8da7e76 100644 --- a/cmd/tatankactl/styles.go +++ b/cmd/tatankactl/styles.go @@ -61,6 +61,12 @@ var ( diffRedStyle = lipgloss.NewStyle(). Foreground(colorRedFg) + diffAddedBgStyle = lipgloss.NewStyle(). + Background(lipgloss.Color("22")) + + diffRemovedBgStyle = lipgloss.NewStyle(). + Background(lipgloss.Color("52")) + menuBoxStyle = lipgloss.NewStyle(). Border(lipgloss.RoundedBorder()). BorderForeground(colorBorder). diff --git a/cmd/tatankactl/whitelist_editor.go b/cmd/tatankactl/whitelist_editor.go new file mode 100644 index 0000000..038e358 --- /dev/null +++ b/cmd/tatankactl/whitelist_editor.go @@ -0,0 +1,316 @@ +package main + +import ( + "fmt" + "sort" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/bisoncraft/mesh/tatanka/admin" + "github.com/libp2p/go-libp2p/core/peer" +) + +type editorEntry struct { + peerID string + checked bool + inCurrent bool + source string // "current", "network", or "manual" +} + +type whitelistEditorModel struct { + entries []editorEntry + cursor int + confirming bool + adding bool // text input mode for adding a new peer ID + addInput textInput // text being typed + addError string // validation error + height int + api *apiClient + scrollOffset int + ourPeerID string // our own peer ID, cannot be unchecked +} + +func newWhitelistEditorModel(state *admin.AdminState, api *apiClient) whitelistEditorModel { + // Collect union of all known peer IDs. + allIDs := make(map[string]string) // peerID -> source + currentSet := make(map[string]bool) + + ourWl := getOurWhitelist(state) + for _, id := range ourWl { + allIDs[id] = "current" + currentSet[id] = true + } + + if state != nil { + // Add peer whitelist from mismatch nodes. + for _, p := range state.Peers { + if p.State == admin.StateWhitelistMismatch && p.WhitelistState != nil && p.WhitelistState.Current != nil { + for _, id := range peerIDStrings(p.WhitelistState.Current) { + if _, ok := allIDs[id]; !ok { + allIDs[id] = "network" + } + } + } + } + // Add proposed peer IDs from all network proposals. + for _, np := range computeNetworkProposals(state) { + for _, id := range np.proposedPeerIDs { + if _, ok := allIDs[id]; !ok { + allIDs[id] = "network" + } + } + } + } + + // Pre-check the proposed whitelist if one exists, otherwise the current. + preselected := getOurProposal(state) + if preselected == nil { + preselected = ourWl + } + checkedSet := make(map[string]bool, len(preselected)) + for _, id := range preselected { + checkedSet[id] = true + } + + // Sort and build entries. + sorted := make([]string, 0, len(allIDs)) + for id := range allIDs { + sorted = append(sorted, id) + } + sort.Strings(sorted) + + entries := make([]editorEntry, len(sorted)) + for i, id := range sorted { + entries[i] = editorEntry{ + peerID: id, + checked: checkedSet[id], + inCurrent: currentSet[id], + source: allIDs[id], + } + } + + var ourPeerID string + if state != nil { + ourPeerID = state.OurPeerID + } + + return whitelistEditorModel{ + entries: entries, + api: api, + ourPeerID: ourPeerID, + } +} + +func (m whitelistEditorModel) selectedCount() int { + n := 0 + for _, e := range m.entries { + if e.checked { + n++ + } + } + return n +} + +func (m whitelistEditorModel) selectedPeers() []string { + var peers []string + for _, e := range m.entries { + if e.checked { + peers = append(peers, e.peerID) + } + } + return peers +} + +func (m *whitelistEditorModel) addPeerID() { + id := strings.TrimSpace(m.addInput.text) + if id == "" { + return + } + // Validate peer ID format. + if _, err := peer.Decode(id); err != nil { + m.addError = "invalid peer ID" + return + } + // Check for duplicates. + for _, e := range m.entries { + if e.peerID == id { + m.addError = "already in list" + return + } + } + m.entries = append(m.entries, editorEntry{ + peerID: id, + checked: true, + source: "manual", + }) + // Move cursor to the new entry. + m.cursor = len(m.entries) - 1 + visible := m.visibleCount() + if m.cursor >= m.scrollOffset+visible { + m.scrollOffset = m.cursor - visible + 1 + } +} + +func (m whitelistEditorModel) Update(msg tea.Msg) (whitelistEditorModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + case tea.KeyMsg: + if m.confirming { + switch msg.String() { + case "y": + m.confirming = false + peers := m.selectedPeers() + return m, tea.Batch( + m.api.proposeWhitelist(peers), + func() tea.Msg { return navigateBackMsg{} }, + ) + case "n", "esc": + m.confirming = false + } + return m, nil + } + + if m.adding { + switch msg.String() { + case "enter": + m.addPeerID() + if m.addError == "" { + m.adding = false + m.addInput.clear() + } + case "esc": + m.adding = false + m.addInput.clear() + m.addError = "" + default: + old := m.addInput.text + m.addInput.handleKey(msg) + if m.addInput.text != old { + m.addError = "" + } + } + return m, nil + } + + switch msg.String() { + case "up", "k": + if m.cursor > 0 { + m.cursor-- + if m.cursor < m.scrollOffset { + m.scrollOffset = m.cursor + } + } + case "down", "j": + if m.cursor < len(m.entries)-1 { + m.cursor++ + visible := m.visibleCount() + if m.cursor >= m.scrollOffset+visible { + m.scrollOffset = m.cursor - visible + 1 + } + } + case " ": + if m.cursor < len(m.entries) { + e := &m.entries[m.cursor] + // Cannot uncheck our own peer ID. + if e.peerID == m.ourPeerID && e.checked { + break + } + e.checked = !e.checked + } + case "a": + m.adding = true + m.addInput.clear() + m.addError = "" + case "enter": + if m.selectedCount() > 0 { + m.confirming = true + } + case "esc", "q": + return m, navBack() + } + } + return m, nil +} + +func (m whitelistEditorModel) visibleCount() int { + v := m.height - 10 // header, footer, spacing, input line + if v < 3 { + v = 3 + } + return v +} + +func (m whitelistEditorModel) View() string { + var b strings.Builder + + b.WriteString(fmt.Sprintf(" %s\n\n", headerStyle.Render("Propose New Whitelist"))) + + if m.confirming { + b.WriteString(fmt.Sprintf(" %s %s\n\n", + cursorStyle.Render(fmt.Sprintf("Propose whitelist with %d peers?", m.selectedCount())), + dimStyle.Render("(y/n)"))) + } + + if m.adding { + prompt := cursorStyle.Render("+ ") + "Peer ID: " + m.addInput.render() + b.WriteString(fmt.Sprintf(" %s\n", prompt)) + if m.addError != "" { + b.WriteString(fmt.Sprintf(" %s\n", diffRedStyle.Render(" "+m.addError))) + } + b.WriteString("\n") + } + + visible := m.visibleCount() + end := m.scrollOffset + visible + if end > len(m.entries) { + end = len(m.entries) + } + + if m.scrollOffset > 0 { + b.WriteString(dimStyle.Render(" \u25b2 more above") + "\n") + } + + for i := m.scrollOffset; i < end; i++ { + e := m.entries[i] + check := "[ ]" + if e.checked { + check = connectedStyle.Render("[x]") + } + + cursor := " " + if i == m.cursor && !m.adding { + cursor = cursorStyle.Render("> ") + } + + var source string + if e.peerID == m.ourPeerID { + source = connectedStyle.Render(" (self)") + } else { + switch e.source { + case "current": + source = dimStyle.Render(" (current)") + case "network": + source = dimStyle.Render(" (from network)") + case "manual": + source = mismatchStyle.Render(" (added)") + } + } + + b.WriteString(fmt.Sprintf(" %s%s %s%s\n", cursor, check, truncatePeerID(e.peerID), source)) + } + + if end < len(m.entries) { + b.WriteString(dimStyle.Render(" \u25bc more below") + "\n") + } + + b.WriteString(fmt.Sprintf("\n %s\n", dimStyle.Render(fmt.Sprintf(" %d selected", m.selectedCount())))) + + if m.adding { + b.WriteString(helpStyle.Render("\n Enter: Add Esc: Cancel Ctrl+U: Clear")) + } else { + b.WriteString(helpStyle.Render("\n Space: Toggle a: Add peer Enter: Submit Esc: Cancel")) + } + + return fitToHeight(b.String(), m.height) +} diff --git a/cmd/tatankactl/whitelist_view.go b/cmd/tatankactl/whitelist_view.go new file mode 100644 index 0000000..0a7d488 --- /dev/null +++ b/cmd/tatankactl/whitelist_view.go @@ -0,0 +1,472 @@ +package main + +import ( + "fmt" + "slices" + "sort" + "strings" + "time" + + "github.com/bisoncraft/mesh/tatanka/admin" + tea "github.com/charmbracelet/bubbletea" +) + +const ( + wlModeNormal = 0 + wlModeConfirm = 1 + wlModeInfo = 2 +) + +type whitelistViewModel struct { + state *admin.AdminState + api *apiClient + mode int + sections []detailSection + focusedSection int + confirmAction string + confirmCmd tea.Cmd + height int + lastUpdate time.Time + statusMsg string +} + +func newWhitelistViewModel(state *admin.AdminState, api *apiClient) whitelistViewModel { + m := whitelistViewModel{ + state: state, + api: api, + } + m.buildSections() + return m +} + +func (m *whitelistViewModel) buildSections() { + m.sections = nil + m.statusMsg = "" + + ourWl := getOurWhitelist(m.state) + proposedWl := getOurProposal(m.state) + hasProposal := len(proposedWl) > 0 + + // Compute network proposals once for both sections. + netProposals := computeNetworkProposals(m.state) + + // Build our proposed set once for matching against network proposals. + ourProposedSet := make(map[string]bool, len(proposedWl)) + for _, id := range proposedWl { + ourProposedSet[id] = true + } + + var ourSupporters map[string]bool + var otherProposals []networkProposal + for _, np := range netProposals { + if hasProposal && len(np.proposedPeerIDs) == len(proposedWl) && matchesSet(np.proposedPeerIDs, ourProposedSet) { + ourSupporters = np.supporters + } else { + otherProposals = append(otherProposals, np) + } + } + + // Section 1: Our Whitelist + title := "Our Whitelist" + if hasProposal { + title = "Our Whitelist " + mismatchStyle.Render("updating") + } + sec := detailSection{title: title} + if m.state != nil { + ourSet := make(map[string]bool, len(ourWl)) + for _, id := range ourWl { + ourSet[id] = true + } + + if hasProposal { + proposedSet := make(map[string]bool, len(proposedWl)) + for _, id := range proposedWl { + proposedSet[id] = true + } + + // Sort: remaining first, then added, then removed. + var remaining, added, removed []string + for id := range ourSet { + if proposedSet[id] { + remaining = append(remaining, id) + } else { + removed = append(removed, id) + } + } + for id := range proposedSet { + if !ourSet[id] { + added = append(added, id) + } + } + sort.Strings(remaining) + sort.Strings(added) + sort.Strings(removed) + + for _, id := range remaining { + short := truncatePeerID(id) + connIcon := m.connectionIcon(id) + sec.lines = append(sec.lines, fmt.Sprintf(" %s %s", connIcon, short)) + } + for _, id := range added { + short := truncatePeerID(id) + connIcon := m.connectionIcon(id) + sec.lines = append(sec.lines, fmt.Sprintf(" %s %s %s", connIcon, diffGreenStyle.Render("+"), diffAddedBgStyle.Render(short))) + } + for _, id := range removed { + short := truncatePeerID(id) + connIcon := m.connectionIcon(id) + sec.lines = append(sec.lines, fmt.Sprintf(" %s %s %s", connIcon, diffRedStyle.Render("-"), diffRemovedBgStyle.Render(short))) + } + + // Consensus info for our proposal. + if ourSupporters != nil { + sec.lines = append(sec.lines, "") + ci := computeConsensusInfo(m.state, ourWl, proposedWl, ourSupporters) + sec.lines = append(sec.lines, m.renderConsensusLines(ci)...) + } + } else { + for _, id := range ourWl { + short := truncatePeerID(id) + connIcon := m.connectionIcon(id) + sec.lines = append(sec.lines, fmt.Sprintf(" %s %s", connIcon, short)) + } + } + + if len(sec.lines) == 0 { + sec.lines = append(sec.lines, dimStyle.Render(" No peers in whitelist")) + } + } else { + sec.lines = append(sec.lines, dimStyle.Render(" Waiting for data...")) + } + m.sections = append(m.sections, sec) + + // Section 2: Network Proposals — only shown when there are proposals + // from other peers (i.e., not just our own proposal). + if len(otherProposals) > 0 { + propSec := detailSection{title: "Network Proposals"} + for i, np := range otherProposals { + if i > 0 { + propSec.lines = append(propSec.lines, "") + } + + label := fmt.Sprintf(" Proposal %d", i+1) + propSec.lines = append(propSec.lines, headerStyle.Render(label)) + + // Align the key with the header line. + for len(propSec.keys) < len(propSec.lines)-1 { + propSec.keys = append(propSec.keys, "") + } + propSec.keys = append(propSec.keys, fmt.Sprintf("proposal:%d", i)) + + // Show peers + proposedSorted := make([]string, len(np.proposedPeerIDs)) + copy(proposedSorted, np.proposedPeerIDs) + sort.Strings(proposedSorted) + + // Diff against current if possible + if len(ourWl) > 0 { + currentSet := make(map[string]bool, len(ourWl)) + for _, id := range ourWl { + currentSet[id] = true + } + proposedSet := make(map[string]bool, len(proposedSorted)) + for _, id := range proposedSorted { + proposedSet[id] = true + } + + var added, removed []string + for _, id := range proposedSorted { + if !currentSet[id] { + added = append(added, truncatePeerID(id)) + } + } + for _, id := range ourWl { + if !proposedSet[id] { + removed = append(removed, truncatePeerID(id)) + } + } + + if len(added) > 0 || len(removed) > 0 { + var changes []string + for _, id := range added { + changes = append(changes, diffGreenStyle.Render("+"+id)) + } + for _, id := range removed { + changes = append(changes, diffRedStyle.Render("-"+id)) + } + propSec.lines = append(propSec.lines, " "+strings.Join(changes, ", ")) + } else { + propSec.lines = append(propSec.lines, dimStyle.Render(" (same as current whitelist)")) + } + } else { + for _, id := range proposedSorted { + propSec.lines = append(propSec.lines, fmt.Sprintf(" %s", truncatePeerID(id))) + } + } + + // Consensus display + ci := computeConsensusInfo(m.state, ourWl, np.proposedPeerIDs, np.supporters) + propSec.lines = append(propSec.lines, m.renderConsensusLines(ci)...) + } + // Pad keys to match final lines length. + for len(propSec.keys) < len(propSec.lines) { + propSec.keys = append(propSec.keys, "") + } + m.sections = append(m.sections, propSec) + } + + if m.focusedSection >= len(m.sections) { + m.focusedSection = 0 + } +} + +func (m *whitelistViewModel) connectionIcon(peerID string) string { + if m.state == nil { + return dimStyle.Render("\u25cf") + } + if peerID == m.state.OurPeerID { + return getStateIcon(admin.StateConnected) + } + peer, ok := m.state.Peers[peerID] + if !ok { + return dimStyle.Render("\u25cf") + } + return getStateIcon(peer.State) +} + +// matchesSet returns true if every element in ids is in the set. +func matchesSet(ids []string, set map[string]bool) bool { + for _, id := range ids { + if !set[id] { + return false + } + } + return true +} + +// renderConsensusLines renders the consensus progress display for a proposal. +func (m *whitelistViewModel) renderConsensusLines(ci consensusInfo) []string { + var lines []string + + header := fmt.Sprintf(" Consensus (%d of %d overlap agreeing, need %d)", ci.agreeing, ci.nOverlap, ci.threshold) + if ci.agreeing >= ci.threshold && ci.blocking == 0 { + header += " " + connectedStyle.Render("\u2713") + } + lines = append(lines, header) + + if ci.blocking > 0 { + plural := "" + if ci.blocking > 1 { + plural = "s" + } + lines = append(lines, disconnectedStyle.Render(fmt.Sprintf(" \u26a0 %d online peer%s not agreeing \u2014 update blocked", ci.blocking, plural))) + } + + for _, op := range ci.overlap { + connIcon := m.connectionIcon(op.peerID) + short := truncatePeerID(op.peerID) + var status string + if op.agreeing { + if op.ready { + status = connectedStyle.Render("agreeing (ready)") + } else { + status = connectedStyle.Render("agreeing") + } + } else if op.online { + status = disconnectedStyle.Render("not agreeing") + } else { + status = dimStyle.Render("offline") + } + lines = append(lines, fmt.Sprintf(" %s %s %s", connIcon, short, status)) + } + + return lines +} + +func (m whitelistViewModel) Update(msg tea.Msg) (whitelistViewModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + case tea.KeyMsg: + if m.mode == wlModeInfo { + switch msg.String() { + case "esc", "q", "i": + m.mode = wlModeNormal + } + return m, nil + } + + if m.mode == wlModeConfirm { + switch msg.String() { + case "y": + m.mode = wlModeNormal + return m, m.confirmCmd + case "n", "esc": + m.mode = wlModeNormal + m.confirmAction = "" + m.confirmCmd = nil + } + return m, nil + } + + // Normal mode + switch msg.String() { + case "up", "k": + if len(m.sections) > 0 { + sec := &m.sections[m.focusedSection] + if len(sec.keys) > 0 { + sec.cursorUp() + } else { + sec.scrollUp() + } + } + case "down", "j": + if len(m.sections) > 0 { + sec := &m.sections[m.focusedSection] + if len(sec.keys) > 0 { + sec.cursorDown() + } else { + sec.scrollDown() + } + } + case "tab": + if len(m.sections) > 1 { + m.focusedSection = (m.focusedSection + 1) % len(m.sections) + } + case "shift+tab": + if len(m.sections) > 1 { + m.focusedSection = (m.focusedSection - 1 + len(m.sections)) % len(m.sections) + } + case "enter": + // Adopt a network proposal — only when proposals section exists and is focused. + if m.focusedSection > 0 && m.focusedSection < len(m.sections) { + proposalIdx := m.findProposalAtCursor() + if proposalIdx >= 0 { + ourProposal := getOurProposal(m.state) + ourSet := make(map[string]bool, len(ourProposal)) + for _, id := range ourProposal { + ourSet[id] = true + } + var others []networkProposal + for _, np := range computeNetworkProposals(m.state) { + if len(np.proposedPeerIDs) != len(ourProposal) || !matchesSet(np.proposedPeerIDs, ourSet) { + others = append(others, np) + } + } + if proposalIdx < len(others) { + np := others[proposalIdx] + // Cannot adopt a whitelist that removes us. + if m.state != nil && !slices.Contains(np.proposedPeerIDs, m.state.OurPeerID) { + m.statusMsg = "Cannot adopt: proposal does not include our peer" + return m, nil + } + m.mode = wlModeConfirm + m.confirmAction = fmt.Sprintf("Propose whitelist with %d peers?", len(np.proposedPeerIDs)) + m.confirmCmd = m.api.proposeWhitelist(np.proposedPeerIDs) + } + } + } + case "i": + m.mode = wlModeInfo + return m, nil + case "p": + // Open whitelist editor + return m, func() tea.Msg { + return navigateToWhitelistEditorMsg{} + } + case "c": + // Clear our active proposal + if len(getOurProposal(m.state)) > 0 { + m.mode = wlModeConfirm + m.confirmAction = "Clear active proposal?" + m.confirmCmd = m.api.clearProposal() + } + case "esc", "q": + return m, navBack() + } + } + return m, nil +} + +func (m *whitelistViewModel) findProposalAtCursor() int { + if m.focusedSection <= 0 || m.focusedSection >= len(m.sections) { + return -1 + } + sec := &m.sections[m.focusedSection] + if len(sec.keys) == 0 { + return -1 + } + key := sec.selectedKey() + if key == "" { + return -1 + } + var idx int + if _, err := fmt.Sscanf(key, "proposal:%d", &idx); err == nil { + return idx + } + return -1 +} + +func (m whitelistViewModel) View() string { + var b strings.Builder + + // Header + ts := "" + if !m.lastUpdate.IsZero() { + ts = dimStyle.Render(m.lastUpdate.Format("15:04:05")) + } + b.WriteString(fmt.Sprintf(" %s%s\n\n", + headerStyle.Render("Whitelist Management"), + pad(ts, 40))) + + // Info overlay + if m.mode == wlModeInfo { + b.WriteString(dimStyle.Render(" A whitelist update is a two-phase process. First, a node proposes a new") + "\n") + b.WriteString(dimStyle.Render(" whitelist and publishes it to the network. Other nodes that agree adopt the") + "\n") + b.WriteString(dimStyle.Render(" same proposal. Once 2/3 of the overlapping peers (peers in both the current") + "\n") + b.WriteString(dimStyle.Render(" and proposed whitelists) are online and agreeing, all participating nodes mark") + "\n") + b.WriteString(dimStyle.Render(" themselves as ready. If any overlapping peer is online but has not adopted the") + "\n") + b.WriteString(dimStyle.Render(" proposal or supports a different proposal, the update will not proceed. When") + "\n") + b.WriteString(dimStyle.Render(" all online agreeing nodes are ready, the switch executes automatically.") + "\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return fitToHeight(b.String(), m.height) + } + + if m.state == nil { + b.WriteString(" Waiting for data...\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return fitToHeight(b.String(), m.height) + } + + // Status message + if m.statusMsg != "" { + b.WriteString(fmt.Sprintf(" %s\n\n", connectedStyle.Render(m.statusMsg))) + } + + // Confirm mode overlay + if m.mode == wlModeConfirm { + b.WriteString(fmt.Sprintf(" %s %s\n\n", + cursorStyle.Render(m.confirmAction), + dimStyle.Render("(y/n)"))) + } + + // Sections + for i := range m.sections { + renderSection(&b, &m.sections[i], i == m.focusedSection) + } + + // Help + var parts []string + parts = append(parts, "\u2191\u2193 Scroll") + if len(m.sections) > 1 { + parts = append(parts, "Tab: Section", "Enter: Adopt") + } + parts = append(parts, "p: Propose") + if len(getOurProposal(m.state)) > 0 { + parts = append(parts, "c: Clear proposal") + } + parts = append(parts, "i: Info", "Esc: Back") + b.WriteString(helpStyle.Render(" " + strings.Join(parts, " "))) + + return fitToHeight(b.String(), m.height) +} diff --git a/tatanka/admin/http_handlers.go b/tatanka/admin/http_handlers.go new file mode 100644 index 0000000..9304d0d --- /dev/null +++ b/tatanka/admin/http_handlers.go @@ -0,0 +1,121 @@ +package admin + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "slices" +) + +type proposeWhitelistRequest struct { + Peers []string `json:"peers"` +} + +// handleProposeWhitelist sets (POST) or clears (DELETE) a whitelist proposal. +func (s *Server) handleProposeWhitelist(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "POST": + var req proposeWhitelistRequest + if err := decodeStrictJSONBody(w, r, &req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + + if len(req.Peers) == 0 { + http.Error(w, "peers list cannot be empty", http.StatusBadRequest) + return + } + + // A node cannot propose a whitelist that removes itself. + state := s.getState() + if !slices.Contains(req.Peers, state.OurPeerID) { + http.Error(w, "proposed whitelist must include our own peer", http.StatusBadRequest) + return + } + + if err := s.proposeWhitelist(req.Peers); err != nil { + http.Error(w, "failed to set proposal: "+err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "proposal set"}) + + case "DELETE": + s.clearProposal() + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "proposal cleared"}) + + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +type adoptWhitelistRequest struct { + PeerID string `json:"peer_id"` +} + +// handleAdoptWhitelist replaces our whitelist with a mismatched peer's current whitelist. +func (s *Server) handleAdoptWhitelist(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req adoptWhitelistRequest + if err := decodeStrictJSONBody(w, r, &req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + + if req.PeerID == "" { + http.Error(w, "peer_id is required", http.StatusBadRequest) + return + } + + state := s.getState() + pi, found := state.Peers[req.PeerID] + + if !found { + http.Error(w, "peer not found", http.StatusNotFound) + return + } + + if pi.State != StateWhitelistMismatch { + http.Error(w, "peer is not in whitelist_mismatch state", http.StatusBadRequest) + return + } + + if pi.WhitelistState == nil || pi.WhitelistState.Current == nil || len(pi.WhitelistState.Current.PeerIDs) == 0 { + http.Error(w, "peer has no whitelist to adopt", http.StatusBadRequest) + return + } + + // Convert the peer's current whitelist to string slice for forceWhitelist. + peerWlStrings := make([]string, 0, len(pi.WhitelistState.Current.PeerIDs)) + for pid := range pi.WhitelistState.Current.PeerIDs { + peerWlStrings = append(peerWlStrings, pid.String()) + } + + if err := s.forceWhitelist(peerWlStrings); err != nil { + http.Error(w, "failed to adopt whitelist: "+err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "whitelist adopted"}) +} + +func decodeStrictJSONBody(w http.ResponseWriter, r *http.Request, dest any) error { + r.Body = http.MaxBytesReader(w, r.Body, adminRequestBodyLimitBytes) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(dest); err != nil { + return err + } + if err := dec.Decode(&struct{}{}); err != io.EOF { + return errors.New("request body must contain a single JSON object") + } + return nil +} diff --git a/tatanka/admin/server.go b/tatanka/admin/server.go index 7c89d7c..369a675 100644 --- a/tatanka/admin/server.go +++ b/tatanka/admin/server.go @@ -3,14 +3,17 @@ package admin import ( "context" "encoding/json" + "net" "net/http" + "net/url" + "strings" "sync" "time" + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/decred/slog" "github.com/gorilla/websocket" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/bisoncraft/mesh/oracle" ) // NodeConnectionState defines the status of a peer connection @@ -20,32 +23,29 @@ const ( StateConnected NodeConnectionState = "connected" StateDisconnected NodeConnectionState = "disconnected" StateWhitelistMismatch NodeConnectionState = "whitelist_mismatch" + + adminRequestBodyLimitBytes = 1 << 20 // 1 MiB + adminWebSocketReadLimit = 1 << 20 // 1 MiB ) -// NodeInfo contains information about a specific peer -type NodeInfo struct { - PeerID string `json:"peer_id"` - State NodeConnectionState `json:"state"` - Addresses []string `json:"addresses,omitempty"` - PeerWhitelist []string `json:"peer_whitelist,omitempty"` +// PeerInfo contains information about a specific peer +type PeerInfo struct { + PeerID string `json:"peer_id"` + State NodeConnectionState `json:"state"` + Addresses []string `json:"addresses,omitempty"` + WhitelistState *types.WhitelistState `json:"whitelist_state,omitempty"` } // AdminState represents the global admin state type AdminState struct { - Nodes map[string]NodeInfo `json:"nodes"` - OurWhitelist []string `json:"our_whitelist"` + OurPeerID string `json:"our_peer_id"` + WhitelistState *types.WhitelistState `json:"whitelist_state"` + Peers map[string]PeerInfo `json:"peers"` // remote peers only } -func (s AdminState) DeepCopy() AdminState { - newState := AdminState{ - Nodes: make(map[string]NodeInfo, len(s.Nodes)), - OurWhitelist: make([]string, len(s.OurWhitelist)), - } - for k, v := range s.Nodes { - newState.Nodes[k] = v - } - copy(newState.OurWhitelist, s.OurWhitelist) - return newState +// WhitelistUpdate is sent when whitelist membership changes. +type WhitelistUpdate struct { + WhitelistState *types.WhitelistState `json:"whitelist_state"` } // WSMessage is the envelope for all WebSocket messages. @@ -54,25 +54,62 @@ type WSMessage struct { Data json.RawMessage `json:"data"` } -// Client represents a connected WebSocket user. -type Client struct { +// client represents a connected WebSocket user. +type client struct { conn *websocket.Conn send chan WSMessage } +// Config holds the configuration for the admin server. +type Config struct { + Log slog.Logger + Addr string + PeerID string + Oracle Oracle + GetState func() AdminState + ProposeWhitelist func(peers []string) error + ClearProposal func() + ForceWhitelist func(peers []string) error +} + +func (c *Config) verify() { + if c.Log == nil { + panic("admin.Config.Log is nil") + } + if c.Addr == "" { + panic("admin.Config.Addr is empty") + } + if c.Oracle == nil { + panic("admin.Config.Oracle is nil") + } + if c.GetState == nil { + panic("admin.Config.GetState is nil") + } + if c.ProposeWhitelist == nil { + panic("admin.Config.ProposeWhitelist is nil") + } + if c.ClearProposal == nil { + panic("admin.Config.ClearProposal is nil") + } + if c.ForceWhitelist == nil { + panic("admin.Config.ForceWhitelist is nil") + } +} + // Server manages the admin server for a tatanka node. type Server struct { - log slog.Logger + log slog.Logger + oracle Oracle + getState func() AdminState + proposeWhitelist func(peers []string) error + clearProposal func() + forceWhitelist func(peers []string) error + httpServer *http.Server upgrader websocket.Upgrader - stateMtx sync.RWMutex - state AdminState - clientsMtx sync.RWMutex - clients map[*Client]bool - - oracle Oracle + clients map[*client]bool } // Oracle supplies data for admin oracle endpoints. @@ -81,27 +118,37 @@ type Oracle interface { } // NewServer initializes the admin server. -func NewServer(log slog.Logger, addr string, oracle Oracle) *Server { - server := &Server{ - log: log, - upgrader: websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}, - clients: make(map[*Client]bool), - state: AdminState{ - Nodes: make(map[string]NodeInfo), - OurWhitelist: []string{}, +func NewServer(cfg *Config) *Server { + cfg.verify() + return &Server{ + log: cfg.Log, + oracle: cfg.Oracle, + getState: cfg.GetState, + proposeWhitelist: cfg.ProposeWhitelist, + clearProposal: cfg.ClearProposal, + forceWhitelist: cfg.ForceWhitelist, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return allowWebSocketOrigin(r) + }, + }, + clients: make(map[*client]bool), + httpServer: &http.Server{ + Addr: cfg.Addr, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 60 * time.Second, }, - httpServer: &http.Server{Addr: addr}, - oracle: oracle, } - - return server } // Start launches the HTTP server func (s *Server) Start(ctx context.Context) error { mux := http.NewServeMux() - mux.HandleFunc("/admin/state", s.handleGetState) - mux.HandleFunc("/admin/ws", s.handleWebSocket) + mux.HandleFunc("/admin/ws", s.localOnly(s.handleWebSocket)) + mux.HandleFunc("/admin/propose-whitelist", s.localOnly(s.handleProposeWhitelist)) + mux.HandleFunc("/admin/adopt-whitelist", s.localOnly(s.handleAdoptWhitelist)) s.httpServer.Handler = mux s.log.Infof("Starting admin server on %s", s.httpServer.Addr) @@ -116,156 +163,92 @@ func (s *Server) Start(ctx context.Context) error { return s.httpServer.ListenAndServe() } -// UpdateConnectionState updates a specific peer's info and broadcasts to clients -func (s *Server) UpdateConnectionState(peerID peer.ID, state NodeConnectionState, addresses []string, peerWhitelist []string) { - s.stateMtx.Lock() - node := NodeInfo{ - PeerID: peerID.String(), - State: state, - Addresses: addresses, - } - if state == StateWhitelistMismatch { - node.PeerWhitelist = peerWhitelist +func (s *Server) localOnly(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !isLoopbackRequest(r) { + http.Error(w, "admin interface is only available from localhost", http.StatusForbidden) + return + } + next(w, r) } - s.state.Nodes[peerID.String()] = node - snapshot := s.state.DeepCopy() - s.stateMtx.Unlock() - - s.broadcastState(snapshot) -} - -// UpdateWhitelist updates the local whitelist and broadcasts to clients -func (s *Server) UpdateWhitelist(whitelist []string) { - s.stateMtx.Lock() - s.state.OurWhitelist = whitelist - snapshot := s.state.DeepCopy() - s.stateMtx.Unlock() - - s.broadcastState(snapshot) } -// broadcastState sends the admin state to all clients. -func (s *Server) broadcastState(state AdminState) { - data, err := json.Marshal(state) +func isLoopbackRequest(r *http.Request) bool { + host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { - s.log.Errorf("Failed to marshal admin state: %v", err) - return - } - msg := WSMessage{ - Type: "admin_state", - Data: json.RawMessage(data), + host = r.RemoteAddr } - s.broadcast(msg) + host = strings.Trim(host, "[]") + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() } -// BroadcastOracleUpdate broadcasts a typed oracle update to all connected clients. -func (s *Server) BroadcastOracleUpdate(msgType string, snapshotDiff *oracle.OracleSnapshot) { - data, err := json.Marshal(snapshotDiff) - if err != nil { - s.log.Errorf("Failed to marshal oracle update (%s): %v", msgType, err) - return - } - msg := WSMessage{ - Type: msgType, - Data: json.RawMessage(data), +func allowWebSocketOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + // Non-browser clients (e.g. tatankactl) usually don't set Origin. + return true } - s.broadcast(msg) -} - -// broadcast sends a WSMessage to all connected clients. -func (s *Server) broadcast(msg WSMessage) { - s.clientsMtx.RLock() - defer s.clientsMtx.RUnlock() - for client := range s.clients { - select { - case client.send <- msg: - default: - s.log.Errorf("Client buffer full, skipping update") - } + u, err := url.Parse(origin) + if err != nil { + return false } + + return isLocalhostHost(u.Host) } -func (s *Server) handleGetState(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return +func isLocalhostHost(hostPort string) bool { + host := hostPort + if parsedHost, _, err := net.SplitHostPort(hostPort); err == nil { + host = parsedHost } - s.stateMtx.RLock() - state := s.state.DeepCopy() - s.stateMtx.RUnlock() + host = strings.Trim(host, "[]") + if strings.EqualFold(host, "localhost") { + return true + } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(state) + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() } -func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { - conn, err := s.upgrader.Upgrade(w, r, nil) +// BroadcastPeerUpdate sends a single peer update to all connected clients. +func (s *Server) BroadcastPeerUpdate(pi PeerInfo) { + data, err := json.Marshal(pi) if err != nil { - s.log.Warnf("WebSocket upgrade failed: %v", err) + s.log.Errorf("Failed to marshal peer update: %v", err) return } + s.broadcast(WSMessage{Type: "peer_update", Data: json.RawMessage(data)}) +} - client := &Client{ - conn: conn, - send: make(chan WSMessage, 10), +// BroadcastWhitelistState sends our own whitelist state to all connected clients. +func (s *Server) BroadcastWhitelistState(ws *types.WhitelistState) { + data, err := json.Marshal(ws) + if err != nil { + s.log.Errorf("Failed to marshal whitelist state: %v", err) + return } + s.broadcast(WSMessage{Type: "whitelist_state", Data: json.RawMessage(data)}) +} - s.clientsMtx.Lock() - s.clients[client] = true - s.clientsMtx.Unlock() - - // Send initial admin state - s.stateMtx.RLock() - initialState := s.state.DeepCopy() - s.stateMtx.RUnlock() - stateData, err := json.Marshal(initialState) - if err == nil { - select { - case client.send <- WSMessage{Type: "admin_state", Data: json.RawMessage(stateData)}: - default: - } +// BroadcastWhitelistUpdate sends whitelist membership changes to all connected clients. +func (s *Server) BroadcastWhitelistUpdate(wu WhitelistUpdate) { + data, err := json.Marshal(wu) + if err != nil { + s.log.Errorf("Failed to marshal whitelist update: %v", err) + return } + s.broadcast(WSMessage{Type: "whitelist_update", Data: json.RawMessage(data)}) +} - // Send oracle snapshot - snapshot := s.oracle.OracleSnapshot() - if snapshot != nil { - snapshotData, err := json.Marshal(snapshot) - if err == nil { - select { - case client.send <- WSMessage{Type: "oracle_snapshot", Data: json.RawMessage(snapshotData)}: - default: - } - } +// BroadcastOracleUpdate broadcasts a typed oracle update to all connected clients. +func (s *Server) BroadcastOracleUpdate(msgType string, snapshotDiff *oracle.OracleSnapshot) { + data, err := json.Marshal(snapshotDiff) + if err != nil { + s.log.Errorf("Failed to marshal oracle update (%s): %v", msgType, err) + return } - - // 1. Writer Goroutine - go func() { - defer conn.Close() - for msg := range client.send { - if err := conn.WriteJSON(msg); err != nil { - return - } - } - }() - - // 2. Reader Goroutine - go func() { - defer func() { - s.clientsMtx.Lock() - if _, ok := s.clients[client]; ok { - delete(s.clients, client) - close(client.send) - } - s.clientsMtx.Unlock() - conn.Close() - }() - - for { - if _, _, err := conn.ReadMessage(); err != nil { - break - } - } - }() + s.broadcast(WSMessage{Type: msgType, Data: json.RawMessage(data)}) } diff --git a/tatanka/admin/server_test.go b/tatanka/admin/server_test.go new file mode 100644 index 0000000..efa41ba --- /dev/null +++ b/tatanka/admin/server_test.go @@ -0,0 +1,96 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestIsLoopbackRequest(t *testing.T) { + tests := []struct { + name string + remoteAddr string + want bool + }{ + {name: "ipv4 loopback", remoteAddr: "127.0.0.1:9000", want: true}, + {name: "ipv6 loopback", remoteAddr: "[::1]:9000", want: true}, + {name: "non loopback", remoteAddr: "192.168.1.10:9000", want: false}, + {name: "invalid remote addr", remoteAddr: "localhost:9000", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/admin/ws", nil) + req.RemoteAddr = tc.remoteAddr + got := isLoopbackRequest(req) + if got != tc.want { + t.Fatalf("isLoopbackRequest(%q) = %v, want %v", tc.remoteAddr, got, tc.want) + } + }) + } +} + +func TestAllowWebSocketOrigin(t *testing.T) { + tests := []struct { + name string + origin string + want bool + }{ + {name: "empty origin allowed", origin: "", want: true}, + {name: "localhost origin", origin: "http://localhost:2000", want: true}, + {name: "loopback ipv4 origin", origin: "http://127.0.0.1:2000", want: true}, + {name: "loopback ipv6 origin", origin: "http://[::1]:2000", want: true}, + {name: "non loopback origin", origin: "https://example.com", want: false}, + {name: "malformed origin", origin: "://bad-origin", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/admin/ws", nil) + if tc.origin != "" { + req.Header.Set("Origin", tc.origin) + } + got := allowWebSocketOrigin(req) + if got != tc.want { + t.Fatalf("allowWebSocketOrigin(%q) = %v, want %v", tc.origin, got, tc.want) + } + }) + } +} + +func TestLocalOnlyMiddleware(t *testing.T) { + s := &Server{} + + tests := []struct { + name string + remoteAddr string + wantStatus int + wantCalled bool + }{ + {name: "loopback allowed", remoteAddr: "127.0.0.1:1111", wantStatus: http.StatusNoContent, wantCalled: true}, + {name: "non loopback forbidden", remoteAddr: "10.0.0.5:1111", wantStatus: http.StatusForbidden, wantCalled: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodGet, "/admin/propose-whitelist", nil) + req.RemoteAddr = tc.remoteAddr + rr := httptest.NewRecorder() + + s.localOnly(next).ServeHTTP(rr, req) + + if rr.Code != tc.wantStatus { + t.Fatalf("status = %d, want %d", rr.Code, tc.wantStatus) + } + if called != tc.wantCalled { + t.Fatalf("handler called = %v, want %v", called, tc.wantCalled) + } + }) + } +} diff --git a/tatanka/admin/websocket.go b/tatanka/admin/websocket.go new file mode 100644 index 0000000..a846571 --- /dev/null +++ b/tatanka/admin/websocket.go @@ -0,0 +1,89 @@ +package admin + +import ( + "encoding/json" + "net/http" +) + +// broadcast sends a WSMessage to all connected clients. +func (s *Server) broadcast(msg WSMessage) { + s.clientsMtx.RLock() + defer s.clientsMtx.RUnlock() + + for c := range s.clients { + select { + case c.send <- msg: + default: + s.log.Errorf("Client buffer full, skipping update") + } + } +} + +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + s.log.Warnf("WebSocket upgrade failed: %v", err) + return + } + conn.SetReadLimit(adminWebSocketReadLimit) + + c := &client{ + conn: conn, + send: make(chan WSMessage, 10), + } + + s.clientsMtx.Lock() + s.clients[c] = true + s.clientsMtx.Unlock() + + // Send initial admin state + initialState := s.getState() + stateData, err := json.Marshal(initialState) + if err == nil { + select { + case c.send <- WSMessage{Type: "admin_state", Data: json.RawMessage(stateData)}: + default: + } + } + + // Send oracle snapshot + snapshot := s.oracle.OracleSnapshot() + if snapshot != nil { + snapshotData, err := json.Marshal(snapshot) + if err == nil { + select { + case c.send <- WSMessage{Type: "oracle_snapshot", Data: json.RawMessage(snapshotData)}: + default: + } + } + } + + // Writer goroutine + go func() { + defer conn.Close() + for msg := range c.send { + if err := conn.WriteJSON(msg); err != nil { + return + } + } + }() + + // Reader goroutine + go func() { + defer func() { + s.clientsMtx.Lock() + if _, ok := s.clients[c]; ok { + delete(s.clients, c) + close(c.send) + } + s.clientsMtx.Unlock() + conn.Close() + }() + + for { + if _, _, err := conn.ReadMessage(); err != nil { + break + } + } + }() +} diff --git a/tatanka/admin_notifier.go b/tatanka/admin_notifier.go new file mode 100644 index 0000000..f990d23 --- /dev/null +++ b/tatanka/admin_notifier.go @@ -0,0 +1,45 @@ +package tatanka + +import ( + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/tatanka/admin" + "github.com/bisoncraft/mesh/tatanka/types" +) + +// adminNotifier abstracts the admin server notifications so that callers +// don't need nil-checks when the admin UI is disabled. +type adminNotifier interface { + BroadcastOracleUpdate(msgType string, snapshot *oracle.OracleSnapshot) + BroadcastPeerUpdate(pi admin.PeerInfo) + BroadcastWhitelistState(ws *types.WhitelistState) + BroadcastWhitelistUpdate(wu admin.WhitelistUpdate) +} + +// noopAdminNotifier is used when the admin UI is disabled (AdminPort <= 0). +type noopAdminNotifier struct{} + +func (noopAdminNotifier) BroadcastOracleUpdate(string, *oracle.OracleSnapshot) {} +func (noopAdminNotifier) BroadcastPeerUpdate(admin.PeerInfo) {} +func (noopAdminNotifier) BroadcastWhitelistState(*types.WhitelistState) {} +func (noopAdminNotifier) BroadcastWhitelistUpdate(admin.WhitelistUpdate) {} + +// liveAdminNotifier delegates to a real admin.Server. +type liveAdminNotifier struct { + server *admin.Server +} + +func (n *liveAdminNotifier) BroadcastOracleUpdate(msgType string, snapshot *oracle.OracleSnapshot) { + n.server.BroadcastOracleUpdate(msgType, snapshot) +} + +func (n *liveAdminNotifier) BroadcastPeerUpdate(pi admin.PeerInfo) { + n.server.BroadcastPeerUpdate(pi) +} + +func (n *liveAdminNotifier) BroadcastWhitelistState(ws *types.WhitelistState) { + n.server.BroadcastWhitelistState(ws) +} + +func (n *liveAdminNotifier) BroadcastWhitelistUpdate(wu admin.WhitelistUpdate) { + n.server.BroadcastWhitelistUpdate(wu) +} diff --git a/tatanka/file.go b/tatanka/file.go new file mode 100644 index 0000000..4ab9dfb --- /dev/null +++ b/tatanka/file.go @@ -0,0 +1,17 @@ +package tatanka + +import "os" + +// atomicWriteFile writes data to a file atomically by writing to a temporary +// file first and then renaming it to the target path. +func atomicWriteFile(path string, data []byte) error { + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0600); err != nil { + return err + } + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return err + } + return nil +} diff --git a/tatanka/gossipsub.go b/tatanka/gossipsub.go index 641f6b6..3d9ba4c 100644 --- a/tatanka/gossipsub.go +++ b/tatanka/gossipsub.go @@ -5,15 +5,16 @@ import ( "errors" "fmt" + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/oracle/sources" + protocolsPb "github.com/bisoncraft/mesh/protocols/pb" + pb "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/decred/slog" "github.com/klauspost/compress/zstd" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" - "github.com/bisoncraft/mesh/oracle" - "github.com/bisoncraft/mesh/oracle/sources" - protocolsPb "github.com/bisoncraft/mesh/protocols/pb" - pb "github.com/bisoncraft/mesh/tatanka/pb" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) @@ -34,6 +35,9 @@ const ( // quotaHeartbeatTopicName is the name of the pubsub topic used to // periodically share quota information between tatanka nodes. quotaHeartbeatTopicName = "quota_heartbeat" + + // whitelistUpdatesTopicName is the pubsub topic for whitelist updates. + whitelistUpdatesTopicName = "whitelist_updates" ) type clientConnectionUpdate struct { @@ -79,6 +83,7 @@ type gossipSubCfg struct { handleClientConnectionMessage func(update *clientConnectionUpdate) handleOracleUpdate func(senderID peer.ID, update *pb.NodeOracleUpdate) handleQuotaHeartbeat func(senderID peer.ID, heartbeat *pb.QuotaHandshake) + handleWhitelistUpdate func(senderID peer.ID, ws *types.WhitelistState, timestamp int64) } // gossipSub manages the nodes connection to a gossip sub network between tatanka @@ -92,6 +97,7 @@ type gossipSub struct { clientConnectionsTopic *pubsub.Topic oracleUpdatesTopic *pubsub.Topic quotaHeartbeatTopic *pubsub.Topic + whitelistUpdatesTopic *pubsub.Topic zstdEncoder *zstd.Encoder zstdDecoder *zstd.Decoder } @@ -111,10 +117,6 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { return nil, err } - // Currently using a single topic for all client messages, but - // we could use separate topics for each client topic to only - // send messages to tatanka nodes that have clients subscribed to - // a certain topic. clientMessageTopic, err := ps.Join(clientMessageTopicName) if err != nil { return nil, fmt.Errorf("failed to join client message topic: %w", err) @@ -135,6 +137,11 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { return nil, fmt.Errorf("failed to join quota heartbeat topic: %w", err) } + whitelistUpdatesTopic, err := ps.Join(whitelistUpdatesTopicName) + if err != nil { + return nil, fmt.Errorf("failed to join whitelist updates topic: %w", err) + } + zstdEncoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) if err != nil { return nil, fmt.Errorf("failed to create zstd encoder: %w", err) @@ -146,22 +153,22 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { } return &gossipSub{ - log: cfg.log, - ps: ps, - cfg: cfg, - clientMessageTopic: clientMessageTopic, - clientConnectionsTopic: clientConnectionsTopic, - oracleUpdatesTopic: oracleUpdatesTopic, - quotaHeartbeatTopic: quotaHeartbeatTopic, - zstdEncoder: zstdEncoder, - zstdDecoder: zstdDecoder, + log: cfg.log, + ps: ps, + cfg: cfg, + clientMessageTopic: clientMessageTopic, + clientConnectionsTopic: clientConnectionsTopic, + oracleUpdatesTopic: oracleUpdatesTopic, + quotaHeartbeatTopic: quotaHeartbeatTopic, + whitelistUpdatesTopic: whitelistUpdatesTopic, + zstdEncoder: zstdEncoder, + zstdDecoder: zstdDecoder, }, nil } // listenForClientMessages subscribes to the pubsub client messages topic, and // distributes messages to subscribed clients as they come in. func (gs *gossipSub) listenForClientMessages(ctx context.Context) error { - // TODO: configure buffer size if needed sub, err := gs.clientMessageTopic.Subscribe() if err != nil { return fmt.Errorf("failed to subscribe to client message topic: %w", err) @@ -313,6 +320,32 @@ func (gs *gossipSub) listenForQuotaHeartbeats(ctx context.Context) error { } } +func (gs *gossipSub) listenForWhitelistUpdates(ctx context.Context) error { + sub, err := gs.whitelistUpdatesTopic.Subscribe() + if err != nil { + return fmt.Errorf("failed to subscribe to whitelist updates topic: %w", err) + } + + for { + msg, err := sub.Next(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + + if msg != nil && gs.cfg.handleWhitelistUpdate != nil { + update := &pb.WhitelistState{} + if err := proto.Unmarshal(msg.Data, update); err != nil { + gs.log.Errorf("Failed to unmarshal whitelist update: %v", err) + continue + } + gs.cfg.handleWhitelistUpdate(msg.GetFrom(), pbToWhitelistState(update), update.Timestamp) + } + } +} + func (gs *gossipSub) publishQuotaHeartbeat(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { heartbeat := &pb.QuotaHandshake{ Quotas: quotaStatusesToPb(quotas), @@ -324,6 +357,14 @@ func (gs *gossipSub) publishQuotaHeartbeat(ctx context.Context, quotas map[strin return gs.quotaHeartbeatTopic.Publish(ctx, data) } +func (gs *gossipSub) publishWhitelistUpdate(ctx context.Context, update *pb.WhitelistState) error { + data, err := proto.Marshal(update) + if err != nil { + return fmt.Errorf("failed to marshal whitelist update: %w", err) + } + return gs.whitelistUpdatesTopic.Publish(ctx, data) +} + func (gs *gossipSub) run(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) @@ -351,5 +392,11 @@ func (gs *gossipSub) run(ctx context.Context) error { return err }) + g.Go(func() error { + err := gs.listenForWhitelistUpdates(ctx) + gs.log.Debug("Whitelist proposals listener stopped.") + return err + }) + return g.Wait() } diff --git a/tatanka/handler_permissions.go b/tatanka/handler_permissions.go index 46805a6..3c3ce78 100644 --- a/tatanka/handler_permissions.go +++ b/tatanka/handler_permissions.go @@ -3,9 +3,9 @@ package tatanka import ( "errors" + "github.com/bisoncraft/mesh/codec" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/protocol" - "github.com/bisoncraft/mesh/codec" ) var errUnauthorized = errors.New("unauthorized") @@ -49,7 +49,7 @@ func (t *TatankaNode) requireBonds(s network.Stream) error { func (t *TatankaNode) isWhitelistPeer(s network.Stream) error { peerID := s.Conn().RemotePeer() - if _, ok := t.getWhitelist().allPeerIDs()[peerID]; !ok { + if _, ok := t.whitelistManager.getWhitelist().PeerIDs[peerID]; !ok { return errUnauthorized } return nil diff --git a/tatanka/handler_permissions_test.go b/tatanka/handler_permissions_test.go index ae7da91..233fc63 100644 --- a/tatanka/handler_permissions_test.go +++ b/tatanka/handler_permissions_test.go @@ -8,6 +8,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/bisoncraft/mesh/codec" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" ) @@ -28,9 +29,7 @@ func TestPushPermissions(t *testing.T) { clientHost := mnet.Host(allPeers[1]) // Create a whitelist with just the mesh node - mockWhitelist := &whitelist{ - peers: []*peer.AddrInfo{{ID: meshHost.ID(), Addrs: meshHost.Addrs()}}, - } + mockWhitelist := types.NewWhitelist([]peer.ID{meshHost.ID()}) // Create the test node with bondStorage initially set to score 0 dir := t.TempDir() diff --git a/tatanka/handlers.go b/tatanka/handlers.go index 2758a34..306e353 100644 --- a/tatanka/handlers.go +++ b/tatanka/handlers.go @@ -6,14 +6,15 @@ import ( "strings" "time" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" "github.com/bisoncraft/mesh/bond" "github.com/bisoncraft/mesh/codec" "github.com/bisoncraft/mesh/oracle" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/tatanka/types" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" "google.golang.org/protobuf/proto" ) @@ -570,7 +571,7 @@ func (t *TatankaNode) handleDiscovery(s network.Stream) { } // Only share addresses for whitelist peers. - if _, ok := t.getWhitelist().allPeerIDs()[targetPeerID]; !ok { + if _, ok := t.whitelistManager.getWhitelist().PeerIDs[targetPeerID]; !ok { t.log.Warnf("Tatanka peer %s attempted to discover addresses for non-whitelist peer %s.", remotePeerID.ShortString(), targetPeerID.ShortString()) if err := codec.WriteLengthPrefixedMessage(s, pbDiscoveryResponseNotFound()); err != nil { t.log.Warnf("Failed to write discovery response to peer %s: %v.", remotePeerID.ShortString(), err) @@ -591,51 +592,50 @@ func (t *TatankaNode) handleDiscovery(s network.Stream) { } } -// handleWhitelist handles a request from another tatanka node to verify the -// whitelist alignment. The counterparty sends the list of peer IDs in their -// whitelist. If they match with ours, we send a success response, otherwise -// we send our whitelist peer IDs so the counterparty can see the difference. +// handleWhitelist handles a symmetric whitelist handshake with another tatanka +// node. Both sides exchange a WhitelistState, check for a match, and protect +// or close the connection based on the result. The receiver waits a few seconds +// to close the connection to allow the initiator to read the mismatch result func (t *TatankaNode) handleWhitelist(s network.Stream) { defer func() { _ = s.Close() }() remotePeerID := s.Conn().RemotePeer() - req := &pb.WhitelistRequest{} - if err := codec.ReadLengthPrefixedMessage(s, req); err != nil { - t.log.Warnf("Failed to read whitelist request from %s: %v", remotePeerID.ShortString(), err) + // 1. Read peer's state. + peerState := &pb.WhitelistState{} + if err := codec.ReadLengthPrefixedMessage(s, peerState); err != nil { + t.log.Warnf("Failed to read whitelist handshake from %s: %v", remotePeerID.ShortString(), err) return } - whitelist := t.getWhitelist() - localPeerIDs := whitelist.allPeerIDs() - var mismatch bool - - // Check if the incoming peer IDs are in the local whitelist. - for _, idBytes := range req.PeerIDs { - id, err := peer.IDFromBytes(idBytes) - if err != nil { - mismatch = true - break - } - if _, ok := localPeerIDs[id]; !ok { - mismatch = true - break - } + // 2. Send our state. + ownWs := t.whitelistManager.getLocalWhitelistState() + if err := codec.WriteLengthPrefixedMessage(s, whitelistStateToPb(ownWs)); err != nil { + t.log.Warnf("Failed to write whitelist handshake to %s: %v", remotePeerID.ShortString(), err) + return } - // Make sure there aren't additional peer IDs in the local whitelist. - mismatch = mismatch || len(req.PeerIDs) != len(localPeerIDs) + // 3. Check match independently. + matched := flexibleWhitelistMatch(ownWs.Current, ownWs.Proposed, peerState.PeerIDs, peerState.ProposedPeerIDs) - var resp *pb.WhitelistResponse - if mismatch { - resp = pbWhitelistResponseMismatch(whitelist.peerIDsBytes()) + // 4. Protect if matched. On mismatch the initiator's verifyWhitelist + // handles ClosePeer; doing it here synchronously races with the + // client's stream read and can prevent the mismatch from being + // detected. A delayed close acts as a safety net in case the + // initiator never cleans up. + if matched { + t.node.ConnManager().Protect(remotePeerID, "tatanka-node") } else { - resp = pbWhitelistResponseSuccess() + go func() { + time.Sleep(5 * time.Second) + if !t.node.ConnManager().IsProtected(remotePeerID, "tatanka-node") { + _ = t.node.Network().ClosePeer(remotePeerID) + } + }() } - if err := codec.WriteLengthPrefixedMessage(s, resp); err != nil { - t.log.Warnf("Failed to write whitelist response to %s: %v", remotePeerID.ShortString(), err) - } + // 5. Record peer's whitelist info regardless of match. + t.whitelistManager.updatePeerWhitelistState(remotePeerID, pbToWhitelistState(peerState), peerState.Timestamp) } // handleAvailableMeshNodes handles a request from a client to get a list of @@ -643,14 +643,14 @@ func (t *TatankaNode) handleWhitelist(s network.Stream) { func (t *TatankaNode) handleAvailableMeshNodes(s network.Stream) { defer func() { _ = s.Close() }() - whitelist := t.getWhitelist() + whitelistPeers := t.whitelistManager.getWhitelist().PeerIDs peerStore := t.node.Peerstore() var peers []*protocolsPb.PeerInfo - for _, p := range whitelist.peers { + for pid := range whitelistPeers { // Include ourselves - if p.ID == t.node.ID() { + if pid == t.node.ID() { peers = append(peers, libp2pPeerInfoToPb(peer.AddrInfo{ ID: t.node.ID(), Addrs: t.node.Addrs(), @@ -659,13 +659,13 @@ func (t *TatankaNode) handleAvailableMeshNodes(s network.Stream) { } // Only include connected peers - if t.node.Network().Connectedness(p.ID) != network.Connected { + if t.node.Network().Connectedness(pid) != network.Connected { continue } - addrs := peerStore.Addrs(p.ID) + addrs := peerStore.Addrs(pid) peers = append(peers, libp2pPeerInfoToPb(peer.AddrInfo{ - ID: p.ID, + ID: pid, Addrs: addrs, })) } @@ -683,6 +683,17 @@ func (t *TatankaNode) handleQuotaHeartbeat(senderID peer.ID, heartbeat *pb.Quota } } +func (t *TatankaNode) handleWhitelistUpdate(senderID peer.ID, ws *types.WhitelistState, timestamp int64) { + if senderID == t.node.ID() { + return + } + if !t.whitelistManager.updatePeerWhitelistState(senderID, ws, timestamp) { + return + } + pi := t.connectionManager.getPeerInfo(senderID) + t.adminNotify.BroadcastPeerUpdate(pi) +} + // handleQuotaHandshake handles a quota handshake request from another tatanka node. // This is used to exchange quota information on connection. func (t *TatankaNode) handleQuotaHandshake(s network.Stream) { diff --git a/tatanka/mesh_connection_manager.go b/tatanka/mesh_connection_manager.go index 1d2dedb..c1551a3 100644 --- a/tatanka/mesh_connection_manager.go +++ b/tatanka/mesh_connection_manager.go @@ -10,13 +10,15 @@ import ( "sync/atomic" "time" + "github.com/bisoncraft/mesh/codec" + "github.com/bisoncraft/mesh/tatanka/admin" + pb "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/decred/slog" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" - "github.com/bisoncraft/mesh/codec" - pb "github.com/bisoncraft/mesh/tatanka/pb" ma "github.com/multiformats/go-multiaddr" ) @@ -37,9 +39,6 @@ var ( errDiscoveryNotFound = errors.New("discovery not found") ) -// AdminUpdateCallback is called when connection states change -type AdminUpdateCallback func(peerID peer.ID, connected bool, whitelistMismatch bool, addresses []string, peerWhitelist []string) - // retryState tracks exponential backoff state for connection retries. type retryState struct { failures int @@ -71,12 +70,8 @@ type peerTracker struct { // immediately. signalReconnect chan struct{} // whitelistMismatch is set when the most recent connection attempt failed - // specifically due to whitelist mismatch. If this is the case, we know - // there's no reason to try discovery. - whitelistMismatch bool - // peerWhitelist contains the peer's whitelist if we received it during - // whitelist verification. - peerWhitelist []string + // due to whitelist mismatch. + whitelistMismatch atomic.Bool // initialCh is closed after the first connection attempt. initialCh chan struct{} initialOnce sync.Once @@ -88,16 +83,6 @@ func (t *peerTracker) markInitial() { }) } -// getAddresses returns the known addresses for this peer as strings. -func (t *peerTracker) getAddresses() []string { - addrs := t.m.node.Peerstore().Addrs(t.peerID) - result := make([]string, len(addrs)) - for i, addr := range addrs { - result[i] = addr.String() - } - return result -} - // run starts the main event loop for a single peer connection. func (t *peerTracker) run() { timer := time.NewTimer(0) @@ -111,7 +96,7 @@ func (t *peerTracker) run() { } success := func() { - t.m.adminCallback(t.peerID, true, false, t.getAddresses(), nil) + t.m.peerStateUpdated(t.m.getPeerInfo(t.peerID)) t.markInitial() retry.onSuccess() resetTimerWithJitter(maxRetryDelay) @@ -123,7 +108,7 @@ func (t *peerTracker) run() { } failure := func() { - t.m.adminCallback(t.peerID, false, t.whitelistMismatch, t.getAddresses(), t.peerWhitelist) + t.m.peerStateUpdated(t.m.getPeerInfo(t.peerID)) t.markInitial() resetTimerWithJitter(retry.backoff) retry.onFailure() @@ -138,7 +123,7 @@ func (t *peerTracker) run() { // - If we have no addresses: discover immediately on first attempt (pass 0), // then every N failures. shouldAttemptDiscovery := func() bool { - if t.whitelistMismatch { + if t.whitelistMismatch.Load() { return false } @@ -158,7 +143,7 @@ func (t *peerTracker) run() { case <-t.signalReconnect: // Do not force a reconnect if we just disconnected due to a whitelist mismatch. - if t.whitelistMismatch { + if t.whitelistMismatch.Load() && t.m.node.Network().Connectedness(t.peerID) != network.Connected { continue } if !timer.Stop() { @@ -168,8 +153,9 @@ func (t *peerTracker) run() { } } forceReconnect() + case <-timer.C: - if t.m.node.Network().Connectedness(t.peerID) == network.Connected { + if !t.whitelistMismatch.Load() && t.m.node.Network().Connectedness(t.peerID) == network.Connected { success() continue } @@ -203,8 +189,6 @@ func (t *peerTracker) connect() error { dialCtx, dialCancel := context.WithTimeout(t.ctx, connectToPeerTimeout) defer dialCancel() - t.whitelistMismatch = false - if err := t.m.node.Connect(dialCtx, peer.AddrInfo{ID: t.peerID}); err != nil { return err } @@ -212,11 +196,14 @@ func (t *peerTracker) connect() error { err := t.verifyWhitelist() if err != nil { if errors.Is(err, errWhitelistMismatch) { - t.whitelistMismatch = true + t.whitelistMismatch.Store(true) } return fmt.Errorf("failed to verify whitelist for peer %s: %w", t.peerID, err) } + // Whitelist verified — clear any prior mismatch. + t.whitelistMismatch.Store(false) + go t.exchangeOracleQuotas() return nil @@ -238,12 +225,14 @@ func (t *peerTracker) exchangeOracleQuotas() { req := &pb.QuotaHandshake{Quotas: localQuotas} if err := codec.WriteLengthPrefixedMessage(stream, req); err != nil { t.m.log.Debugf("Failed to send quota handshake to %s: %v", t.peerID, err) + _ = stream.Reset() return } resp := &pb.QuotaHandshake{} if err := codec.ReadLengthPrefixedMessage(stream, resp); err != nil { t.m.log.Debugf("Failed to read quota handshake from %s: %v", t.peerID, err) + _ = stream.Reset() return } @@ -252,16 +241,16 @@ func (t *peerTracker) exchangeOracleQuotas() { // discoverAddresses asks connected whitelist peers for the address of the target. func (t *peerTracker) discoverAddresses() bool { - whitelist := t.m.getWhitelist() + whitelist := t.m.getLocalWhitelistState().Current var wg sync.WaitGroup var success atomic.Bool - for _, p := range whitelist.peers { - if p.ID == t.m.node.ID() || p.ID == t.peerID { + for pid := range whitelist.PeerIDs { + if pid == t.m.node.ID() || pid == t.peerID { continue } - if t.m.node.Network().Connectedness(p.ID) != network.Connected { + if t.m.node.Network().Connectedness(pid) != network.Connected { continue } @@ -276,7 +265,7 @@ func (t *peerTracker) discoverAddresses() bool { } else if !errors.Is(err, errDiscoveryNotFound) { t.m.log.Warnf("Failed to discover addresses for %s via %s: %v", t.peerID, helper, err) } - }(p.ID) + }(pid) } wg.Wait() @@ -284,54 +273,66 @@ func (t *peerTracker) discoverAddresses() bool { return success.Load() } -// verifyWhitelist sends a whitelist request to the peer and verifies the response. -// If the response is a success, we protect the connection and return nil. Otherwise, -// we close the connection and return an error. +// verifyWhitelist performs a whitelist handshake with the peer. Both sides +// exchange a WhitelistState, check for a match, and protect or close the connection. func (t *peerTracker) verifyWhitelist() error { - var success bool - defer func() { - if success { - t.m.node.ConnManager().Protect(t.peerID, "tatanka-node") - } else { - _ = t.m.node.Network().ClosePeer(t.peerID) - } - }() + ctx, cancel := context.WithTimeout(t.ctx, connectToPeerTimeout) + defer cancel() - stream, err := t.m.node.NewStream(t.ctx, t.peerID, whitelistProtocol) + stream, err := t.m.node.NewStream(ctx, t.peerID, whitelistProtocol) if err != nil { + _ = t.m.node.Network().ClosePeer(t.peerID) return err } defer func() { _ = stream.Close() }() - req := &pb.WhitelistRequest{PeerIDs: t.m.getWhitelist().peerIDsBytes()} - if err := codec.WriteLengthPrefixedMessage(stream, req); err != nil { + // 1. Send our state. + state := t.m.getLocalWhitelistState() + currentWl := state.Current + proposedWl := state.Proposed + ownState := &pb.WhitelistState{ + PeerIDs: currentWl.PeerIDsBytes(), + Timestamp: time.Now().UnixNano(), + Ready: state.Ready, + } + if proposedWl != nil { + ownState.ProposedPeerIDs = proposedWl.PeerIDsBytes() + } + if err := codec.WriteLengthPrefixedMessage(stream, ownState); err != nil { + _ = stream.Reset() + _ = t.m.node.Network().ClosePeer(t.peerID) return err } - resp := &pb.WhitelistResponse{} - if err := codec.ReadLengthPrefixedMessage(stream, resp); err != nil { + // 2. Read peer's state. + peerState := &pb.WhitelistState{} + if err := codec.ReadLengthPrefixedMessage(stream, peerState); err != nil { + _ = stream.Reset() + _ = t.m.node.Network().ClosePeer(t.peerID) return err } - if resp.GetSuccess() != nil { - success = true - return nil + // 3. An empty PeerIDs means the peer rejected our handshake (the + // permission decorator sent a generic Response, not a WhitelistState). + if len(peerState.PeerIDs) == 0 { + t.m.updatePeerWhitelistState(t.peerID, nil, 0) + _ = t.m.node.Network().ClosePeer(t.peerID) + return fmt.Errorf("%w: peer %s rejected handshake", errWhitelistMismatch, t.peerID) } - if resp.GetMismatch() != nil { - // Extract peer's whitelist from mismatch response - peerIDs := resp.GetMismatch().GetPeerIDs() - t.peerWhitelist = make([]string, len(peerIDs)) - for i, peerIDBytes := range peerIDs { - peerID, err := peer.IDFromBytes(peerIDBytes) - if err != nil { - t.m.log.Warnf("Failed to parse peer ID from mismatch response: %v", err) - continue - } - t.peerWhitelist[i] = peerID.String() - } + // 4. Record peer's whitelist info regardless of match. + ws := pbToWhitelistState(peerState) + t.m.updatePeerWhitelistState(t.peerID, ws, peerState.Timestamp) + + // 5. Check match. + matched := flexibleWhitelistMatch(currentWl, proposedWl, peerState.PeerIDs, peerState.ProposedPeerIDs) + + if matched { + t.m.node.ConnManager().Protect(t.peerID, "tatanka-node") + return nil } + _ = t.m.node.Network().ClosePeer(t.peerID) return fmt.Errorf("%w for peer %s", errWhitelistMismatch, t.peerID) } @@ -340,50 +341,59 @@ type meshConnectionManager struct { log slog.Logger node host.Host ctx context.Context - whitelist atomic.Value // *whitelist + cancelCtx context.CancelFunc trackersMtx sync.RWMutex peerTrackers map[peer.ID]*peerTracker - initialCh chan struct{} - initialOnce sync.Once - initialErr atomic.Value // error - adminCallback AdminUpdateCallback + initialCh chan struct{} + initialOnce sync.Once + + peerStateUpdated func(admin.PeerInfo) + + // Whitelist handshake callbacks + getLocalWhitelistState func() *types.WhitelistState + updatePeerWhitelistState func(peerID peer.ID, ws *types.WhitelistState, timestamp int64) bool + getPeerWhitelistState func(peerID peer.ID) *types.WhitelistState - // Quota exchange callbacks + // Oracle quota handshake callbacks getLocalQuotas func() map[string]*pb.QuotaStatus handlePeerQuotas func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) } -func newMeshConnectionManager( - log slog.Logger, - node host.Host, - whitelist *whitelist, - adminCallback AdminUpdateCallback, - getLocalQuotas func() map[string]*pb.QuotaStatus, - handlePeerQuotas func(peerID peer.ID, quotas map[string]*pb.QuotaStatus), -) *meshConnectionManager { - m := &meshConnectionManager{ - log: log, - node: node, - peerTrackers: make(map[peer.ID]*peerTracker), - initialCh: make(chan struct{}), - adminCallback: adminCallback, - getLocalQuotas: getLocalQuotas, - handlePeerQuotas: handlePeerQuotas, - } - m.whitelist.Store(whitelist) - - // Add all bootstrap addresses to the peerstore. - for _, whitelistPeer := range whitelist.peers { - if whitelistPeer.ID == m.node.ID() || len(whitelistPeer.Addrs) == 0 { - continue - } - m.node.Peerstore().AddAddrs(whitelistPeer.ID, whitelistPeer.Addrs, peerstore.PermanentAddrTTL) - } +// meshConnectionManagerConfig holds configuration for creating a meshConnectionManager. +type meshConnectionManagerConfig struct { + log slog.Logger + node host.Host + peerStateUpdated func(admin.PeerInfo) + getLocalQuotas func() map[string]*pb.QuotaStatus + handlePeerQuotas func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) + getLocalWhitelistState func() *types.WhitelistState + updatePeerWhitelistState func(peerID peer.ID, ws *types.WhitelistState, timestamp int64) bool + getPeerWhitelistState func(peerID peer.ID) *types.WhitelistState +} - // Register for network events to react instantly to disconnects. - node.Network().Notify(&network.NotifyBundle{ +func newMeshConnectionManager(cfg *meshConnectionManagerConfig) *meshConnectionManager { + m := &meshConnectionManager{ + log: cfg.log, + node: cfg.node, + peerTrackers: make(map[peer.ID]*peerTracker), + initialCh: make(chan struct{}), + peerStateUpdated: cfg.peerStateUpdated, + getLocalQuotas: cfg.getLocalQuotas, + handlePeerQuotas: cfg.handlePeerQuotas, + getLocalWhitelistState: cfg.getLocalWhitelistState, + updatePeerWhitelistState: cfg.updatePeerWhitelistState, + getPeerWhitelistState: cfg.getPeerWhitelistState, + } + + m.ctx, m.cancelCtx = context.WithCancel(context.Background()) + + // Register for network events to react instantly to connects and disconnects. + cfg.node.Network().Notify(&network.NotifyBundle{ + ConnectedF: func(_ network.Network, conn network.Conn) { + m.triggerReconnect(conn.RemotePeer()) + }, DisconnectedF: func(_ network.Network, conn network.Conn) { m.triggerReconnect(conn.RemotePeer()) }, @@ -392,10 +402,6 @@ func newMeshConnectionManager( return m } -func (m *meshConnectionManager) getWhitelist() *whitelist { - return m.whitelist.Load().(*whitelist) -} - // triggerReconnect signals the specific peer tracker to wake up and retry immediately. func (m *meshConnectionManager) triggerReconnect(pid peer.ID) { m.trackersMtx.RLock() @@ -437,7 +443,6 @@ func (m *meshConnectionManager) startTracker(ctx context.Context, pid peer.ID) * cancel: cancel, signalReconnect: make(chan struct{}, 1), initialCh: make(chan struct{}), - initialOnce: sync.Once{}, } m.peerTrackers[pid] = t go t.run() @@ -461,7 +466,10 @@ func (m *meshConnectionManager) stopTracker(pid peer.ID) { // sendDiscoveryRequest sends a discovery request to reqPeerID for the addresses of targetPeerID. func (t *peerTracker) sendDiscoveryRequest(reqPeerID, targetPeerID peer.ID) ([]ma.Multiaddr, error) { - s, err := t.m.node.NewStream(t.ctx, reqPeerID, discoveryProtocol) + ctx, cancel := context.WithTimeout(t.ctx, connectToPeerTimeout) + defer cancel() + + s, err := t.m.node.NewStream(ctx, reqPeerID, discoveryProtocol) if err != nil { return nil, err } @@ -469,11 +477,13 @@ func (t *peerTracker) sendDiscoveryRequest(reqPeerID, targetPeerID peer.ID) ([]m request := &pb.DiscoveryRequest{Id: []byte(targetPeerID)} if err := codec.WriteLengthPrefixedMessage(s, request); err != nil { + _ = s.Reset() return nil, err } response := &pb.DiscoveryResponse{} if err := codec.ReadLengthPrefixedMessage(s, response); err != nil { + _ = s.Reset() return nil, err } @@ -492,10 +502,99 @@ func (t *peerTracker) sendDiscoveryRequest(reqPeerID, targetPeerID peer.ID) ([]m return nil, errDiscoveryNotFound } -func (m *meshConnectionManager) updateWhitelist(whitelist *whitelist) { - m.whitelist.Store(whitelist) - // TODO: Stop trackers for peers removed from whitelist and - // start trackers for new peers. +// reconcileTrackers should be called when the whitelist changes. It stops +// trackers for peers not in the new whitelist and starts trackers for peers +// in the new whitelist that aren't already tracked. +func (m *meshConnectionManager) reconcileTrackers() { + whitelistState := m.getLocalWhitelistState() + currentPeers := whitelistState.Current.PeerIDs + + // Stop trackers for peers not in the new whitelist. + m.trackersMtx.RLock() + var toStop []peer.ID + for pid := range m.peerTrackers { + if _, ok := currentPeers[pid]; !ok { + toStop = append(toStop, pid) + } + } + m.trackersMtx.RUnlock() + + for _, pid := range toStop { + m.stopTracker(pid) + _ = m.node.Network().ClosePeer(pid) + } + + // Start trackers for new peers. + m.trackersMtx.RLock() + var toStart []peer.ID + for pid := range currentPeers { + if pid == m.node.ID() { + continue + } + if _, ok := m.peerTrackers[pid]; ok { + continue // already tracked + } + toStart = append(toStart, pid) + } + m.trackersMtx.RUnlock() + + for _, pid := range toStart { + m.startTracker(m.ctx, pid) + } + + // Clear stale mismatch flags on existing trackers and signal them to + // reconnect immediately. After a whitelist change, prior mismatches may + // no longer apply. + m.trackersMtx.RLock() + for _, t := range m.peerTrackers { + t.whitelistMismatch.Store(false) + } + m.trackersMtx.RUnlock() + m.triggerReconnectAll() +} + +// getPeerInfo builds an admin.PeerInfo for a single peer. +func (m *meshConnectionManager) getPeerInfo(pid peer.ID) admin.PeerInfo { + state := admin.StateDisconnected + + m.trackersMtx.RLock() + t := m.peerTrackers[pid] + m.trackersMtx.RUnlock() + + if t != nil && t.whitelistMismatch.Load() { + state = admin.StateWhitelistMismatch + } else if m.node.Network().Connectedness(pid) == network.Connected { + state = admin.StateConnected + } + + addrs := m.node.Peerstore().Addrs(pid) + addrStrs := make([]string, len(addrs)) + for i, addr := range addrs { + addrStrs[i] = addr.String() + } + + return admin.PeerInfo{ + PeerID: pid.String(), + State: state, + Addresses: addrStrs, + WhitelistState: m.getPeerWhitelistState(pid), + } +} + +// peerInfoSnapshot returns a snapshot of all tracked peers' connection info. +func (m *meshConnectionManager) peerInfoSnapshot() map[peer.ID]admin.PeerInfo { + m.trackersMtx.RLock() + pids := make([]peer.ID, 0, len(m.peerTrackers)) + for pid := range m.peerTrackers { + pids = append(pids, pid) + } + m.trackersMtx.RUnlock() + + result := make(map[peer.ID]admin.PeerInfo, len(pids)) + for _, pid := range pids { + result[pid] = m.getPeerInfo(pid) + } + return result } // waitInitial blocks until the initial connectivity pass is marked complete. @@ -515,9 +614,13 @@ func (m *meshConnectionManager) markInitial() { // run starts the connection manager. func (m *meshConnectionManager) run(ctx context.Context) { - m.ctx = ctx + // Link the lifecycle of our internal context to the run context + go func() { + <-ctx.Done() + m.cancelCtx() + }() - whitelist := m.getWhitelist() + whitelist := m.getLocalWhitelistState().Current allTrackers := make(map[peer.ID]*peerTracker) waitForAllTrackers := func() { @@ -533,26 +636,26 @@ func (m *meshConnectionManager) run(ctx context.Context) { // PASS 1: Start trackers for peers with addresses. Wait for // them all to finish their initial connection loop before // starting trackers for peers without addresses. - for _, p := range whitelist.peers { - if p.ID == m.node.ID() { + for pid := range whitelist.PeerIDs { + if pid == m.node.ID() { continue } - hasAddrs := len(m.node.Peerstore().Addrs(p.ID)) > 0 + hasAddrs := len(m.node.Peerstore().Addrs(pid)) > 0 if hasAddrs { - allTrackers[p.ID] = m.startTracker(ctx, p.ID) + allTrackers[pid] = m.startTracker(ctx, pid) } } waitForAllTrackers() // PASS 2: Start peers WITHOUT addresses as well. - for _, p := range whitelist.peers { - if p.ID == m.node.ID() { + for pid := range whitelist.PeerIDs { + if pid == m.node.ID() { continue } // no-op if already started. - t := m.startTracker(ctx, p.ID) - allTrackers[p.ID] = t + t := m.startTracker(ctx, pid) + allTrackers[pid] = t } waitForAllTrackers() diff --git a/tatanka/mesh_connection_manager_test.go b/tatanka/mesh_connection_manager_test.go new file mode 100644 index 0000000..369946d --- /dev/null +++ b/tatanka/mesh_connection_manager_test.go @@ -0,0 +1,316 @@ +package tatanka + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/bisoncraft/mesh/codec" + "github.com/bisoncraft/mesh/tatanka/admin" + pb "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/tatanka/types" + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" +) + +// testMCMWLUpdate records a single updatePeerWhitelistState invocation. +type testMCMWLUpdate struct { + peerID peer.ID + ws *types.WhitelistState + timestamp int64 +} + +// testMCMQuotaUpdate records a single handlePeerQuotas invocation. +type testMCMQuotaUpdate struct { + peerID peer.ID + quotas map[string]*pb.QuotaStatus +} + +// testMCM is a lightweight harness for unit-testing meshConnectionManager +// without spinning up a full TatankaNode. +type testMCM struct { + t *testing.T + mnet mocknet.Mocknet + local host.Host + remote host.Host + mcm *meshConnectionManager + + // Callback recording channels. + peerStates chan admin.PeerInfo + wlUpdates chan testMCMWLUpdate + quotaUpdates chan testMCMQuotaUpdate + + // In-memory peer whitelist state store. + peerWLMu sync.Mutex + peerWLStates map[peer.ID]*types.WhitelistState + + localWLState *types.WhitelistState + peerReplyWLState *types.WhitelistState +} + +// newTestMCM builds a meshConnectionManager wired to a two-host mocknet. +// If peerReplyWLState is nil the remote host replies with a matching whitelist. +func newTestMCM(t *testing.T, peerReplyWLState *types.WhitelistState) *testMCM { + t.Helper() + + mnet, err := mocknet.WithNPeers(2) + if err != nil { + t.Fatal(err) + } + + peers := mnet.Peers() + local := mnet.Host(peers[0]) + remote := mnet.Host(peers[1]) + + // LinkPeers creates a virtual link so the hosts can dial each other. + // AddAddrs seeds the peerstores so the connection manager can find + // the peer's address. This mirrors linkNodeWithMesh in tatanka_test.go. + if _, err := mnet.LinkPeers(local.ID(), remote.ID()); err != nil { + t.Fatal(err) + } + local.Peerstore().AddAddrs(remote.ID(), remote.Addrs(), peerstore.PermanentAddrTTL) + remote.Peerstore().AddAddrs(local.ID(), local.Addrs(), peerstore.PermanentAddrTTL) + + sharedWL := types.NewWhitelist([]peer.ID{local.ID(), remote.ID()}) + localWLState := &types.WhitelistState{Current: sharedWL} + + if peerReplyWLState == nil { + peerReplyWLState = &types.WhitelistState{Current: sharedWL} + } + + tm := &testMCM{ + t: t, + mnet: mnet, + local: local, + remote: remote, + peerStates: make(chan admin.PeerInfo, 10), + wlUpdates: make(chan testMCMWLUpdate, 10), + quotaUpdates: make(chan testMCMQuotaUpdate, 10), + peerWLStates: make(map[peer.ID]*types.WhitelistState), + localWLState: localWLState, + peerReplyWLState: peerReplyWLState, + } + + // Register protocol handlers on the remote host. + remote.SetStreamHandler(whitelistProtocol, tm.handleWhitelist) + remote.SetStreamHandler(quotaHandshakeProtocol, tm.handleQuotaHandshake) + + logBackend := slog.NewBackend(os.Stdout) + log := logBackend.Logger("test-mcm") + log.SetLevel(slog.LevelDebug) + + tm.mcm = newMeshConnectionManager(&meshConnectionManagerConfig{ + log: log, + node: local, + peerStateUpdated: func(pi admin.PeerInfo) { + tm.peerStates <- pi + }, + getLocalQuotas: func() map[string]*pb.QuotaStatus { + return nil + }, + handlePeerQuotas: func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) { + tm.quotaUpdates <- testMCMQuotaUpdate{peerID: peerID, quotas: quotas} + }, + getLocalWhitelistState: func() *types.WhitelistState { + return tm.localWLState + }, + updatePeerWhitelistState: func(peerID peer.ID, ws *types.WhitelistState, timestamp int64) bool { + tm.peerWLMu.Lock() + tm.peerWLStates[peerID] = ws + tm.peerWLMu.Unlock() + tm.wlUpdates <- testMCMWLUpdate{peerID: peerID, ws: ws, timestamp: timestamp} + return true + }, + getPeerWhitelistState: func(peerID peer.ID) *types.WhitelistState { + tm.peerWLMu.Lock() + defer tm.peerWLMu.Unlock() + return tm.peerWLStates[peerID] + }, + }) + + return tm +} + +// handleWhitelist is the whitelist protocol handler registered on the remote +// host. It reads the initiator's state and replies with peerReplyWLState. +func (tm *testMCM) handleWhitelist(s network.Stream) { + defer s.Close() + + req := &pb.WhitelistState{} + if err := codec.ReadLengthPrefixedMessage(s, req); err != nil { + return + } + + reply := whitelistStateToPb(tm.peerReplyWLState) + if err := codec.WriteLengthPrefixedMessage(s, reply); err != nil { + _ = s.Reset() + return + } +} + +// handleQuotaHandshake is the quota protocol handler registered on the remote +// host. It reads the request and replies with an empty QuotaHandshake. +func (tm *testMCM) handleQuotaHandshake(s network.Stream) { + defer s.Close() + + req := &pb.QuotaHandshake{} + if err := codec.ReadLengthPrefixedMessage(s, req); err != nil { + return + } + + reply := &pb.QuotaHandshake{} + if err := codec.WriteLengthPrefixedMessage(s, reply); err != nil { + _ = s.Reset() + return + } +} + +// waitForPeerState drains the peerStates channel until an update with the +// expected state is received or the timeout elapses. +func waitForPeerState(t *testing.T, ch <-chan admin.PeerInfo, state admin.NodeConnectionState, timeout time.Duration) admin.PeerInfo { + t.Helper() + deadline := time.After(timeout) + for { + select { + case pi := <-ch: + if pi.State == state { + return pi + } + case <-deadline: + t.Fatalf("timeout waiting for peer state %s", state) + return admin.PeerInfo{} + } + } +} + +func TestMeshConnectionManager_PeerConnect(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + tm := newTestMCM(t, nil) + go tm.mcm.run(ctx) + + // Wait for connected state. + pi := waitForPeerState(t, tm.peerStates, admin.StateConnected, 10*time.Second) + if pi.PeerID != tm.remote.ID().String() { + t.Fatalf("connected peer: got %s, want %s", pi.PeerID, tm.remote.ID()) + } + + // updatePeerWhitelistState is called before peerStateUpdated, so the + // update should already be in the channel (or stored in the map). + select { + case wu := <-tm.wlUpdates: + if wu.peerID != tm.remote.ID() { + t.Fatalf("whitelist update peer: got %s, want %s", wu.peerID, tm.remote.ID()) + } + if wu.ws == nil { + t.Fatal("whitelist state is nil") + } + case <-time.After(time.Second): + tm.peerWLMu.Lock() + ws := tm.peerWLStates[tm.remote.ID()] + tm.peerWLMu.Unlock() + if ws == nil { + t.Fatal("updatePeerWhitelistState was not called") + } + } + + // handlePeerQuotas runs asynchronously; give it a moment. + select { + case qu := <-tm.quotaUpdates: + if qu.peerID != tm.remote.ID() { + t.Fatalf("quota update peer: got %s, want %s", qu.peerID, tm.remote.ID()) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for handlePeerQuotas") + } + + // getPeerInfo should return StateConnected. + info := tm.mcm.getPeerInfo(tm.remote.ID()) + if info.State != admin.StateConnected { + t.Fatalf("getPeerInfo state: got %s, want %s", info.State, admin.StateConnected) + } + + // peerInfoSnapshot should contain the peer with StateConnected. + snapshot := tm.mcm.peerInfoSnapshot() + snap, ok := snapshot[tm.remote.ID()] + if !ok { + t.Fatal("peer missing from peerInfoSnapshot") + } + if snap.State != admin.StateConnected { + t.Fatalf("snapshot state: got %s, want %s", snap.State, admin.StateConnected) + } +} + +func TestMeshConnectionManager_WhitelistMismatch(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // Remote replies with a completely different whitelist. + mismatchWL := &types.WhitelistState{ + Current: types.NewWhitelist([]peer.ID{randomPeerID(t), randomPeerID(t)}), + } + tm := newTestMCM(t, mismatchWL) + go tm.mcm.run(ctx) + + // Wait for mismatch state. + pi := waitForPeerState(t, tm.peerStates, admin.StateWhitelistMismatch, 10*time.Second) + if pi.PeerID != tm.remote.ID().String() { + t.Fatalf("mismatch peer: got %s, want %s", pi.PeerID, tm.remote.ID()) + } + + // getPeerInfo should reflect the mismatch. + info := tm.mcm.getPeerInfo(tm.remote.ID()) + if info.State != admin.StateWhitelistMismatch { + t.Fatalf("getPeerInfo state: got %s, want %s", info.State, admin.StateWhitelistMismatch) + } +} + +func TestMeshConnectionManager_Reconnect(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + tm := newTestMCM(t, nil) + go tm.mcm.run(ctx) + + // 1. Wait for initial connection. + waitForPeerState(t, tm.peerStates, admin.StateConnected, 10*time.Second) + + // Drain extra state events from the connect/notify cycle. + time.Sleep(500 * time.Millisecond) + for len(tm.peerStates) > 0 { + <-tm.peerStates + } + + // 2. Disconnect: unlink prevents new connections, ClosePeer tears down + // the existing one. + if err := tm.mnet.UnlinkPeers(tm.local.ID(), tm.remote.ID()); err != nil { + t.Fatalf("UnlinkPeers: %v", err) + } + if err := tm.local.Network().ClosePeer(tm.remote.ID()); err != nil { + t.Fatalf("ClosePeer: %v", err) + } + + // 3. Wait for disconnected state. + waitForPeerState(t, tm.peerStates, admin.StateDisconnected, 5*time.Second) + + // 4. Re-link to allow the tracker to reconnect. + if _, err := tm.mnet.LinkPeers(tm.local.ID(), tm.remote.ID()); err != nil { + t.Fatalf("LinkPeers: %v", err) + } + + // 5. Wait for reconnection (tracker retries on backoff ~2-4 s). + waitForPeerState(t, tm.peerStates, admin.StateConnected, 10*time.Second) + + // Verify final state. + info := tm.mcm.getPeerInfo(tm.remote.ID()) + if info.State != admin.StateConnected { + t.Fatalf("getPeerInfo after reconnect: got %s, want %s", info.State, admin.StateConnected) + } +} diff --git a/tatanka/pb/messages.pb.go b/tatanka/pb/messages.pb.go index 013182f..18ecb47 100644 --- a/tatanka/pb/messages.pb.go +++ b/tatanka/pb/messages.pb.go @@ -2,7 +2,7 @@ // versions: // protoc-gen-go v1.36.10 // protoc v6.33.1 -// source: tatanka/pb/messages.proto +// source: messages.proto package pb @@ -34,7 +34,7 @@ type ClientConnectionMsg struct { func (x *ClientConnectionMsg) Reset() { *x = ClientConnectionMsg{} - mi := &file_tatanka_pb_messages_proto_msgTypes[0] + mi := &file_messages_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -46,7 +46,7 @@ func (x *ClientConnectionMsg) String() string { func (*ClientConnectionMsg) ProtoMessage() {} func (x *ClientConnectionMsg) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[0] + mi := &file_messages_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -59,7 +59,7 @@ func (x *ClientConnectionMsg) ProtoReflect() protoreflect.Message { // Deprecated: Use ClientConnectionMsg.ProtoReflect.Descriptor instead. func (*ClientConnectionMsg) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{0} + return file_messages_proto_rawDescGZIP(), []int{0} } func (x *ClientConnectionMsg) GetId() []byte { @@ -101,7 +101,7 @@ type TatankaForwardRelayRequest struct { func (x *TatankaForwardRelayRequest) Reset() { *x = TatankaForwardRelayRequest{} - mi := &file_tatanka_pb_messages_proto_msgTypes[1] + mi := &file_messages_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -113,7 +113,7 @@ func (x *TatankaForwardRelayRequest) String() string { func (*TatankaForwardRelayRequest) ProtoMessage() {} func (x *TatankaForwardRelayRequest) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[1] + mi := &file_messages_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -126,7 +126,7 @@ func (x *TatankaForwardRelayRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TatankaForwardRelayRequest.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayRequest) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{1} + return file_messages_proto_rawDescGZIP(), []int{1} } func (x *TatankaForwardRelayRequest) GetInitiatorId() []byte { @@ -165,7 +165,7 @@ type TatankaForwardRelayResponse struct { func (x *TatankaForwardRelayResponse) Reset() { *x = TatankaForwardRelayResponse{} - mi := &file_tatanka_pb_messages_proto_msgTypes[2] + mi := &file_messages_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -177,7 +177,7 @@ func (x *TatankaForwardRelayResponse) String() string { func (*TatankaForwardRelayResponse) ProtoMessage() {} func (x *TatankaForwardRelayResponse) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[2] + mi := &file_messages_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -190,7 +190,7 @@ func (x *TatankaForwardRelayResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TatankaForwardRelayResponse.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayResponse) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{2} + return file_messages_proto_rawDescGZIP(), []int{2} } func (x *TatankaForwardRelayResponse) GetResponse() isTatankaForwardRelayResponse_Response { @@ -273,7 +273,7 @@ type DiscoveryRequest struct { func (x *DiscoveryRequest) Reset() { *x = DiscoveryRequest{} - mi := &file_tatanka_pb_messages_proto_msgTypes[3] + mi := &file_messages_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -285,7 +285,7 @@ func (x *DiscoveryRequest) String() string { func (*DiscoveryRequest) ProtoMessage() {} func (x *DiscoveryRequest) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[3] + mi := &file_messages_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -298,7 +298,7 @@ func (x *DiscoveryRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryRequest.ProtoReflect.Descriptor instead. func (*DiscoveryRequest) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{3} + return file_messages_proto_rawDescGZIP(), []int{3} } func (x *DiscoveryRequest) GetId() []byte { @@ -321,7 +321,7 @@ type DiscoveryResponse struct { func (x *DiscoveryResponse) Reset() { *x = DiscoveryResponse{} - mi := &file_tatanka_pb_messages_proto_msgTypes[4] + mi := &file_messages_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -333,7 +333,7 @@ func (x *DiscoveryResponse) String() string { func (*DiscoveryResponse) ProtoMessage() {} func (x *DiscoveryResponse) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[4] + mi := &file_messages_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -346,7 +346,7 @@ func (x *DiscoveryResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryResponse.ProtoReflect.Descriptor instead. func (*DiscoveryResponse) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{4} + return file_messages_proto_rawDescGZIP(), []int{4} } func (x *DiscoveryResponse) GetResponse() isDiscoveryResponse_Response { @@ -390,28 +390,33 @@ func (*DiscoveryResponse_Success_) isDiscoveryResponse_Response() {} func (*DiscoveryResponse_NotFound_) isDiscoveryResponse_Response() {} -type WhitelistRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - PeerIDs [][]byte `protobuf:"bytes,1,rep,name=peerIDs,proto3" json:"peerIDs,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *WhitelistRequest) Reset() { - *x = WhitelistRequest{} - mi := &file_tatanka_pb_messages_proto_msgTypes[5] +// WhitelistState is used both for the symmetric whitelist handshake between +// tatanka nodes and for gossipsub whitelist proposal broadcasts. +type WhitelistState struct { + state protoimpl.MessageState `protogen:"open.v1"` + PeerIDs [][]byte `protobuf:"bytes,1,rep,name=peerIDs,proto3" json:"peerIDs,omitempty"` + ProposedPeerIDs [][]byte `protobuf:"bytes,2,rep,name=proposed_peerIDs,json=proposedPeerIDs,proto3" json:"proposed_peerIDs,omitempty"` // empty = no proposal + Timestamp int64 `protobuf:"varint,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + Ready bool `protobuf:"varint,4,opt,name=ready,proto3" json:"ready,omitempty"` // true = node is ready to switch + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WhitelistState) Reset() { + *x = WhitelistState{} + mi := &file_messages_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *WhitelistRequest) String() string { +func (x *WhitelistState) String() string { return protoimpl.X.MessageStringOf(x) } -func (*WhitelistRequest) ProtoMessage() {} +func (*WhitelistState) ProtoMessage() {} -func (x *WhitelistRequest) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[5] +func (x *WhitelistState) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -422,100 +427,39 @@ func (x *WhitelistRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use WhitelistRequest.ProtoReflect.Descriptor instead. -func (*WhitelistRequest) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{5} +// Deprecated: Use WhitelistState.ProtoReflect.Descriptor instead. +func (*WhitelistState) Descriptor() ([]byte, []int) { + return file_messages_proto_rawDescGZIP(), []int{5} } -func (x *WhitelistRequest) GetPeerIDs() [][]byte { +func (x *WhitelistState) GetPeerIDs() [][]byte { if x != nil { return x.PeerIDs } return nil } -type WhitelistResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Types that are valid to be assigned to Response: - // - // *WhitelistResponse_Success_ - // *WhitelistResponse_Mismatch_ - Response isWhitelistResponse_Response `protobuf_oneof:"response"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *WhitelistResponse) Reset() { - *x = WhitelistResponse{} - mi := &file_tatanka_pb_messages_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *WhitelistResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*WhitelistResponse) ProtoMessage() {} - -func (x *WhitelistResponse) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[6] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use WhitelistResponse.ProtoReflect.Descriptor instead. -func (*WhitelistResponse) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{6} -} - -func (x *WhitelistResponse) GetResponse() isWhitelistResponse_Response { +func (x *WhitelistState) GetProposedPeerIDs() [][]byte { if x != nil { - return x.Response + return x.ProposedPeerIDs } return nil } -func (x *WhitelistResponse) GetSuccess() *WhitelistResponse_Success { +func (x *WhitelistState) GetTimestamp() int64 { if x != nil { - if x, ok := x.Response.(*WhitelistResponse_Success_); ok { - return x.Success - } + return x.Timestamp } - return nil + return 0 } -func (x *WhitelistResponse) GetMismatch() *WhitelistResponse_Mismatch { +func (x *WhitelistState) GetReady() bool { if x != nil { - if x, ok := x.Response.(*WhitelistResponse_Mismatch_); ok { - return x.Mismatch - } + return x.Ready } - return nil -} - -type isWhitelistResponse_Response interface { - isWhitelistResponse_Response() -} - -type WhitelistResponse_Success_ struct { - Success *WhitelistResponse_Success `protobuf:"bytes,1,opt,name=success,proto3,oneof"` -} - -type WhitelistResponse_Mismatch_ struct { - Mismatch *WhitelistResponse_Mismatch `protobuf:"bytes,2,opt,name=mismatch,proto3,oneof"` + return false } -func (*WhitelistResponse_Success_) isWhitelistResponse_Response() {} - -func (*WhitelistResponse_Mismatch_) isWhitelistResponse_Response() {} - // NodeOracleUpdate contains oracle data for sharing between Tatanka mesh nodes. type NodeOracleUpdate struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -530,7 +474,7 @@ type NodeOracleUpdate struct { func (x *NodeOracleUpdate) Reset() { *x = NodeOracleUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[7] + mi := &file_messages_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -542,7 +486,7 @@ func (x *NodeOracleUpdate) String() string { func (*NodeOracleUpdate) ProtoMessage() {} func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[7] + mi := &file_messages_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -555,7 +499,7 @@ func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { // Deprecated: Use NodeOracleUpdate.ProtoReflect.Descriptor instead. func (*NodeOracleUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{7} + return file_messages_proto_rawDescGZIP(), []int{6} } func (x *NodeOracleUpdate) GetSource() string { @@ -605,7 +549,7 @@ type QuotaStatus struct { func (x *QuotaStatus) Reset() { *x = QuotaStatus{} - mi := &file_tatanka_pb_messages_proto_msgTypes[8] + mi := &file_messages_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -617,7 +561,7 @@ func (x *QuotaStatus) String() string { func (*QuotaStatus) ProtoMessage() {} func (x *QuotaStatus) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[8] + mi := &file_messages_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -630,7 +574,7 @@ func (x *QuotaStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use QuotaStatus.ProtoReflect.Descriptor instead. func (*QuotaStatus) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{8} + return file_messages_proto_rawDescGZIP(), []int{7} } func (x *QuotaStatus) GetFetchesRemaining() int64 { @@ -665,7 +609,7 @@ type QuotaHandshake struct { func (x *QuotaHandshake) Reset() { *x = QuotaHandshake{} - mi := &file_tatanka_pb_messages_proto_msgTypes[9] + mi := &file_messages_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -677,7 +621,7 @@ func (x *QuotaHandshake) String() string { func (*QuotaHandshake) ProtoMessage() {} func (x *QuotaHandshake) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[9] + mi := &file_messages_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -690,7 +634,7 @@ func (x *QuotaHandshake) ProtoReflect() protoreflect.Message { // Deprecated: Use QuotaHandshake.ProtoReflect.Descriptor instead. func (*QuotaHandshake) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{9} + return file_messages_proto_rawDescGZIP(), []int{8} } func (x *QuotaHandshake) GetQuotas() map[string]*QuotaStatus { @@ -708,7 +652,7 @@ type TatankaForwardRelayResponse_ClientNotFound struct { func (x *TatankaForwardRelayResponse_ClientNotFound) Reset() { *x = TatankaForwardRelayResponse_ClientNotFound{} - mi := &file_tatanka_pb_messages_proto_msgTypes[10] + mi := &file_messages_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -720,7 +664,7 @@ func (x *TatankaForwardRelayResponse_ClientNotFound) String() string { func (*TatankaForwardRelayResponse_ClientNotFound) ProtoMessage() {} func (x *TatankaForwardRelayResponse_ClientNotFound) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[10] + mi := &file_messages_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -733,7 +677,7 @@ func (x *TatankaForwardRelayResponse_ClientNotFound) ProtoReflect() protoreflect // Deprecated: Use TatankaForwardRelayResponse_ClientNotFound.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayResponse_ClientNotFound) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{2, 0} + return file_messages_proto_rawDescGZIP(), []int{2, 0} } type TatankaForwardRelayResponse_ClientRejected struct { @@ -744,7 +688,7 @@ type TatankaForwardRelayResponse_ClientRejected struct { func (x *TatankaForwardRelayResponse_ClientRejected) Reset() { *x = TatankaForwardRelayResponse_ClientRejected{} - mi := &file_tatanka_pb_messages_proto_msgTypes[11] + mi := &file_messages_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -756,7 +700,7 @@ func (x *TatankaForwardRelayResponse_ClientRejected) String() string { func (*TatankaForwardRelayResponse_ClientRejected) ProtoMessage() {} func (x *TatankaForwardRelayResponse_ClientRejected) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[11] + mi := &file_messages_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -769,7 +713,7 @@ func (x *TatankaForwardRelayResponse_ClientRejected) ProtoReflect() protoreflect // Deprecated: Use TatankaForwardRelayResponse_ClientRejected.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayResponse_ClientRejected) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{2, 1} + return file_messages_proto_rawDescGZIP(), []int{2, 1} } type DiscoveryResponse_Success struct { @@ -781,7 +725,7 @@ type DiscoveryResponse_Success struct { func (x *DiscoveryResponse_Success) Reset() { *x = DiscoveryResponse_Success{} - mi := &file_tatanka_pb_messages_proto_msgTypes[12] + mi := &file_messages_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -793,7 +737,7 @@ func (x *DiscoveryResponse_Success) String() string { func (*DiscoveryResponse_Success) ProtoMessage() {} func (x *DiscoveryResponse_Success) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[12] + mi := &file_messages_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -806,7 +750,7 @@ func (x *DiscoveryResponse_Success) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryResponse_Success.ProtoReflect.Descriptor instead. func (*DiscoveryResponse_Success) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{4, 0} + return file_messages_proto_rawDescGZIP(), []int{4, 0} } func (x *DiscoveryResponse_Success) GetAddrs() [][]byte { @@ -824,7 +768,7 @@ type DiscoveryResponse_NotFound struct { func (x *DiscoveryResponse_NotFound) Reset() { *x = DiscoveryResponse_NotFound{} - mi := &file_tatanka_pb_messages_proto_msgTypes[13] + mi := &file_messages_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -836,7 +780,7 @@ func (x *DiscoveryResponse_NotFound) String() string { func (*DiscoveryResponse_NotFound) ProtoMessage() {} func (x *DiscoveryResponse_NotFound) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[13] + mi := &file_messages_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -849,94 +793,14 @@ func (x *DiscoveryResponse_NotFound) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryResponse_NotFound.ProtoReflect.Descriptor instead. func (*DiscoveryResponse_NotFound) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{4, 1} + return file_messages_proto_rawDescGZIP(), []int{4, 1} } -type WhitelistResponse_Success struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} +var File_messages_proto protoreflect.FileDescriptor -func (x *WhitelistResponse_Success) Reset() { - *x = WhitelistResponse_Success{} - mi := &file_tatanka_pb_messages_proto_msgTypes[14] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *WhitelistResponse_Success) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*WhitelistResponse_Success) ProtoMessage() {} - -func (x *WhitelistResponse_Success) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[14] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use WhitelistResponse_Success.ProtoReflect.Descriptor instead. -func (*WhitelistResponse_Success) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{6, 0} -} - -type WhitelistResponse_Mismatch struct { - state protoimpl.MessageState `protogen:"open.v1"` - PeerIDs [][]byte `protobuf:"bytes,1,rep,name=peerIDs,proto3" json:"peerIDs,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *WhitelistResponse_Mismatch) Reset() { - *x = WhitelistResponse_Mismatch{} - mi := &file_tatanka_pb_messages_proto_msgTypes[15] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *WhitelistResponse_Mismatch) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*WhitelistResponse_Mismatch) ProtoMessage() {} - -func (x *WhitelistResponse_Mismatch) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[15] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use WhitelistResponse_Mismatch.ProtoReflect.Descriptor instead. -func (*WhitelistResponse_Mismatch) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{6, 1} -} - -func (x *WhitelistResponse_Mismatch) GetPeerIDs() [][]byte { - if x != nil { - return x.PeerIDs - } - return nil -} - -var File_tatanka_pb_messages_proto protoreflect.FileDescriptor - -const file_tatanka_pb_messages_proto_rawDesc = "" + +const file_messages_proto_rawDesc = "" + "\n" + - "\x19tatanka/pb/messages.proto\x12\x02pb\"\x82\x01\n" + + "\x0emessages.proto\x12\x02pb\"\x82\x01\n" + "\x13ClientConnectionMsg\x12\x0e\n" + "\x02id\x18\x01 \x01(\fR\x02id\x12\x1f\n" + "\vreporter_id\x18\x02 \x01(\fR\n" + @@ -966,17 +830,12 @@ const file_tatanka_pb_messages_proto_rawDesc = "" + "\n" + "\bNotFoundB\n" + "\n" + - "\bresponse\",\n" + - "\x10WhitelistRequest\x12\x18\n" + - "\apeerIDs\x18\x01 \x03(\fR\apeerIDs\"\xc9\x01\n" + - "\x11WhitelistResponse\x129\n" + - "\asuccess\x18\x01 \x01(\v2\x1d.pb.WhitelistResponse.SuccessH\x00R\asuccess\x12<\n" + - "\bmismatch\x18\x02 \x01(\v2\x1e.pb.WhitelistResponse.MismatchH\x00R\bmismatch\x1a\t\n" + - "\aSuccess\x1a$\n" + - "\bMismatch\x12\x18\n" + - "\apeerIDs\x18\x01 \x03(\fR\apeerIDsB\n" + - "\n" + - "\bresponse\"\xe2\x02\n" + + "\bresponse\"\x89\x01\n" + + "\x0eWhitelistState\x12\x18\n" + + "\apeerIDs\x18\x01 \x03(\fR\apeerIDs\x12)\n" + + "\x10proposed_peerIDs\x18\x02 \x03(\fR\x0fproposedPeerIDs\x12\x1c\n" + + "\ttimestamp\x18\x03 \x01(\x03R\ttimestamp\x12\x14\n" + + "\x05ready\x18\x04 \x01(\bR\x05ready\"\xe2\x02\n" + "\x10NodeOracleUpdate\x12\x16\n" + "\x06source\x18\x01 \x01(\tR\x06source\x12\x1c\n" + "\ttimestamp\x18\x02 \x01(\x03R\ttimestamp\x128\n" + @@ -1000,92 +859,83 @@ const file_tatanka_pb_messages_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\v2\x0f.pb.QuotaStatusR\x05value:\x028\x01B'Z%github.com/bisoncraft/mesh/tatanka/pbb\x06proto3" var ( - file_tatanka_pb_messages_proto_rawDescOnce sync.Once - file_tatanka_pb_messages_proto_rawDescData []byte + file_messages_proto_rawDescOnce sync.Once + file_messages_proto_rawDescData []byte ) -func file_tatanka_pb_messages_proto_rawDescGZIP() []byte { - file_tatanka_pb_messages_proto_rawDescOnce.Do(func() { - file_tatanka_pb_messages_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_tatanka_pb_messages_proto_rawDesc), len(file_tatanka_pb_messages_proto_rawDesc))) +func file_messages_proto_rawDescGZIP() []byte { + file_messages_proto_rawDescOnce.Do(func() { + file_messages_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_messages_proto_rawDesc), len(file_messages_proto_rawDesc))) }) - return file_tatanka_pb_messages_proto_rawDescData + return file_messages_proto_rawDescData } -var file_tatanka_pb_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 19) -var file_tatanka_pb_messages_proto_goTypes = []any{ +var file_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 16) +var file_messages_proto_goTypes = []any{ (*ClientConnectionMsg)(nil), // 0: pb.ClientConnectionMsg (*TatankaForwardRelayRequest)(nil), // 1: pb.TatankaForwardRelayRequest (*TatankaForwardRelayResponse)(nil), // 2: pb.TatankaForwardRelayResponse (*DiscoveryRequest)(nil), // 3: pb.DiscoveryRequest (*DiscoveryResponse)(nil), // 4: pb.DiscoveryResponse - (*WhitelistRequest)(nil), // 5: pb.WhitelistRequest - (*WhitelistResponse)(nil), // 6: pb.WhitelistResponse - (*NodeOracleUpdate)(nil), // 7: pb.NodeOracleUpdate - (*QuotaStatus)(nil), // 8: pb.QuotaStatus - (*QuotaHandshake)(nil), // 9: pb.QuotaHandshake - (*TatankaForwardRelayResponse_ClientNotFound)(nil), // 10: pb.TatankaForwardRelayResponse.ClientNotFound - (*TatankaForwardRelayResponse_ClientRejected)(nil), // 11: pb.TatankaForwardRelayResponse.ClientRejected - (*DiscoveryResponse_Success)(nil), // 12: pb.DiscoveryResponse.Success - (*DiscoveryResponse_NotFound)(nil), // 13: pb.DiscoveryResponse.NotFound - (*WhitelistResponse_Success)(nil), // 14: pb.WhitelistResponse.Success - (*WhitelistResponse_Mismatch)(nil), // 15: pb.WhitelistResponse.Mismatch - nil, // 16: pb.NodeOracleUpdate.PricesEntry - nil, // 17: pb.NodeOracleUpdate.FeeRatesEntry - nil, // 18: pb.QuotaHandshake.QuotasEntry -} -var file_tatanka_pb_messages_proto_depIdxs = []int32{ - 10, // 0: pb.TatankaForwardRelayResponse.client_not_found:type_name -> pb.TatankaForwardRelayResponse.ClientNotFound - 11, // 1: pb.TatankaForwardRelayResponse.client_rejected:type_name -> pb.TatankaForwardRelayResponse.ClientRejected - 12, // 2: pb.DiscoveryResponse.success:type_name -> pb.DiscoveryResponse.Success - 13, // 3: pb.DiscoveryResponse.not_found:type_name -> pb.DiscoveryResponse.NotFound - 14, // 4: pb.WhitelistResponse.success:type_name -> pb.WhitelistResponse.Success - 15, // 5: pb.WhitelistResponse.mismatch:type_name -> pb.WhitelistResponse.Mismatch - 16, // 6: pb.NodeOracleUpdate.prices:type_name -> pb.NodeOracleUpdate.PricesEntry - 17, // 7: pb.NodeOracleUpdate.fee_rates:type_name -> pb.NodeOracleUpdate.FeeRatesEntry - 8, // 8: pb.NodeOracleUpdate.quota:type_name -> pb.QuotaStatus - 18, // 9: pb.QuotaHandshake.quotas:type_name -> pb.QuotaHandshake.QuotasEntry - 8, // 10: pb.QuotaHandshake.QuotasEntry.value:type_name -> pb.QuotaStatus - 11, // [11:11] is the sub-list for method output_type - 11, // [11:11] is the sub-list for method input_type - 11, // [11:11] is the sub-list for extension type_name - 11, // [11:11] is the sub-list for extension extendee - 0, // [0:11] is the sub-list for field type_name -} - -func init() { file_tatanka_pb_messages_proto_init() } -func file_tatanka_pb_messages_proto_init() { - if File_tatanka_pb_messages_proto != nil { + (*WhitelistState)(nil), // 5: pb.WhitelistState + (*NodeOracleUpdate)(nil), // 6: pb.NodeOracleUpdate + (*QuotaStatus)(nil), // 7: pb.QuotaStatus + (*QuotaHandshake)(nil), // 8: pb.QuotaHandshake + (*TatankaForwardRelayResponse_ClientNotFound)(nil), // 9: pb.TatankaForwardRelayResponse.ClientNotFound + (*TatankaForwardRelayResponse_ClientRejected)(nil), // 10: pb.TatankaForwardRelayResponse.ClientRejected + (*DiscoveryResponse_Success)(nil), // 11: pb.DiscoveryResponse.Success + (*DiscoveryResponse_NotFound)(nil), // 12: pb.DiscoveryResponse.NotFound + nil, // 13: pb.NodeOracleUpdate.PricesEntry + nil, // 14: pb.NodeOracleUpdate.FeeRatesEntry + nil, // 15: pb.QuotaHandshake.QuotasEntry +} +var file_messages_proto_depIdxs = []int32{ + 9, // 0: pb.TatankaForwardRelayResponse.client_not_found:type_name -> pb.TatankaForwardRelayResponse.ClientNotFound + 10, // 1: pb.TatankaForwardRelayResponse.client_rejected:type_name -> pb.TatankaForwardRelayResponse.ClientRejected + 11, // 2: pb.DiscoveryResponse.success:type_name -> pb.DiscoveryResponse.Success + 12, // 3: pb.DiscoveryResponse.not_found:type_name -> pb.DiscoveryResponse.NotFound + 13, // 4: pb.NodeOracleUpdate.prices:type_name -> pb.NodeOracleUpdate.PricesEntry + 14, // 5: pb.NodeOracleUpdate.fee_rates:type_name -> pb.NodeOracleUpdate.FeeRatesEntry + 7, // 6: pb.NodeOracleUpdate.quota:type_name -> pb.QuotaStatus + 15, // 7: pb.QuotaHandshake.quotas:type_name -> pb.QuotaHandshake.QuotasEntry + 7, // 8: pb.QuotaHandshake.QuotasEntry.value:type_name -> pb.QuotaStatus + 9, // [9:9] is the sub-list for method output_type + 9, // [9:9] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name +} + +func init() { file_messages_proto_init() } +func file_messages_proto_init() { + if File_messages_proto != nil { return } - file_tatanka_pb_messages_proto_msgTypes[2].OneofWrappers = []any{ + file_messages_proto_msgTypes[2].OneofWrappers = []any{ (*TatankaForwardRelayResponse_Success)(nil), (*TatankaForwardRelayResponse_ClientNotFound_)(nil), (*TatankaForwardRelayResponse_ClientRejected_)(nil), (*TatankaForwardRelayResponse_Error)(nil), } - file_tatanka_pb_messages_proto_msgTypes[4].OneofWrappers = []any{ + file_messages_proto_msgTypes[4].OneofWrappers = []any{ (*DiscoveryResponse_Success_)(nil), (*DiscoveryResponse_NotFound_)(nil), } - file_tatanka_pb_messages_proto_msgTypes[6].OneofWrappers = []any{ - (*WhitelistResponse_Success_)(nil), - (*WhitelistResponse_Mismatch_)(nil), - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_tatanka_pb_messages_proto_rawDesc), len(file_tatanka_pb_messages_proto_rawDesc)), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_messages_proto_rawDesc), len(file_messages_proto_rawDesc)), NumEnums: 0, - NumMessages: 19, + NumMessages: 16, NumExtensions: 0, NumServices: 0, }, - GoTypes: file_tatanka_pb_messages_proto_goTypes, - DependencyIndexes: file_tatanka_pb_messages_proto_depIdxs, - MessageInfos: file_tatanka_pb_messages_proto_msgTypes, + GoTypes: file_messages_proto_goTypes, + DependencyIndexes: file_messages_proto_depIdxs, + MessageInfos: file_messages_proto_msgTypes, }.Build() - File_tatanka_pb_messages_proto = out.File - file_tatanka_pb_messages_proto_goTypes = nil - file_tatanka_pb_messages_proto_depIdxs = nil + File_messages_proto = out.File + file_messages_proto_goTypes = nil + file_messages_proto_depIdxs = nil } diff --git a/tatanka/pb/messages.proto b/tatanka/pb/messages.proto index c86e9a9..9b553bc 100644 --- a/tatanka/pb/messages.proto +++ b/tatanka/pb/messages.proto @@ -50,23 +50,15 @@ message DiscoveryResponse { message NotFound {} } -message WhitelistRequest { +// WhitelistState is used both for the symmetric whitelist handshake between +// tatanka nodes and for gossipsub whitelist proposal broadcasts. +message WhitelistState { repeated bytes peerIDs = 1; + repeated bytes proposed_peerIDs = 2; // empty = no proposal + int64 timestamp = 3; + bool ready = 4; // true = node is ready to switch } -message WhitelistResponse { - oneof response { - Success success = 1; - Mismatch mismatch = 2; - } - - message Success { - } - - message Mismatch { - repeated bytes peerIDs = 1; - } -} // NodeOracleUpdate contains oracle data for sharing between Tatanka mesh nodes. message NodeOracleUpdate { diff --git a/tatanka/pb_helpers.go b/tatanka/pb_helpers.go index e4f28ca..c1bdcd8 100644 --- a/tatanka/pb_helpers.go +++ b/tatanka/pb_helpers.go @@ -10,9 +10,32 @@ import ( "github.com/bisoncraft/mesh/oracle/sources" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/tatanka/types" ma "github.com/multiformats/go-multiaddr" ) +// whitelistStateToPb converts a types.WhitelistState to a pb.WhitelistState. +func whitelistStateToPb(ws *types.WhitelistState) *pb.WhitelistState { + state := &pb.WhitelistState{ + PeerIDs: ws.Current.PeerIDsBytes(), + Timestamp: time.Now().UnixNano(), + Ready: ws.Ready, + } + if ws.Proposed != nil { + state.ProposedPeerIDs = ws.Proposed.PeerIDsBytes() + } + return state +} + +// pbToWhitelistState converts a pb.WhitelistState to a types.WhitelistState. +func pbToWhitelistState(state *pb.WhitelistState) *types.WhitelistState { + return &types.WhitelistState{ + Current: types.DecodePeerIDBytes(state.PeerIDs), + Proposed: types.DecodePeerIDBytes(state.ProposedPeerIDs), + Ready: state.Ready, + } +} + // libp2pPeerInfoToPb converts a peer.AddrInfo to a protocolsPb.PeerInfo. func libp2pPeerInfoToPb(peerInfo peer.AddrInfo) *protocolsPb.PeerInfo { addrBytes := make([][]byte, len(peerInfo.Addrs)) @@ -307,24 +330,6 @@ func pbTatankaForwardRelayError(message string) *pb.TatankaForwardRelayResponse } } -func pbWhitelistResponseSuccess() *pb.WhitelistResponse { - return &pb.WhitelistResponse{ - Response: &pb.WhitelistResponse_Success_{ - Success: &pb.WhitelistResponse_Success{}, - }, - } -} - -func pbWhitelistResponseMismatch(mismatchedPeerIDs [][]byte) *pb.WhitelistResponse { - return &pb.WhitelistResponse{ - Response: &pb.WhitelistResponse_Mismatch_{ - Mismatch: &pb.WhitelistResponse_Mismatch{ - PeerIDs: mismatchedPeerIDs, - }, - }, - } -} - func pbDiscoveryResponseNotFound() *pb.DiscoveryResponse { return &pb.DiscoveryResponse{ Response: &pb.DiscoveryResponse_NotFound_{ diff --git a/tatanka/peerstore_cache.go b/tatanka/peerstore_cache.go new file mode 100644 index 0000000..d8b5ef4 --- /dev/null +++ b/tatanka/peerstore_cache.go @@ -0,0 +1,162 @@ +package tatanka + +import ( + "context" + "encoding/json" + "fmt" + "os" + "sync" + "time" + + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + ma "github.com/multiformats/go-multiaddr" +) + +// peerstoreCache persists known peer addresses to disk so that a node can +// reconnect after a cold restart without relying on bootstrap flags. +type peerstoreCache struct { + path string + log slog.Logger + host host.Host + getWhitelist func() map[peer.ID]struct{} + writingMtx sync.Mutex +} + +func newPeerstoreCache(log slog.Logger, path string, host host.Host, getWhitelist func() map[peer.ID]struct{}) *peerstoreCache { + return &peerstoreCache{ + path: path, + log: log, + host: host, + getWhitelist: getWhitelist, + } +} + +// load reads cached addresses from disk and adds them to the peerstore. +// No-op if the file doesn't exist. +func (c *peerstoreCache) load() { + data, err := os.ReadFile(c.path) + if err != nil { + if os.IsNotExist(err) { + return + } + c.log.Warnf("Failed to read peerstore cache: %v", err) + return + } + + var cache map[string][]string + if err := json.Unmarshal(data, &cache); err != nil { + c.log.Warnf("Failed to parse peerstore cache: %v", err) + return + } + + for pidStr, addrStrs := range cache { + pid, err := peer.Decode(pidStr) + if err != nil { + c.log.Warnf("Skipping invalid peer ID in cache: %s", pidStr) + continue + } + addrs := make([]ma.Multiaddr, 0, len(addrStrs)) + for _, a := range addrStrs { + addr, err := ma.NewMultiaddr(a) + if err != nil { + c.log.Warnf("Skipping invalid address in cache for %s: %s", pidStr, a) + continue + } + addrs = append(addrs, addr) + } + if len(addrs) > 0 { + c.host.Peerstore().AddAddrs(pid, addrs, peerstore.PermanentAddrTTL) + } + } + + c.log.Infof("Loaded peerstore cache with %d peers", len(cache)) +} + +// save writes the current addresses for the whitelist peers to the disk. +func (c *peerstoreCache) save() { + c.writingMtx.Lock() + defer c.writingMtx.Unlock() + + cache := make(map[string][]string) + for pid := range c.getWhitelist() { + addrs := c.host.Peerstore().Addrs(pid) + if len(addrs) == 0 { + continue + } + addrStrs := make([]string, len(addrs)) + for i, a := range addrs { + addrStrs[i] = a.String() + } + cache[pid.String()] = addrStrs + } + + data, err := json.Marshal(cache) + if err != nil { + c.log.Errorf("Failed to marshal peerstore cache: %v", err) + return + } + + if err := atomicWriteFile(c.path, data); err != nil { + c.log.Errorf("Failed to write peerstore cache: %v", err) + return + } +} + +// run periodically persists peerstore addresses for current whitelist peers. +func (c *peerstoreCache) run(ctx context.Context) { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.save() + } + } +} + +// parseBootstrapAddrs parses multiaddr strings into peer.AddrInfo slices, +// merging multiple addresses for the same peer. +func parseBootstrapAddrs(addrs []string) ([]peer.AddrInfo, error) { + peerMap := make(map[peer.ID]*peer.AddrInfo) + + for _, addrStr := range addrs { + maddr, err := ma.NewMultiaddr(addrStr) + if err != nil { + return nil, fmt.Errorf("failed to parse address %q: %w", addrStr, err) + } + info, err := peer.AddrInfoFromP2pAddr(maddr) + if err != nil { + return nil, fmt.Errorf("failed to parse peer info from %q: %w", addrStr, err) + } + + if existing, ok := peerMap[info.ID]; ok { + existing.Addrs = append(existing.Addrs, info.Addrs...) + } else { + peerMap[info.ID] = info + } + } + + result := make([]peer.AddrInfo, 0, len(peerMap)) + for _, info := range peerMap { + result = append(result, *info) + } + return result, nil +} + +// seedBootstrapAddrs adds the given addresses to the peerstore. +// If any address is invalid an error is returned. +func seedBootstrapAddrs(host host.Host, addrs []string) error { + infos, err := parseBootstrapAddrs(addrs) + if err != nil { + return err + } + for _, info := range infos { + host.Peerstore().AddAddrs(info.ID, info.Addrs, peerstore.PermanentAddrTTL) + } + return nil +} diff --git a/tatanka/peerstore_cache_test.go b/tatanka/peerstore_cache_test.go new file mode 100644 index 0000000..600ec0a --- /dev/null +++ b/tatanka/peerstore_cache_test.go @@ -0,0 +1,181 @@ +package tatanka + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/peer" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + ma "github.com/multiformats/go-multiaddr" +) + +func TestPeerstoreCacheSaveLoad(t *testing.T) { + mnet, err := mocknet.WithNPeers(3) + if err != nil { + t.Fatal(err) + } + + hosts := mnet.Hosts() + h := hosts[0] + + // Add known addresses for host 1 and host 2 to host 0's peerstore. + addr1, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + addr2, _ := ma.NewMultiaddr("/ip4/5.6.7.8/tcp/5678") + h.Peerstore().AddAddrs(hosts[1].ID(), []ma.Multiaddr{addr1}, 1<<62) + h.Peerstore().AddAddrs(hosts[2].ID(), []ma.Multiaddr{addr2}, 1<<62) + + dir := t.TempDir() + cachePath := filepath.Join(dir, "peerstore.json") + + backend := slog.NewBackend(os.Stdout) + log := backend.Logger("test") + log.SetLevel(slog.LevelOff) + + getWhitelist := func() map[peer.ID]struct{} { + return map[peer.ID]struct{}{ + hosts[1].ID(): {}, + hosts[2].ID(): {}, + } + } + cache := newPeerstoreCache(log, cachePath, h, getWhitelist) + + // Save with both peer IDs in the whitelist. + cache.save() + + // Verify the file exists and has correct content. + data, err := os.ReadFile(cachePath) + if err != nil { + t.Fatalf("Failed to read cache file: %v", err) + } + + var saved map[string][]string + if err := json.Unmarshal(data, &saved); err != nil { + t.Fatalf("Failed to unmarshal cache: %v", err) + } + + if len(saved) != 2 { + t.Fatalf("Expected 2 peers in cache, got %d", len(saved)) + } + + // Now load into a fresh host and verify addresses are added. + mnet2, err := mocknet.WithNPeers(1) + if err != nil { + t.Fatal(err) + } + h2 := mnet2.Hosts()[0] + + cache2 := newPeerstoreCache(log, cachePath, h2, getWhitelist) + cache2.load() + + addrs1 := h2.Peerstore().Addrs(hosts[1].ID()) + if len(addrs1) == 0 { + t.Fatal("Expected addresses for host 1 after load") + } + addrs2 := h2.Peerstore().Addrs(hosts[2].ID()) + if len(addrs2) == 0 { + t.Fatal("Expected addresses for host 2 after load") + } +} + +func TestPeerstoreCacheOnlyWhitelistPeers(t *testing.T) { + mnet, err := mocknet.WithNPeers(3) + if err != nil { + t.Fatal(err) + } + + hosts := mnet.Hosts() + h := hosts[0] + + addr1, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + addr2, _ := ma.NewMultiaddr("/ip4/5.6.7.8/tcp/5678") + h.Peerstore().AddAddrs(hosts[1].ID(), []ma.Multiaddr{addr1}, 1<<62) + h.Peerstore().AddAddrs(hosts[2].ID(), []ma.Multiaddr{addr2}, 1<<62) + + dir := t.TempDir() + cachePath := filepath.Join(dir, "peerstore.json") + + backend := slog.NewBackend(os.Stdout) + log := backend.Logger("test") + log.SetLevel(slog.LevelOff) + + // Only save host 1 (not host 2). + cache := newPeerstoreCache(log, cachePath, h, func() map[peer.ID]struct{} { + return map[peer.ID]struct{}{ + hosts[1].ID(): {}, + } + }) + cache.save() + + data, err := os.ReadFile(cachePath) + if err != nil { + t.Fatalf("Failed to read cache file: %v", err) + } + + var saved map[string][]string + if err := json.Unmarshal(data, &saved); err != nil { + t.Fatalf("Failed to unmarshal cache: %v", err) + } + + if len(saved) != 1 { + t.Fatalf("Expected 1 peer in cache, got %d", len(saved)) + } + if _, ok := saved[hosts[1].ID().String()]; !ok { + t.Fatal("Expected host 1 in cache") + } + if _, ok := saved[hosts[2].ID().String()]; ok { + t.Fatal("Host 2 should not be in cache") + } +} + +func TestPeerstoreCacheSeedAddresses(t *testing.T) { + mnet, err := mocknet.WithNPeers(2) + if err != nil { + t.Fatal(err) + } + + hosts := mnet.Hosts() + h := hosts[0] + + // Build a full p2p multiaddr targeting hosts[1]. + addrStr := "/ip4/10.0.0.1/tcp/9999/p2p/" + hosts[1].ID().String() + if err := seedBootstrapAddrs(h, []string{addrStr}); err != nil { + t.Fatalf("seedBootstrapAddrs failed: %v", err) + } + + // Verify address was added to peerstore. + addrs := h.Peerstore().Addrs(hosts[1].ID()) + if len(addrs) == 0 { + t.Fatal("Expected addresses after seedBootstrapAddrs") + } + + expected, _ := ma.NewMultiaddr("/ip4/10.0.0.1/tcp/9999") + found := false + for _, a := range addrs { + if a.Equal(expected) { + found = true + break + } + } + if !found { + t.Fatal("Seeded address not found in peerstore") + } +} + +func TestPeerstoreCacheLoadNoFile(t *testing.T) { + mnet, err := mocknet.WithNPeers(1) + if err != nil { + t.Fatal(err) + } + h := mnet.Hosts()[0] + + backend := slog.NewBackend(os.Stdout) + log := backend.Logger("test") + log.SetLevel(slog.LevelOff) + + cache := newPeerstoreCache(log, "/nonexistent/path/peerstore.json", h, nil) + // Should not panic or error. + cache.load() +} diff --git a/tatanka/tatanka.go b/tatanka/tatanka.go index 9170a1e..f12e39c 100644 --- a/tatanka/tatanka.go +++ b/tatanka/tatanka.go @@ -14,17 +14,19 @@ import ( "sync/atomic" "time" - "github.com/decred/slog" - "github.com/libp2p/go-libp2p" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/peer" "github.com/bisoncraft/mesh/oracle" "github.com/bisoncraft/mesh/oracle/sources" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/bisoncraft/mesh/tatanka/admin" - "github.com/bisoncraft/mesh/tatanka/pb" + pb "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/tatanka/types" + "github.com/decred/slog" + "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -33,6 +35,10 @@ const ( // for the tatanka node. privateKeyFileName = "p.key" + // whitelistFileName is the name of the file that contains the whitelist for + // the tatanka node. + whitelistFileName = "whitelist.json" + // forwardRelayProtocol is the protocol used to forward a relay message between two tatanka nodes. forwardRelayProtocol = "/tatanka/forward-relay/1.0.0" @@ -49,13 +55,17 @@ const ( // Config is the configuration for the tatanka node type Config struct { - DataDir string - Logger slog.Logger - ListenIP string - ListenPort int - MetricsPort int - AdminPort int - WhitelistPath string + DataDir string + Logger slog.Logger + ListenIP string + ListenPort int + MetricsPort int + AdminPort int + BootstrapAddrs []string + WhitelistPeers []peer.ID + // ForceWhitelist overwrites any existing whitelist on disk with the provided + // whitelist when WhitelistPeers is non-empty. + ForceWhitelist bool // Oracle Configuration CMCKey string @@ -73,8 +83,8 @@ func WithHost(h host.Host) Option { } } -// Oracle defines the requirements for implementing an oracle. -type Oracle interface { +// oracleService defines the requirements for implementing an oracle. +type oracleService interface { Run(ctx context.Context) Merge(update *oracle.OracleUpdate, senderID string) *oracle.MergeResult Price(ticker oracle.Ticker) (float64, bool) @@ -86,37 +96,58 @@ type Oracle interface { // TatankaNode is a permissioned node in the tatanka mesh type TatankaNode struct { - config *Config - node host.Host - log slog.Logger - whitelist atomic.Value // *whitelist - readyCh chan struct{} - readyOnce sync.Once - readyErr atomic.Value // error - privateKey crypto.PrivKey - bondVerifier *bondVerifier - bondStorage bondStorage - + config *Config + node host.Host + log slog.Logger + initWhitelist *types.Whitelist + readyCh chan struct{} + readyOnce sync.Once + readyErr atomic.Value // error + privateKey crypto.PrivKey + + bondVerifier *bondVerifier + bondStorage bondStorage + peerstoreCache *peerstoreCache gossipSub *gossipSub clientConnectionManager *clientConnectionManager subscriptionManager *subscriptionManager pushStreamManager *pushStreamManager + whitelistManager *whitelistManager connectionManager *meshConnectionManager adminServer *admin.Server - - metricsServer *http.Server - - oracle Oracle + adminNotify adminNotifier + metricsServer *http.Server + oracle oracleService } -// NewTatankaNode creates a new TatankaNode with the given configuration and options. -func NewTatankaNode(config *Config, opts ...Option) (*TatankaNode, error) { - privateKey, err := getOrCreatePrivateKey(filepath.Join(config.DataDir, privateKeyFileName)) +func initTatankaNode(dataDir string) (crypto.PrivKey, error) { + if dataDir == "" { + return nil, errors.New("no data directory provided") + } + if err := os.MkdirAll(dataDir, 0o700); err != nil { + return nil, fmt.Errorf("failed to create data directory %q: %w", dataDir, err) + } + priv, err := getOrCreatePrivateKey(filepath.Join(dataDir, privateKeyFileName)) if err != nil { return nil, err } + return priv, nil +} + +// InitTatankaNode ensures the data directory exists, generates (or loads) the +// node private key, and returns the corresponding peer ID. +func InitTatankaNode(dataDir string) (peer.ID, error) { + priv, err := initTatankaNode(dataDir) + if err != nil { + return "", err + } - whitelist, err := loadWhitelist(config.WhitelistPath) + return peer.IDFromPrivateKey(priv) +} + +// NewTatankaNode creates a new TatankaNode with the given configuration and options. +func NewTatankaNode(config *Config, opts ...Option) (*TatankaNode, error) { + privateKey, err := initTatankaNode(config.DataDir) if err != nil { return nil, err } @@ -131,21 +162,42 @@ func NewTatankaNode(config *Config, opts ...Option) (*TatankaNode, error) { bondStorage: newMemoryBondStorage(time.Now), readyCh: make(chan struct{}), } - t.whitelist.Store(whitelist) for _, opt := range opts { opt(t) } - return t, nil -} + // Derive local peer ID from the host if one was injected (e.g. tests), + // otherwise from the generated private key. + var localPeerID peer.ID + if t.node != nil { + localPeerID = t.node.ID() + } else { + localPeerID, err = peer.IDFromPrivateKey(privateKey) + if err != nil { + return nil, err + } + } + + t.initWhitelist, err = initWhitelist(config.DataDir, config.WhitelistPeers, config.ForceWhitelist, localPeerID) + if err != nil { + return nil, err + } -func (t *TatankaNode) getWhitelist() *whitelist { - return t.whitelist.Load().(*whitelist) + return t, nil } -func (t *TatankaNode) getWhitelistPeers() map[peer.ID]struct{} { - return t.getWhitelist().allPeerIDs() +// decodePeerIDStrings decodes a slice of peer ID strings into peer.ID values. +func decodePeerIDStrings(peers []string) ([]peer.ID, error) { + peerIDs := make([]peer.ID, 0, len(peers)) + for _, p := range peers { + pid, err := peer.Decode(p) + if err != nil { + return nil, fmt.Errorf("invalid peer ID %q: %w", p, err) + } + peerIDs = append(peerIDs, pid) + } + return peerIDs, nil } func (t *TatankaNode) handleBroadcastMessage(msg *protocolsPb.PushMessage) { @@ -159,29 +211,71 @@ func (t *TatankaNode) handleClientConnectionMessage(update *clientConnectionUpda t.clientConnectionManager.updateClientConnectionInfo(update) } +// peerStates computes the current admin state on demand by querying the +// connection manager and whitelist manager. +func (t *TatankaNode) peerStates() admin.AdminState { + state := admin.AdminState{ + OurPeerID: t.node.ID().String(), + WhitelistState: t.whitelistManager.getLocalWhitelistState(), + Peers: make(map[string]admin.PeerInfo), + } + for pid, pi := range t.connectionManager.peerInfoSnapshot() { + state.Peers[pid.String()] = pi + } + return state +} + // Run starts the tatanka node and blocks until the context is done. func (t *TatankaNode) Run(ctx context.Context) error { - wg := sync.WaitGroup{} + t.adminNotify = noopAdminNotifier{} - // Setup libp2p node if not provided in options. - if t.node == nil { - listenAddrs := []string{ - fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", t.config.ListenPort), - fmt.Sprintf("/ip6/::/tcp/%d", t.config.ListenPort), - } - var err error - t.node, err = libp2p.New( - libp2p.Identity(t.privateKey), - libp2p.ListenAddrStrings(listenAddrs...), - // EnableRelayService for p2p communication between clients - libp2p.EnableRelayService(), - ) - if err != nil { - t.markReady(err) - return err - } + if err := t.initHost(); err != nil { + t.markReady(err) + return err + } + if err := t.initMessaging(ctx); err != nil { + t.markReady(err) + return err + } + if err := t.initOracle(); err != nil { + t.markReady(err) + return err + } + if err := t.initAdmin(ctx); err != nil { + t.markReady(err) + return err + } + if err := t.initConnectivity(); err != nil { + t.markReady(err) + return err + } + + t.setupStreamHandlers() + t.setupObservability() + + return t.serve(ctx) +} + +// initHost creates the libp2p host if not already injected (e.g. via WithHost). +func (t *TatankaNode) initHost() error { + if t.node != nil { + return nil } + listenAddrs := []string{ + fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", t.config.ListenPort), + fmt.Sprintf("/ip6/::/tcp/%d", t.config.ListenPort), + } + var err error + t.node, err = libp2p.New( + libp2p.Identity(t.privateKey), + libp2p.ListenAddrStrings(listenAddrs...), + ) + return err +} + +// initMessaging creates the gossipSub and pushStreamManager. +func (t *TatankaNode) initMessaging(ctx context.Context) error { t.log.Infof("Node ID: %s", t.node.ID().String()) listenAddrs := t.node.Network().ListenAddresses() @@ -192,16 +286,22 @@ func (t *TatankaNode) Run(ctx context.Context) error { var err error t.gossipSub, err = newGossipSub(ctx, &gossipSubCfg{ - node: t.node, - log: t.config.Logger, - getWhitelistPeers: t.getWhitelistPeers, + node: t.node, + log: t.config.Logger, + getWhitelistPeers: func() map[peer.ID]struct{} { + return t.whitelistManager.getWhitelist().PeerIDs + }, handleBroadcastMessage: t.handleBroadcastMessage, handleClientConnectionMessage: t.handleClientConnectionMessage, handleOracleUpdate: t.handleOracleUpdate, handleQuotaHeartbeat: t.handleQuotaHeartbeat, + // t.whitelistManager and t.connectionManager are set in initConnectivity, + // after gossipSub is created. The closure is safe because it is only + // called during gossipSub.run, which starts in serve() after all init + // phases. + handleWhitelistUpdate: t.handleWhitelistUpdate, }) if err != nil { - t.markReady(err) return err } @@ -217,70 +317,149 @@ func (t *TatankaNode) Run(ctx context.Context) error { } }) - // Only create oracle if not provided (e.g., via test setup) - if t.oracle == nil { - t.oracle, err = oracle.New(&oracle.Config{ - Log: t.config.Logger, - CMCKey: t.config.CMCKey, - TatumKey: t.config.TatumKey, - BlockcypherToken: t.config.BlockcypherToken, - NodeID: t.node.ID().String(), - PublishUpdate: t.gossipSub.publishOracleUpdate, - OnStateUpdate: func(update *oracle.OracleSnapshot) { - if t.adminServer != nil { - t.adminServer.BroadcastOracleUpdate("oracle_update", update) - } - }, - PublishQuotaHeartbeat: t.gossipSub.publishQuotaHeartbeat, - }) - if err != nil { - return fmt.Errorf("failed to create oracle: %v", err) - } + return nil +} + +// initOracle creates the oracle if not already injected (e.g. via test setup). +func (t *TatankaNode) initOracle() error { + if t.oracle != nil { + return nil } - // Create admin callback function and setup the admin server if configured. - adminCallback := func(peerID peer.ID, connected bool, whitelistMismatch bool, addresses []string, peerWhitelist []string) { + var err error + t.oracle, err = oracle.New(&oracle.Config{ + Log: t.config.Logger, + CMCKey: t.config.CMCKey, + TatumKey: t.config.TatumKey, + BlockcypherToken: t.config.BlockcypherToken, + NodeID: t.node.ID().String(), + PublishUpdate: t.gossipSub.publishOracleUpdate, + OnStateUpdate: func(update *oracle.OracleSnapshot) { + t.adminNotify.BroadcastOracleUpdate("oracle_update", update) + }, + PublishQuotaHeartbeat: t.gossipSub.publishQuotaHeartbeat, + }) + if err != nil { + return fmt.Errorf("failed to create oracle: %v", err) } - if t.config.AdminPort > 0 { - adminAddr := fmt.Sprintf(":%d", t.config.AdminPort) - server := admin.NewServer(t.config.Logger, adminAddr, t.oracle) - whitelistIDs := t.getWhitelist().allPeerIDs() - whitelist := make([]string, 0, len(whitelistIDs)) - for id := range whitelistIDs { - whitelist = append(whitelist, id.String()) - } - server.UpdateWhitelist(whitelist) - - adminCallback = func(peerID peer.ID, connected, whitelistMismatch bool, addresses []string, peerWhitelist []string) { - state := admin.StateDisconnected - switch { - case connected: - state = admin.StateConnected - case whitelistMismatch: - state = admin.StateWhitelistMismatch + + return nil +} + +// initAdmin creates the admin server when configured (AdminPort > 0) and +// upgrades adminNotify from the noop to the live implementation. +func (t *TatankaNode) initAdmin(ctx context.Context) error { + if t.config.AdminPort <= 0 { + return nil + } + + adminAddr := fmt.Sprintf("127.0.0.1:%d", t.config.AdminPort) + server := admin.NewServer(&admin.Config{ + Log: t.config.Logger, + Addr: adminAddr, + PeerID: t.node.ID().String(), + Oracle: t.oracle, + GetState: func() admin.AdminState { + return t.peerStates() + }, + ProposeWhitelist: func(peers []string) error { + peerIDs, err := decodePeerIDStrings(peers) + if err != nil { + return err } - server.UpdateConnectionState(peerID, state, addresses, peerWhitelist) - } + return t.whitelistManager.proposeWhitelist(types.NewWhitelist(peerIDs)) + }, + ClearProposal: func() { + t.whitelistManager.clearProposal() + }, + ForceWhitelist: func(peers []string) error { + peerIDs, err := decodePeerIDStrings(peers) + if err != nil { + return err + } + return t.whitelistManager.forceWhitelist(types.NewWhitelist(peerIDs)) + }, + }) - t.adminServer = server + t.adminServer = server + t.adminNotify = &liveAdminNotifier{server: server} + + return nil +} + +// initConnectivity creates the peerstore cache and mesh connection manager. +func (t *TatankaNode) initConnectivity() error { + t.peerstoreCache = newPeerstoreCache( + t.log, + filepath.Join(t.config.DataDir, "peerstore.json"), + t.node, + func() map[peer.ID]struct{} { + return t.whitelistManager.getWhitelist().PeerIDs + }, + ) + t.peerstoreCache.load() + + if len(t.config.BootstrapAddrs) > 0 { + if err := seedBootstrapAddrs(t.node, t.config.BootstrapAddrs); err != nil { + return fmt.Errorf("failed to seed bootstrap addresses: %w", err) + } } - t.connectionManager = newMeshConnectionManager( - t.config.Logger, t.node, t.getWhitelist(), adminCallback, - func() map[string]*pb.QuotaStatus { + whitelistPath := filepath.Join(t.config.DataDir, whitelistFileName) + t.whitelistManager = newWhitelistManager(&whitelistManagerConfig{ + log: t.config.Logger, + peerID: t.node.ID(), + isConnected: func(pid peer.ID) bool { + return t.node.Network().Connectedness(pid) == network.Connected + }, + whitelist: t.initWhitelist, + whitelistUpdated: func(newWl *types.Whitelist) { + if err := saveWhitelist(whitelistPath, newWl); err != nil { + t.log.Errorf("Failed to save whitelist: %v", err) + } + t.connectionManager.reconcileTrackers() + t.adminNotify.BroadcastWhitelistUpdate(admin.WhitelistUpdate{ + WhitelistState: t.whitelistManager.getLocalWhitelistState(), + }) + t.peerstoreCache.save() + }, + broadcastLocalState: func(ws *types.WhitelistState) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + if err := t.gossipSub.publishWhitelistUpdate(ctx, whitelistStateToPb(ws)); err != nil { + t.log.Errorf("Failed to publish whitelist update: %v", err) + } + t.adminNotify.BroadcastWhitelistState(ws) + }, + }) + + t.connectionManager = newMeshConnectionManager(&meshConnectionManagerConfig{ + log: t.config.Logger, + node: t.node, + peerStateUpdated: func(pi admin.PeerInfo) { + t.adminNotify.BroadcastPeerUpdate(pi) + }, + getPeerWhitelistState: t.whitelistManager.getPeerWhitelistState, + getLocalQuotas: func() map[string]*pb.QuotaStatus { return quotaStatusesToPb(t.oracle.GetLocalQuotas()) }, - func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) { + handlePeerQuotas: func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) { for source, q := range quotas { t.oracle.UpdatePeerSourceQuota(peerID.String(), pbToTimestampedQuotaStatus(q), source) } }, - ) + getLocalWhitelistState: t.whitelistManager.getLocalWhitelistState, + updatePeerWhitelistState: t.whitelistManager.updatePeerWhitelistState, + }) - t.log.Infof("Admin interface available (or not) on :%d", t.config.AdminPort) + return nil +} - t.setupStreamHandlers() - t.setupObservability() +// serve launches all long-running goroutines, waits for the initial connectivity +// pass to complete, marks the node as ready, then blocks until context +// cancellation. After all goroutines exit it runs shutdown. +func (t *TatankaNode) serve(ctx context.Context) error { + var wg sync.WaitGroup go func() { t.log.Infof("Metrics available on :%d/metrics", t.config.MetricsPort) @@ -292,7 +471,6 @@ func (t *TatankaNode) Run(ctx context.Context) error { } }() - // Start admin server if configured if t.adminServer != nil { wg.Add(1) go func() { @@ -316,7 +494,12 @@ func (t *TatankaNode) Run(ctx context.Context) error { } }() - // Maintain mesh connectivity + wg.Add(1) + go func() { + defer wg.Done() + t.whitelistManager.run(ctx) + }() + wg.Add(1) go func() { defer wg.Done() @@ -327,28 +510,37 @@ func (t *TatankaNode) Run(ctx context.Context) error { t.connectionManager.waitInitial(ctx) t.markReady(nil) - // Run Oracle wg.Add(1) go func() { defer wg.Done() t.oracle.Run(ctx) }() + wg.Add(1) + go func() { + defer wg.Done() + t.peerstoreCache.run(ctx) + }() + wg.Wait() + return t.shutdown() +} + +// shutdown performs graceful teardown of the metrics server and libp2p host. +func (t *TatankaNode) shutdown() error { t.log.Infof("Shutting down tatanka node...") - err = t.metricsServer.Shutdown(ctx) - if err != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := t.metricsServer.Shutdown(shutdownCtx); err != nil { return err } - err = t.node.Close() - if err != nil { + if err := t.node.Close(); err != nil { return err } t.log.Infof("Tatanka node shutdown complete.") - return nil } @@ -376,6 +568,8 @@ func (t *TatankaNode) markReady(err error) { }) } +// getOrCreatePrivateKey loads an existing private key from filePath, or +// generates a new Ed25519 key and writes it to filePath if none exists. func getOrCreatePrivateKey(filePath string) (crypto.PrivKey, error) { data, err := os.ReadFile(filePath) if err != nil { diff --git a/tatanka/tatanka_test.go b/tatanka/tatanka_test.go index 0786e48..07466c8 100644 --- a/tatanka/tatanka_test.go +++ b/tatanka/tatanka_test.go @@ -2,7 +2,6 @@ package tatanka import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -19,10 +18,12 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/bisoncraft/mesh/bond" "github.com/bisoncraft/mesh/codec" "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/bisoncraft/mesh/oracle/sources" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" @@ -77,7 +78,7 @@ type tOracle struct { feeRates map[oracle.Network]*big.Int } -var _ Oracle = (*tOracle)(nil) +var _ oracleService = (*tOracle)(nil) func newTOracle() *tOracle { return &tOracle{ @@ -152,25 +153,19 @@ func (t *tOracle) SetFeeRates(feeRates map[oracle.Network]*big.Int) { t.feeRates = feeRates } -func newTestNode(t *testing.T, ctx context.Context, h host.Host, dataDir string, whitelist *whitelist) *TatankaNode { +func newTestNode(t *testing.T, ctx context.Context, h host.Host, dataDir string, wl *types.Whitelist) *TatankaNode { logBackend := slog.NewBackend(os.Stdout) log := logBackend.Logger(h.ID().ShortString()) log.SetLevel(slog.LevelDebug) // Write the whitelist to the data directory. - whitelistPath := filepath.Join(dataDir, "whitelist.json") - whitelistData, err := json.Marshal(whitelist.toFile()) - if err != nil { - t.Fatalf("Failed to marshal whitelist: %v", err) - } - if err := os.WriteFile(whitelistPath, whitelistData, 0644); err != nil { + if err := saveWhitelist(filepath.Join(dataDir, "whitelist.json"), wl); err != nil { t.Fatalf("Failed to write whitelist: %v", err) } n, err := NewTatankaNode(&Config{ - Logger: log, - DataDir: dataDir, - WhitelistPath: filepath.Join(dataDir, "whitelist.json"), + Logger: log, + DataDir: dataDir, }, WithHost(h)) if err != nil { t.Fatalf("Failed to create test node: %v", err) @@ -193,25 +188,19 @@ func newTestNode(t *testing.T, ctx context.Context, h host.Host, dataDir string, } // newTestNodeWithOracle creates a test node with a custom oracle implementation. -func newTestNodeWithOracle(t *testing.T, ctx context.Context, h host.Host, dataDir string, whitelist *whitelist, testOracle Oracle) *TatankaNode { +func newTestNodeWithOracle(t *testing.T, ctx context.Context, h host.Host, dataDir string, wl *types.Whitelist, testOracle oracleService) *TatankaNode { logBackend := slog.NewBackend(os.Stdout) log := logBackend.Logger(h.ID().ShortString()) log.SetLevel(slog.LevelDebug) // Write the whitelist to the data directory. - whitelistPath := filepath.Join(dataDir, "whitelist.json") - whitelistData, err := json.Marshal(whitelist.toFile()) - if err != nil { - t.Fatalf("Failed to marshal whitelist: %v", err) - } - if err := os.WriteFile(whitelistPath, whitelistData, 0644); err != nil { + if err := saveWhitelist(filepath.Join(dataDir, "whitelist.json"), wl); err != nil { t.Fatalf("Failed to write whitelist: %v", err) } n, err := NewTatankaNode(&Config{ - Logger: log, - DataDir: dataDir, - WhitelistPath: filepath.Join(dataDir, "whitelist.json"), + Logger: log, + DataDir: dataDir, }, WithHost(h)) if err != nil { t.Fatalf("Failed to create test node: %v", err) @@ -542,7 +531,7 @@ func (tc *testClient) handleIncomingRelay(s network.Stream) { } func fullyConnectedMeshWithClients(ctx context.Context, t *testing.T, numMeshNodes, numClients int, clientToNode func(int) int) ( - net mocknet.Mocknet, meshNodes []*TatankaNode, clients []*testClient) { + meshNodes []*TatankaNode, wl *types.Whitelist, clients []*testClient) { mnet, err := mocknet.WithNPeers(numMeshNodes + numClients) if err != nil { t.Fatal(err) @@ -559,13 +548,11 @@ func fullyConnectedMeshWithClients(ctx context.Context, t *testing.T, numMeshNod clientHosts[i] = mnet.Host(allPeers[numMeshNodes+i]) } - whitelistPeers := make([]*peer.AddrInfo, numMeshNodes) + whitelistPeerIDs := make([]peer.ID, numMeshNodes) for i, h := range meshHosts { - whitelistPeers[i] = &peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()} - } - mockWhitelist := &whitelist{ - peers: whitelistPeers, + whitelistPeerIDs[i] = h.ID() } + mockWhitelist := types.NewWhitelist(whitelistPeerIDs) runningNodes := make([]*TatankaNode, 0, numMeshNodes) for i, h := range meshHosts { @@ -599,7 +586,7 @@ func fullyConnectedMeshWithClients(ctx context.Context, t *testing.T, numMeshNod time.Sleep(time.Second) - return mnet, runningNodes, clients + return runningNodes, mockWhitelist, clients } // checkFullyConnected verifies that all provided nodes are connected to each other. @@ -649,6 +636,10 @@ func linkNodeWithMesh(mesh mocknet.Mocknet, host host.Host, runningNodes []*Tata if _, err := mesh.LinkPeers(host.ID(), otherNode.node.ID()); err != nil { return err } + // Add addresses to each other's peerstores so the connection + // manager can dial. LinkPeers only creates a virtual link. + host.Peerstore().AddAddrs(otherNode.node.ID(), otherNode.node.Addrs(), peerstore.PermanentAddrTTL) + otherNode.node.Peerstore().AddAddrs(host.ID(), host.Addrs(), peerstore.PermanentAddrTTL) } else { if err := mesh.DisconnectPeers(host.ID(), otherNode.node.ID()); err != nil { return err @@ -682,15 +673,13 @@ func TestProgressiveMeshStartup(t *testing.T) { h3 := mesh.Host(peerIDs[2]) h4 := mesh.Host(peerIDs[3]) h5 := mesh.Host(peerIDs[4]) - mockWhitelist := &whitelist{ - peers: []*peer.AddrInfo{ - {ID: h1.ID(), Addrs: h1.Addrs()}, - {ID: h2.ID()}, - {ID: h3.ID(), Addrs: h3.Addrs()}, - {ID: h4.ID()}, - {ID: h5.ID()}, - }, - } + mockWhitelist := types.NewWhitelist([]peer.ID{ + h1.ID(), + h2.ID(), + h3.ID(), + h4.ID(), + h5.ID(), + }) // runningNodes will hold all running nodes that have been connected // to the mesh. @@ -786,11 +775,11 @@ func TestMeshRecovery(t *testing.T) { // Fully connected whitelist (no discovery required) hosts := mesh.Hosts() - var whitelistPeers []*peer.AddrInfo + var whitelistPeerIDs []peer.ID for _, h := range hosts { - whitelistPeers = append(whitelistPeers, &peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()}) + whitelistPeerIDs = append(whitelistPeerIDs, h.ID()) } - mockWhitelist := &whitelist{peers: whitelistPeers} + mockWhitelist := types.NewWhitelist(whitelistPeerIDs) // Start the nodes var nodes []*TatankaNode @@ -846,28 +835,24 @@ func TestWhitelistMismatch(t *testing.T) { hosts := mesh.Hosts() h1, h2, h3 := hosts[0], hosts[1], hosts[2] - goodWhitelist := &whitelist{ - peers: []*peer.AddrInfo{ - {ID: h1.ID(), Addrs: h1.Addrs()}, - {ID: h2.ID(), Addrs: h2.Addrs()}, - {ID: h3.ID(), Addrs: h3.Addrs()}, - }, - } + goodWhitelist := types.NewWhitelist([]peer.ID{ + h1.ID(), + h2.ID(), + h3.ID(), + }) - badWhitelist := &whitelist{ - peers: []*peer.AddrInfo{ - {ID: h1.ID(), Addrs: h1.Addrs()}, - {ID: h2.ID(), Addrs: h2.Addrs()}, - {ID: h3.ID(), Addrs: h3.Addrs()}, - {ID: randomPeerID(t), Addrs: nil}, - }, - } + badWhitelist := types.NewWhitelist([]peer.ID{ + h1.ID(), + h2.ID(), + h3.ID(), + randomPeerID(t), + }) // Start Nodes // Node 1 & 2 get the Good Whitelist // Node 3 gets the Bad Whitelist var nodes []*TatankaNode - startNode := func(h host.Host, whitelist *whitelist) (*TatankaNode, context.CancelFunc) { + startNode := func(h host.Host, whitelist *types.Whitelist) (*TatankaNode, context.CancelFunc) { err = linkNodeWithMesh(mesh, h, nodes, true) if err != nil { t.Fatal(err) @@ -1139,13 +1124,11 @@ func TestGossipSubOracleUpdates_PriceUpdates(t *testing.T) { } // Create whitelist with all nodes - whitelistPeers := make([]*peer.AddrInfo, numMeshNodes) + whitelistPeerIDs := make([]peer.ID, numMeshNodes) for i, h := range hosts { - whitelistPeers[i] = &peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()} - } - mockWhitelist := &whitelist{ - peers: whitelistPeers, + whitelistPeerIDs[i] = h.ID() } + mockWhitelist := types.NewWhitelist(whitelistPeerIDs) // Create nodes with custom oracles that track merges nodes := make([]*TatankaNode, numMeshNodes) @@ -1220,13 +1203,11 @@ func TestGossipSubOracleUpdates_FeeRateUpdates(t *testing.T) { } // Create whitelist with all nodes - whitelistPeers := make([]*peer.AddrInfo, numMeshNodes) + whitelistPeerIDs := make([]peer.ID, numMeshNodes) for i, h := range hosts { - whitelistPeers[i] = &peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()} - } - mockWhitelist := &whitelist{ - peers: whitelistPeers, + whitelistPeerIDs[i] = h.ID() } + mockWhitelist := types.NewWhitelist(whitelistPeerIDs) // Create nodes with custom oracles nodes := make([]*TatankaNode, numMeshNodes) @@ -1301,13 +1282,11 @@ func TestGossipSubOracleUpdates_MultipleNodes(t *testing.T) { } // Create whitelist with all nodes - whitelistPeers := make([]*peer.AddrInfo, numMeshNodes) + whitelistPeerIDs := make([]peer.ID, numMeshNodes) for i, h := range hosts { - whitelistPeers[i] = &peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()} - } - mockWhitelist := &whitelist{ - peers: whitelistPeers, + whitelistPeerIDs[i] = h.ID() } + mockWhitelist := types.NewWhitelist(whitelistPeerIDs) // Create nodes with custom oracles nodes := make([]*TatankaNode, numMeshNodes) @@ -1383,13 +1362,11 @@ func TestGossipSubOracleUpdates_ClientDelivery(t *testing.T) { } // Create whitelist - whitelistPeers := make([]*peer.AddrInfo, numMeshNodes) + whitelistPeerIDs := make([]peer.ID, numMeshNodes) for i, h := range meshHosts { - whitelistPeers[i] = &peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()} - } - mockWhitelist := &whitelist{ - peers: whitelistPeers, + whitelistPeerIDs[i] = h.ID() } + mockWhitelist := types.NewWhitelist(whitelistPeerIDs) // Create nodes with custom oracles that return updated prices/fee rates nodes := make([]*TatankaNode, numMeshNodes) @@ -1574,3 +1551,149 @@ func TestClientSubscriptionEvents(t *testing.T) { clients[idx].Close() } } + +// proposeOnAll proposes the same whitelist on every node. +func proposeOnAll(t *testing.T, nodes []*TatankaNode, proposed *types.Whitelist) { + t.Helper() + for i, n := range nodes { + if err := n.whitelistManager.proposeWhitelist(proposed); err != nil { + t.Fatalf("node %d: proposeWhitelist: %v", i, err) + } + } +} + +// waitTransitionComplete waits until every node's proposal is cleared, +// indicating the transition committed. +func waitTransitionComplete(t *testing.T, nodes []*TatankaNode) { + t.Helper() + for i, n := range nodes { + requireEventually(t, func() bool { + return n.whitelistManager.getLocalWhitelistState().Proposed == nil + }, 30*time.Second, 200*time.Millisecond, + "node %d did not complete whitelist transition", i) + } +} + +// TestWhitelistTransition_AllAgree verifies the end-to-end wiring: all nodes +// propose the same whitelist, the proposals propagate via gossipsub, the +// whitelist manager reaches consensus, and the new whitelist is committed. +func TestWhitelistTransition_AllAgree(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + nodes, wl, _ := fullyConnectedMeshWithClients(ctx, t, 3, 0, func(int) int { return 0 }) + + proposeOnAll(t, nodes, wl) + waitTransitionComplete(t, nodes) + + // After commit, every node's current whitelist should match the proposal. + for i, n := range nodes { + cur := n.whitelistManager.getLocalWhitelistState().Current + if !cur.Equals(wl) { + t.Fatalf("node %d: current whitelist doesn't match proposed after transition", i) + } + } +} + +// TestWhitelistTransition_AddNode verifies that adding a new node via whitelist +// transition works end-to-end: existing nodes propose a whitelist that includes +// a new node, the transition completes, and the connection manager reconciles +// trackers so the new node becomes connected. +func TestWhitelistTransition_AddNode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + const numPeers = 3 + mnet, err := mocknet.WithNPeers(numPeers) + if err != nil { + t.Fatal(err) + } + + hosts := mnet.Hosts() + whitelistPeerIDs := make([]peer.ID, numPeers) + for i, h := range hosts { + whitelistPeerIDs[i] = h.ID() + } + initialWL := types.NewWhitelist(whitelistPeerIDs) + + existingNodes := make([]*TatankaNode, 0, numPeers) + for i, h := range hosts { + if err := linkNodeWithMesh(mnet, h, existingNodes, true); err != nil { + t.Fatalf("Failed to link node %d: %v", i, err) + } + existingNodes = append(existingNodes, newTestNode(t, ctx, h, t.TempDir(), initialWL)) + } + checkFullyConnected(t, existingNodes) + + // Create the 4th host and link it to the existing mesh. + newHost, err := mnet.GenPeer() + if err != nil { + t.Fatal(err) + } + for _, n := range existingNodes { + if _, err := mnet.LinkPeers(newHost.ID(), n.node.ID()); err != nil { + t.Fatalf("LinkPeers: %v", err) + } + newHost.Peerstore().AddAddrs(n.node.ID(), n.node.Addrs(), peerstore.PermanentAddrTTL) + n.node.Peerstore().AddAddrs(newHost.ID(), newHost.Addrs(), peerstore.PermanentAddrTTL) + } + + // Proposed whitelist: all 3 existing + new node. + proposedIDs := make([]peer.ID, 0, 4) + for _, n := range existingNodes { + proposedIDs = append(proposedIDs, n.node.ID()) + } + proposedIDs = append(proposedIDs, newHost.ID()) + proposedWL := types.NewWhitelist(proposedIDs) + + // Start the new node with the proposed whitelist as its initial whitelist. + _ = newTestNode(t, ctx, newHost, t.TempDir(), proposedWL) + + // All existing nodes propose the new whitelist. + proposeOnAll(t, existingNodes, proposedWL) + waitTransitionComplete(t, existingNodes) + + // After transition, the existing nodes should have reconciled trackers + // and connected to the new node. + for i, n := range existingNodes { + requireEventually(t, func() bool { + return n.node.Network().Connectedness(newHost.ID()) == network.Connected + }, 15*time.Second, 200*time.Millisecond, + "node %d not connected to new node after transition", i) + } +} + +// TestWhitelistTransition_RemoveNode verifies that removing a node via +// whitelist transition works: overlap nodes agree, the transition completes, +// and the connection manager stops tracking the removed node. +func TestWhitelistTransition_RemoveNode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + nodes, _, _ := fullyConnectedMeshWithClients(ctx, t, 3, 0, func(int) int { return 0 }) + + // Proposed whitelist: only nodes 0 and 1 (remove node 2). + proposed := types.NewWhitelist([]peer.ID{nodes[0].node.ID(), nodes[1].node.ID()}) + + // Both overlap nodes propose. + for i := 0; i < 2; i++ { + if err := nodes[i].whitelistManager.proposeWhitelist(proposed); err != nil { + t.Fatalf("node %d: proposeWhitelist: %v", i, err) + } + } + + // Transition should complete on the overlap nodes. + waitTransitionComplete(t, nodes[:2]) + + // After transition, the removed node's tracker should be stopped. + removedID := nodes[2].node.ID() + for i := 0; i < 2; i++ { + requireEventually(t, func() bool { + nodes[i].connectionManager.trackersMtx.RLock() + _, tracked := nodes[i].connectionManager.peerTrackers[removedID] + nodes[i].connectionManager.trackersMtx.RUnlock() + return !tracked + }, 10*time.Second, 100*time.Millisecond, + "node %d still tracking removed node", i) + } +} diff --git a/tatanka/types/whitelist.go b/tatanka/types/whitelist.go new file mode 100644 index 0000000..793712b --- /dev/null +++ b/tatanka/types/whitelist.go @@ -0,0 +1,152 @@ +package types + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "sort" + "strings" + "sync" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// WhitelistState contains a node's whitelist-related state. Current is the +// active whitelist. Proposed, when non-nil, is the whitelist the node wants +// to transition to. Ready indicates the node has observed sufficient +// agreement on the proposal (phase 1 of the two-phase commit) and will +// commit the switch once all ready peers are confirmed. +type WhitelistState struct { + Current *Whitelist `json:"current"` + Proposed *Whitelist `json:"proposed,omitempty"` + Ready bool `json:"ready"` +} + +// DeepCopy returns a fully independent copy of the whitelist state. +func (ws *WhitelistState) DeepCopy() *WhitelistState { + if ws == nil { + return nil + } + return &WhitelistState{ + Current: ws.Current.DeepCopy(), + Proposed: ws.Proposed.DeepCopy(), + Ready: ws.Ready, + } +} + +// Whitelist is a set of peer IDs authorized to participate in the mesh. +// The hash is lazily computed and cached for fast equality comparisons +// during consensus checks. +type Whitelist struct { + PeerIDs map[peer.ID]struct{} + hashOnce sync.Once + hashVal string +} + +// DeepCopy returns a fully independent copy of the whitelist. +func (w *Whitelist) DeepCopy() *Whitelist { + if w == nil { + return nil + } + peerIDs := make(map[peer.ID]struct{}, len(w.PeerIDs)) + for pid := range w.PeerIDs { + peerIDs[pid] = struct{}{} + } + return &Whitelist{PeerIDs: peerIDs} +} + +// NewWhitelist creates a new Whitelist from a slice of peer IDs. +func NewWhitelist(peerIDs []peer.ID) *Whitelist { + set := make(map[peer.ID]struct{}, len(peerIDs)) + for _, pid := range peerIDs { + set[pid] = struct{}{} + } + return &Whitelist{PeerIDs: set} +} + +// hashPeerIDSet returns a deterministic hex hash for a set of peer IDs. +// Returns an empty string for nil or empty sets. +func hashPeerIDSet(ids map[peer.ID]struct{}) string { + if len(ids) == 0 { + return "" + } + sorted := make([]string, 0, len(ids)) + for id := range ids { + sorted = append(sorted, id.String()) + } + sort.Strings(sorted) + h := sha256.Sum256([]byte(strings.Join(sorted, ","))) + return hex.EncodeToString(h[:]) +} + +// Hash returns a deterministic hex hash for the whitelist's peer IDs. +// The result is cached after the first call. +func (w *Whitelist) Hash() string { + w.hashOnce.Do(func() { + w.hashVal = hashPeerIDSet(w.PeerIDs) + }) + return w.hashVal +} + +// Equals returns true if both whitelists contain exactly the same +// set of peer IDs. +func (w *Whitelist) Equals(other *Whitelist) bool { + if w == nil || other == nil { + return w == other + } + return w.Hash() == other.Hash() +} + +// MarshalJSON serializes the whitelist as a sorted JSON array of peer ID strings. +func (w *Whitelist) MarshalJSON() ([]byte, error) { + strs := make([]string, 0, len(w.PeerIDs)) + for pid := range w.PeerIDs { + strs = append(strs, pid.String()) + } + sort.Strings(strs) + return json.Marshal(strs) +} + +// UnmarshalJSON deserializes a JSON array of peer ID strings into the whitelist. +func (w *Whitelist) UnmarshalJSON(data []byte) error { + var strs []string + if err := json.Unmarshal(data, &strs); err != nil { + return err + } + w.PeerIDs = make(map[peer.ID]struct{}, len(strs)) + for _, s := range strs { + pid, err := peer.Decode(s) + if err != nil { + return err + } + w.PeerIDs[pid] = struct{}{} + } + return nil +} + +// PeerIDsBytes returns the whitelist peer IDs as byte slices. +func (w *Whitelist) PeerIDsBytes() [][]byte { + ids := make([][]byte, 0, len(w.PeerIDs)) + for pid := range w.PeerIDs { + ids = append(ids, []byte(pid)) + } + return ids +} + +// DecodePeerIDBytes decodes a slice of byte-encoded peer IDs (as produced by +// PeerIDsBytes) into a Whitelist. Entries that fail to decode are silently +// skipped. A nil or empty input returns nil. +func DecodePeerIDBytes(raw [][]byte) *Whitelist { + if len(raw) == 0 { + return nil + } + set := make(map[peer.ID]struct{}, len(raw)) + for _, b := range raw { + pid, err := peer.IDFromBytes(b) + if err != nil { + continue + } + set[pid] = struct{}{} + } + return &Whitelist{PeerIDs: set} +} diff --git a/tatanka/types/whitelist_test.go b/tatanka/types/whitelist_test.go new file mode 100644 index 0000000..c7779d9 --- /dev/null +++ b/tatanka/types/whitelist_test.go @@ -0,0 +1,78 @@ +package types + +import ( + "encoding/json" + "testing" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +func randomPeerID(t *testing.T) peer.ID { + t.Helper() + _, pub, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) + if err != nil { + t.Fatalf("Failed to generate key pair: %v", err) + } + id, err := peer.IDFromPublicKey(pub) + if err != nil { + t.Fatalf("Failed to create peer ID: %v", err) + } + return id +} + +func TestWhitelistEquals(t *testing.T) { + id1 := randomPeerID(t) + id2 := randomPeerID(t) + id3 := randomPeerID(t) + + wl1 := NewWhitelist([]peer.ID{id1, id2}) + wl2 := NewWhitelist([]peer.ID{id2, id1}) // same IDs, different order + wl3 := NewWhitelist([]peer.ID{id1, id3}) // different set + wl4 := NewWhitelist([]peer.ID{id1}) // subset + + if !wl1.Equals(wl2) { + t.Error("Expected equal peer IDs (same set, different order)") + } + if wl1.Equals(wl3) { + t.Error("Expected unequal peer IDs (different set)") + } + if wl1.Equals(wl4) { + t.Error("Expected unequal peer IDs (subset)") + } +} + +func TestPeerIDBytesRoundTrip(t *testing.T) { + id1 := randomPeerID(t) + id2 := randomPeerID(t) + id3 := randomPeerID(t) + + wl := NewWhitelist([]peer.ID{id1, id2, id3}) + decoded := DecodePeerIDBytes(wl.PeerIDsBytes()) + + if !wl.Equals(decoded) { + t.Error("Round-tripped whitelist does not match original") + } +} + +func TestJSONRoundTrip(t *testing.T) { + id1 := randomPeerID(t) + id2 := randomPeerID(t) + id3 := randomPeerID(t) + + wl := NewWhitelist([]peer.ID{id1, id2, id3}) + + data, err := json.Marshal(wl) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + var decoded Whitelist + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + + if !wl.Equals(&decoded) { + t.Error("Round-tripped whitelist does not match original") + } +} diff --git a/tatanka/whitelist.go b/tatanka/whitelist.go index 7037f8a..33d9bf9 100644 --- a/tatanka/whitelist.go +++ b/tatanka/whitelist.go @@ -3,96 +3,108 @@ package tatanka import ( "encoding/json" "errors" + "fmt" "os" + "path/filepath" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" ) -type whitelistPeer struct { - ID string `json:"id"` - Address string `json:"address,omitempty"` -} - -type whitelistFile struct { - Peers []whitelistPeer `json:"peers"` -} - -// whitelist is the parsed whitelist file. -type whitelist struct { - peers []*peer.AddrInfo -} - -func (m *whitelist) allPeerIDs() map[peer.ID]struct{} { - allPeerIDs := make(map[peer.ID]struct{}, len(m.peers)) - for _, peer := range m.peers { - allPeerIDs[peer.ID] = struct{}{} - } - return allPeerIDs -} - -func (m *whitelist) toFile() whitelistFile { - peers := make([]whitelistPeer, len(m.peers)) - - for i, peer := range m.peers { - address := "" - if len(peer.Addrs) > 0 { - address = peer.Addrs[0].String() - } - peers[i] = whitelistPeer{ - ID: peer.ID.String(), - Address: address, - } - } - - return whitelistFile{ - Peers: peers, - } -} - -// peerIDsBytes returns the whitelist peer IDs as byte slices. -func (m *whitelist) peerIDsBytes() [][]byte { - ids := make([][]byte, 0, len(m.peers)) - for _, p := range m.peers { - ids = append(ids, []byte(p.ID)) +// saveWhitelist writes the whitelist to disk atomically. +func saveWhitelist(path string, wl *types.Whitelist) error { + data, err := json.Marshal(wl) + if err != nil { + return err } - return ids + return atomicWriteFile(path, data) } -func loadWhitelist(path string) (*whitelist, error) { +// loadWhitelist loads a whitelist from disk. +func loadWhitelist(path string) (*types.Whitelist, error) { if path == "" { return nil, errors.New("no whitelist path provided") } - var file whitelistFile data, err := os.ReadFile(path) if err != nil { return nil, err } - if err := json.Unmarshal(data, &file); err != nil { + + var wl types.Whitelist + if err := json.Unmarshal(data, &wl); err != nil { return nil, err } + return &wl, nil +} + +// flexibleWhitelistMatch returns true when any of {myCurrentWl, myProposedWl} +// matches any of {theirCurrent, theirProposed}. myProposedWl may be nil. +func flexibleWhitelistMatch(myCurrentWl, myProposedWl *types.Whitelist, theirCurrent, theirProposed [][]byte) bool { + myWhitelists := []*types.Whitelist{myCurrentWl} + if myProposedWl != nil { + myWhitelists = append(myWhitelists, myProposedWl) + } - peers := make([]*peer.AddrInfo, 0, len(file.Peers)) + theirCurrentWl := types.DecodePeerIDBytes(theirCurrent) + theirProposedWl := types.DecodePeerIDBytes(theirProposed) + + for _, myWl := range myWhitelists { + if theirCurrentWl != nil && myWl.Equals(theirCurrentWl) { + return true + } + if theirProposedWl != nil && myWl.Equals(theirProposedWl) { + return true + } + } + return false +} - for _, pData := range file.Peers { - peerID, err := peer.Decode(pData.ID) - if err != nil { +// initWhitelist initializes the whitelist from the provided peers or the file on disk. +// If forceUpdate is true, the providedPeers will be saved to the file on disk and used +// as the current whitelist. Otherwise, the existing whitelist in the file will be used, +// unless this is the first time the file is created. If the file already exists, forceUpdate +// is false, and the providedPeers do not match the existing whitelist, an error is returned. +func initWhitelist(dataDir string, providedPeers []peer.ID, forceUpdate bool, localPeerID peer.ID) (*types.Whitelist, error) { + wlPath := filepath.Join(dataDir, whitelistFileName) + + var providedWL *types.Whitelist + if len(providedPeers) > 0 { + // Make sure that the local peer ID is included in the provided whitelist. + providedWL = types.NewWhitelist(providedPeers) + if _, found := providedWL.PeerIDs[localPeerID]; !found { + return nil, fmt.Errorf("local node peer ID %s is not included in whitelist", localPeerID) + } + } + + if forceUpdate { + if providedWL == nil { + return nil, errors.New("force whitelist update requires provided peers") + } + if err := saveWhitelist(wlPath, providedWL); err != nil { return nil, err } + return providedWL, nil + } - addrs := make([]ma.Multiaddr, 0, 1) - if pData.Address != "" { - addr, err := ma.NewMultiaddr(pData.Address) - if err != nil { + existingWL, err := loadWhitelist(wlPath) + if err != nil { + if os.IsNotExist(err) && providedWL != nil { + if err := saveWhitelist(wlPath, providedWL); err != nil { return nil, err } - addrs = append(addrs, addr) + return providedWL, nil } + return nil, err + } + + if providedWL != nil && !existingWL.Equals(providedWL) { + return nil, fmt.Errorf("provided whitelist does not match existing whitelist in %s; use --forcewl to overwrite", wlPath) + } - peers = append(peers, &peer.AddrInfo{ID: peerID, Addrs: addrs}) + if _, found := existingWL.PeerIDs[localPeerID]; !found { + return nil, fmt.Errorf("local node peer ID %s is not included in whitelist", localPeerID) } - return &whitelist{ - peers: peers, - }, nil + + return existingWL, nil } diff --git a/tatanka/whitelist_manager.go b/tatanka/whitelist_manager.go new file mode 100644 index 0000000..4e0dcc1 --- /dev/null +++ b/tatanka/whitelist_manager.go @@ -0,0 +1,412 @@ +package tatanka + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/bisoncraft/mesh/tatanka/types" + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/peer" +) + +// timestampedWhitelistState wraps a WhitelistState with a timestamp used for +// staleness checks during gossip. +type timestampedWhitelistState struct { + types.WhitelistState + timestamp int64 +} + +// whitelistManagerConfig holds configuration for creating a whitelistManager. +type whitelistManagerConfig struct { + log slog.Logger + peerID peer.ID + whitelist *types.Whitelist + isConnected func(peer.ID) bool + whitelistUpdated func(newWl *types.Whitelist) + broadcastLocalState func(ws *types.WhitelistState) +} + +// whitelistManager manages whitelist updates and state transitions. +// It accepts updates to the whitelist states of peers and the local node, +// and checks for consensus on a transition using a two-phase flow: +// 1. Ready phase: for overlap peers (present in both current and proposed +// whitelists), require a 2/3 quorum over the full overlap set and no +// online overlap peer disagreement. +// 2. Commit phase: once ready, require every online overlap peer to be ready +// (or already switched) before switching Current to Proposed. +// +// Because quorum is computed from the full overlap set, offline overlap peers +// can block transition progress until they reconnect or are removed. +type whitelistManager struct { + log slog.Logger + peerID peer.ID + isConnected func(peer.ID) bool + + mtx sync.RWMutex + localState types.WhitelistState + peerStates map[peer.ID]*timestampedWhitelistState + + whitelistUpdated func(newWl *types.Whitelist) + broadcastLocalState func(ws *types.WhitelistState) +} + +func newWhitelistManager(cfg *whitelistManagerConfig) *whitelistManager { + return &whitelistManager{ + log: cfg.log, + peerID: cfg.peerID, + isConnected: cfg.isConnected, + localState: types.WhitelistState{ + Current: cfg.whitelist, + }, + peerStates: make(map[peer.ID]*timestampedWhitelistState), + whitelistUpdated: cfg.whitelistUpdated, + broadcastLocalState: cfg.broadcastLocalState, + } +} + +// heartbeat publishes the local node's current state if a proposal is active. +func (wm *whitelistManager) heartbeat() { + wm.mtx.RLock() + if wm.localState.Proposed == nil { + wm.mtx.RUnlock() + return + } + stateCopy := wm.localState.DeepCopy() + wm.mtx.RUnlock() + + wm.broadcastLocalState(stateCopy) +} + +// run periodically heartbeats the local whitelist state while a proposal is +// active. +func (wm *whitelistManager) run(ctx context.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + wm.heartbeat() + } + } +} + +func (wm *whitelistManager) getWhitelist() *types.Whitelist { + wm.mtx.RLock() + defer wm.mtx.RUnlock() + return wm.localState.Current.DeepCopy() +} + +// getLocalWhitelistState returns a snapshot of the local node's whitelist state. +func (wm *whitelistManager) getLocalWhitelistState() *types.WhitelistState { + wm.mtx.RLock() + defer wm.mtx.RUnlock() + return wm.localState.DeepCopy() +} + +// getPeerWhitelistState returns a copy of the whitelist state for a peer. +func (wm *whitelistManager) getPeerWhitelistState(peerID peer.ID) *types.WhitelistState { + wm.mtx.RLock() + defer wm.mtx.RUnlock() + + if peerID == wm.peerID { + return wm.localState.DeepCopy() + } + + state := wm.peerStates[peerID] + if state == nil { + return nil + } + + return state.WhitelistState.DeepCopy() +} + +// updatePeerWhitelistState records a peer's whitelist state and triggers a +// transition check. Passing a nil state clears the peer's stored state, +// regardless of the timestamp. +// Returns true if the stored state for that peer changed. +func (wm *whitelistManager) updatePeerWhitelistState(peerID peer.ID, ws *types.WhitelistState, timestamp int64) bool { + if peerID == wm.peerID { + return false + } + if ws != nil && ws.Current == nil { + return false + } + + doUpdate := func() (updated, localStateChanged, whitelistSwitched bool, localStateCopy *types.WhitelistState) { + wm.mtx.Lock() + defer wm.mtx.Unlock() + + existing, exists := wm.peerStates[peerID] + + // If no update, return early + if (ws == nil && !exists) || (ws != nil && existing != nil && existing.timestamp >= timestamp) { + return false, false, false, nil + } + + updated = true + + // Update the peer's state + if ws == nil { + delete(wm.peerStates, peerID) + } else { + wsCopy := ws.DeepCopy() + wm.peerStates[peerID] = ×tampedWhitelistState{ + WhitelistState: *wsCopy, + timestamp: timestamp, + } + } + + // Check if this update causes a transition. + localStateChanged, whitelistSwitched = wm.checkTransitionLocked() + if localStateChanged { + localStateCopy = wm.localState.DeepCopy() + } + + return updated, localStateChanged, whitelistSwitched, localStateCopy + } + + updated, localStateChanged, whitelistSwitched, localStateCopy := doUpdate() + if !updated { + return false + } + if localStateChanged { + wm.broadcastLocalState(localStateCopy) + } + if whitelistSwitched { + wm.whitelistUpdated(localStateCopy.Current) + } + + return true +} + +// proposeWhitelist sets the proposed whitelist. Connections are not changed +// until the transition completes. The proposed whitelist must include the +// local node. +func (wm *whitelistManager) proposeWhitelist(wl *types.Whitelist) error { + if wl == nil || wl.PeerIDs == nil { + return errors.New("proposed whitelist cannot be nil") + } + if _, ok := wl.PeerIDs[wm.peerID]; !ok { + return errors.New("proposed whitelist must include the local node") + } + + wm.mtx.Lock() + wm.localState.Proposed = wl.DeepCopy() + wm.localState.Ready = false + _, whitelistSwitched := wm.checkTransitionLocked() + stateCopy := wm.localState.DeepCopy() + wm.mtx.Unlock() + + wm.broadcastLocalState(stateCopy) + if whitelistSwitched { + wm.whitelistUpdated(stateCopy.Current) + } + + return nil +} + +// clearProposal removes the active proposal. +func (wm *whitelistManager) clearProposal() { + wm.mtx.Lock() + if wm.localState.Proposed == nil { + wm.mtx.Unlock() + return + } + wm.localState.Proposed = nil + wm.localState.Ready = false + stateCopy := wm.localState.DeepCopy() + wm.mtx.Unlock() + + wm.broadcastLocalState(stateCopy) +} + +// forceWhitelist unilaterally replaces the current whitelist and clears +// any proposed whitelist. +func (wm *whitelistManager) forceWhitelist(wl *types.Whitelist) error { + if wl == nil || wl.PeerIDs == nil { + return errors.New("force whitelist cannot be nil") + } + if _, ok := wl.PeerIDs[wm.peerID]; !ok { + return errors.New("force whitelist must include the local node") + } + + wm.mtx.Lock() + if wl.Equals(wm.localState.Current) { + wm.mtx.Unlock() + return errors.New("force whitelist must be different from the current whitelist") + } + wm.localState.Current = wl.DeepCopy() + wm.localState.Proposed = nil + wm.localState.Ready = false + stateCopy := wm.localState.DeepCopy() + wm.mtx.Unlock() + + wm.broadcastLocalState(stateCopy) + wm.whitelistUpdated(wl) + + wm.log.Infof("Force whitelist applied with %d peers", len(wl.PeerIDs)) + + return nil +} + +// checkTransitionResult indicates the outcome of a transition check. +type checkTransitionResult uint8 + +const ( + transitionNoChange checkTransitionResult = iota + transitionMarkReady + transitionUnmarkReady + transitionWhitelist +) + +// checkTransitionLocked checks if the current whitelist state requires a change +// and if so, performs the change. +// This function MUST be called with the lock held. +func (wm *whitelistManager) checkTransitionLocked() (localStateChanged, whitelistSwitched bool) { + if wm.localState.Proposed == nil { + return false, false + } + + connectedPeers := make(map[peer.ID]bool) + for pid := range wm.localState.Current.PeerIDs { + connectedPeers[pid] = wm.isConnected(pid) + } + for pid := range wm.localState.Proposed.PeerIDs { + if _, exists := connectedPeers[pid]; !exists { + connectedPeers[pid] = wm.isConnected(pid) + } + } + + switch checkTransition( + wm.localState, + wm.peerID, + wm.peerStates, + connectedPeers, + ) { + case transitionMarkReady: + wm.localState.Ready = true + localStateChanged = true + wm.log.Infof("Whitelist transition: ready to switch") + case transitionUnmarkReady: + wm.localState.Ready = false + localStateChanged = true + wm.log.Infof("Whitelist transition: ready state regressed, resetting") + case transitionWhitelist: + wm.localState.Current = wm.localState.Proposed + wm.localState.Proposed = nil + wm.localState.Ready = false + localStateChanged = true + whitelistSwitched = true + wm.log.Infof("Whitelist transition complete: switched to new whitelist with %d peers", len(wm.localState.Current.PeerIDs)) + } + return localStateChanged, whitelistSwitched +} + +// checkTransition checks if a transition is warranted based on the current +// whitelist state of the local node and all the peers. +func checkTransition( + localState types.WhitelistState, + localPeerID peer.ID, + peerStates map[peer.ID]*timestampedWhitelistState, + connectedPeers map[peer.ID]bool, +) checkTransitionResult { + proposed := localState.Proposed + if proposed == nil { + return transitionNoChange + } + + currentPeers := localState.Current.PeerIDs + proposedPeers := proposed.PeerIDs + + // Compute overlap: peers in both current and proposed whitelists. + overlap := make(map[peer.ID]struct{}) + for pid := range currentPeers { + if _, ok := proposedPeers[pid]; ok { + overlap[pid] = struct{}{} + } + } + + nOverlap := len(overlap) + if nOverlap == 0 { + // Should never happen as the local node must be in the proposed whitelist + return transitionNoChange + } + threshold := (2*nOverlap + 2) / 3 // ceil(2/3 * |overlap|) + + onlineAgreeing := 0 + onlineNotAgreeing := 0 + + for pid := range overlap { + if pid == localPeerID { + onlineAgreeing++ + continue + } + + pp, hasPP := peerStates[pid] + connected := connectedPeers[pid] + + if !connected { + // Disconnected overlap peers are not counted as disagreeing, but the + // fixed threshold above still includes them, so they can block readiness. + continue + } + + if !hasPP { + onlineNotAgreeing++ + continue + } + + peerProposedMatchesOurs := pp.Proposed != nil && pp.Proposed.Equals(proposed) + alreadySwitched := pp.Current.Equals(proposed) + + if peerProposedMatchesOurs || alreadySwitched { + onlineAgreeing++ + } else { + onlineNotAgreeing++ + } + } + + // Phase 1: Ready check. Ready requires quorum over the full overlap set + // and no online overlap peer actively disagreeing. + shouldBeReady := onlineAgreeing >= threshold && onlineNotAgreeing == 0 + if !shouldBeReady { + if localState.Ready { + return transitionUnmarkReady + } + return transitionNoChange + } + + // Phase 2: Commit check — all online overlap peers must be ready or + // already switched. + canCommit := true + for pid := range overlap { + if pid == localPeerID { + continue + } + if !connectedPeers[pid] { + continue + } + pp, hasPP := peerStates[pid] + if !hasPP { + canCommit = false + break + } + alreadySwitched := pp.Current.Equals(proposed) + if !pp.Ready && !alreadySwitched { + canCommit = false + break + } + } + + if canCommit { + return transitionWhitelist + } + if !localState.Ready { + return transitionMarkReady + } + return transitionNoChange +} diff --git a/tatanka/whitelist_manager_test.go b/tatanka/whitelist_manager_test.go new file mode 100644 index 0000000..932ef10 --- /dev/null +++ b/tatanka/whitelist_manager_test.go @@ -0,0 +1,654 @@ +package tatanka + +import ( + "io" + "sync" + "testing" + + "github.com/bisoncraft/mesh/tatanka/types" + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/peer" +) + +// TestCheckTransition tests the core transition logic. +func TestCheckTransition(t *testing.T) { + local := randomPeerID(t) + var p [5]peer.ID + for i := range p { + p[i] = randomPeerID(t) + } + + wl := types.NewWhitelist + + ps := func(current, proposed *types.Whitelist, ready bool) *timestampedWhitelistState { + return ×tampedWhitelistState{ + WhitelistState: types.WhitelistState{ + Current: current, + Proposed: proposed, + Ready: ready, + }, + } + } + + // Standard whitelists: overlap={local, p0}, threshold=2 + cur2 := wl([]peer.ID{local, p[0], p[1]}) + prop2 := wl([]peer.ID{local, p[0], p[2]}) + + // Larger overlap: overlap={local, p0, p1}, threshold=2 + cur3 := wl([]peer.ID{local, p[0], p[1], p[2]}) + prop3 := wl([]peer.ID{local, p[0], p[1], p[3]}) + + tests := []struct { + name string + localState types.WhitelistState + peerStates map[peer.ID]*timestampedWhitelistState + connected map[peer.ID]bool + want checkTransitionResult + }{ + { + name: "no_proposal", + localState: types.WhitelistState{Current: cur2}, + want: transitionNoChange, + }, + { + name: "single_overlap_commits_immediately", + // only the local peer is in both whitelists, so it can transition immediately + localState: types.WhitelistState{ + Current: wl([]peer.ID{local, p[0]}), + Proposed: wl([]peer.ID{local, p[1]}), + }, + want: transitionWhitelist, + }, + { + name: "below_threshold_peer_disconnected", + // p0 not connected → agreeing=1, threshold=2 + localState: types.WhitelistState{Current: cur2, Proposed: prop2}, + want: transitionNoChange, + }, + { + name: "connected_no_state_disagrees", + // p0 connected but no stored state → counts as disagreeing + localState: types.WhitelistState{Current: cur2, Proposed: prop2}, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionNoChange, + }, + { + name: "different_proposal_disagrees", + localState: types.WhitelistState{Current: cur2, Proposed: prop2}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur2, wl([]peer.ID{local, p[0], p[4]}), false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionNoChange, + }, + { + name: "matching_proposed_marks_ready", + // p0 agrees (proposed matches) but not Ready → can't commit + localState: types.WhitelistState{Current: cur2, Proposed: prop2}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur2, prop2, false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionMarkReady, + }, + { + name: "already_switched_commits_immediately", + // p0 Current equals our Proposed — already switched + localState: types.WhitelistState{Current: cur2, Proposed: prop2}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(prop2, nil, false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionWhitelist, + }, + { + name: "threshold_met_disagreer_blocks", + // overlap={local, p0, p1}, threshold=2. p0 agrees, p1 disagrees. + localState: types.WhitelistState{Current: cur3, Proposed: prop3}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur3, prop3, false), + p[1]: ps(cur3, wl([]peer.ID{local, p[0], p[1], p[4]}), false), + }, + connected: map[peer.ID]bool{p[0]: true, p[1]: true}, + want: transitionNoChange, + }, + { + name: "disconnected_does_not_disagree", + // overlap={local, p0, p1}, threshold=2. p0 agrees, p1 offline. + // agreeing=2, no disagreers → ready + localState: types.WhitelistState{Current: cur3, Proposed: prop3}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur3, prop3, false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionMarkReady, + }, + { + name: "regression_below_threshold", + // p0 disconnected → agreeing=1 < threshold=2 + localState: types.WhitelistState{Current: cur2, Proposed: prop2, Ready: true}, + want: transitionUnmarkReady, + }, + { + name: "regression_disagreer", + localState: types.WhitelistState{Current: cur2, Proposed: prop2, Ready: true}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur2, wl([]peer.ID{local, p[0], p[4]}), false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionUnmarkReady, + }, + { + name: "regression_connected_no_state", + // connected but no stored state → disagrees + localState: types.WhitelistState{Current: cur2, Proposed: prop2, Ready: true}, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionUnmarkReady, + }, + + { + name: "commit_blocked_peer_not_ready", + // p0 agrees (proposed matches) but not Ready → can't commit + localState: types.WhitelistState{Current: cur2, Proposed: prop2, Ready: true}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur2, prop2, false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionNoChange, + }, + { + name: "commit_all_ready", + localState: types.WhitelistState{Current: cur2, Proposed: prop2, Ready: true}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur2, prop2, true), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionWhitelist, + }, + { + name: "commit_peer_already_switched", + // p0 Current = our Proposed, not explicitly Ready → still commits + localState: types.WhitelistState{Current: cur2, Proposed: prop2, Ready: true}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(prop2, nil, false), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionWhitelist, + }, + { + name: "commit_disconnected_peer_skipped", + // overlap={local, p0, p1}. p0 ready, p1 disconnected → p1 skipped + localState: types.WhitelistState{Current: cur3, Proposed: prop3, Ready: true}, + peerStates: map[peer.ID]*timestampedWhitelistState{ + p[0]: ps(cur3, prop3, true), + }, + connected: map[peer.ID]bool{p[0]: true}, + want: transitionWhitelist, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := checkTransition(tc.localState, local, tc.peerStates, tc.connected) + if got != tc.want { + t.Fatalf("got %d, want %d", got, tc.want) + } + }) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// whitelistManager — method tests (wiring, callbacks, input validation) +// ──────────────────────────────────────────────────────────────────────────── + +type wmHarness struct { + wm *whitelistManager + localID peer.ID + peers []peer.ID + + mu sync.Mutex + stateUpdates []*types.WhitelistState + whitelistUpdates []*types.Whitelist + connected map[peer.ID]bool +} + +func newWMHarness(t *testing.T, numPeers int) *wmHarness { + t.Helper() + + localID := randomPeerID(t) + peers := make([]peer.ID, numPeers) + for i := range peers { + peers[i] = randomPeerID(t) + } + + allPeers := make([]peer.ID, 0, numPeers+1) + allPeers = append(allPeers, localID) + allPeers = append(allPeers, peers...) + + h := &wmHarness{ + localID: localID, + peers: peers, + connected: make(map[peer.ID]bool), + } + + h.wm = newWhitelistManager(&whitelistManagerConfig{ + log: slog.NewBackend(io.Discard).Logger("test"), + peerID: localID, + isConnected: func(pid peer.ID) bool { + h.mu.Lock() + defer h.mu.Unlock() + return h.connected[pid] + }, + whitelist: types.NewWhitelist(allPeers), + whitelistUpdated: func(newWl *types.Whitelist) { + h.mu.Lock() + h.whitelistUpdates = append(h.whitelistUpdates, newWl) + h.mu.Unlock() + }, + broadcastLocalState: func(ws *types.WhitelistState) { + h.mu.Lock() + h.stateUpdates = append(h.stateUpdates, ws) + h.mu.Unlock() + }, + }) + + return h +} + +func (h *wmHarness) setConnected(pids ...peer.ID) { + h.mu.Lock() + defer h.mu.Unlock() + for _, pid := range pids { + h.connected[pid] = true + } +} + +func (h *wmHarness) resetRecords() { + h.mu.Lock() + h.stateUpdates = nil + h.whitelistUpdates = nil + h.mu.Unlock() +} + +func TestWMInitialState(t *testing.T) { + h := newWMHarness(t, 2) + + state := h.wm.getLocalWhitelistState() + if state.Current == nil || len(state.Current.PeerIDs) != 3 { + t.Fatalf("expected 3-peer current whitelist, got %v", state.Current) + } + if state.Proposed != nil { + t.Fatal("Proposed should be nil initially") + } + if state.Ready { + t.Fatal("Ready should be false initially") + } +} + +func TestWMGetPeerWhitelistState(t *testing.T) { + h := newWMHarness(t, 2) + + t.Run("local_returns_own_state", func(t *testing.T) { + got := h.wm.getPeerWhitelistState(h.localID) + expected := h.wm.getLocalWhitelistState() + if !got.Current.Equals(expected.Current) { + t.Fatal("local peer lookup should match getLocalWhitelistState") + } + }) + + t.Run("unknown_returns_nil", func(t *testing.T) { + if h.wm.getPeerWhitelistState(randomPeerID(t)) != nil { + t.Fatal("expected nil for unknown peer") + } + }) + + t.Run("stored_peer", func(t *testing.T) { + ws := &types.WhitelistState{Current: types.NewWhitelist([]peer.ID{h.localID, h.peers[0]})} + h.wm.updatePeerWhitelistState(h.peers[0], ws, 100) + + got := h.wm.getPeerWhitelistState(h.peers[0]) + if got == nil || !got.Current.Equals(ws.Current) { + t.Fatal("stored peer state should be retrievable") + } + }) +} + +func TestWMProposeWhitelist(t *testing.T) { + t.Run("stores_and_broadcasts", func(t *testing.T) { + h := newWMHarness(t, 2) + proposed := types.NewWhitelist([]peer.ID{h.localID, h.peers[0], randomPeerID(t)}) + if err := h.wm.proposeWhitelist(proposed); err != nil { + t.Fatal(err) + } + + state := h.wm.getLocalWhitelistState() + if !state.Proposed.Equals(proposed) { + t.Fatal("Proposed doesn't match") + } + if state.Ready { + t.Fatal("should not be ready") + } + + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 1 { + t.Fatalf("expected 1 broadcast, got %d", len(h.stateUpdates)) + } + }) + + t.Run("rejects_nil", func(t *testing.T) { + h := newWMHarness(t, 1) + if err := h.wm.proposeWhitelist(nil); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("rejects_without_local", func(t *testing.T) { + h := newWMHarness(t, 2) + if err := h.wm.proposeWhitelist(types.NewWhitelist([]peer.ID{h.peers[0]})); err == nil { + t.Fatal("expected error") + } + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 0 { + t.Fatal("should not broadcast on error") + } + }) + + t.Run("deep_copies_input", func(t *testing.T) { + h := newWMHarness(t, 1) + proposed := types.NewWhitelist([]peer.ID{h.localID, h.peers[0]}) + h.wm.proposeWhitelist(proposed) + + proposed.PeerIDs[randomPeerID(t)] = struct{}{} + + state := h.wm.getLocalWhitelistState() + if len(state.Proposed.PeerIDs) != 2 { + t.Fatal("stored proposal should not be affected by caller mutation") + } + }) +} + +func TestWMClearProposal(t *testing.T) { + t.Run("clears_and_broadcasts", func(t *testing.T) { + h := newWMHarness(t, 2) + h.wm.proposeWhitelist(types.NewWhitelist([]peer.ID{h.localID, h.peers[0], randomPeerID(t)})) + h.resetRecords() + + h.wm.clearProposal() + + if h.wm.getLocalWhitelistState().Proposed != nil { + t.Fatal("Proposed should be nil after clear") + } + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 1 { + t.Fatalf("expected 1 broadcast, got %d", len(h.stateUpdates)) + } + }) + + t.Run("noop_without_proposal", func(t *testing.T) { + h := newWMHarness(t, 1) + h.wm.clearProposal() + + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 0 { + t.Fatal("should not broadcast on noop") + } + }) +} + +func TestWMForceWhitelist(t *testing.T) { + t.Run("replaces_current_and_clears_proposal", func(t *testing.T) { + h := newWMHarness(t, 2) + h.wm.proposeWhitelist(types.NewWhitelist([]peer.ID{h.localID, h.peers[0], randomPeerID(t)})) + // Store a peer state to verify it survives force. + peerWS := &types.WhitelistState{Current: types.NewWhitelist([]peer.ID{h.localID, h.peers[0]})} + h.wm.updatePeerWhitelistState(h.peers[0], peerWS, 100) + h.resetRecords() + + forced := types.NewWhitelist([]peer.ID{h.localID, h.peers[0]}) + if err := h.wm.forceWhitelist(forced); err != nil { + t.Fatal(err) + } + + state := h.wm.getLocalWhitelistState() + if !state.Current.Equals(forced) { + t.Fatal("Current should match forced") + } + if state.Proposed != nil { + t.Fatal("Proposed should be nil after force") + } + if h.wm.getPeerWhitelistState(h.peers[0]) == nil { + t.Fatal("peer state should be preserved after force") + } + + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 1 { + t.Fatalf("expected 1 broadcast, got %d", len(h.stateUpdates)) + } + if len(h.whitelistUpdates) != 1 || !h.whitelistUpdates[0].Equals(forced) { + t.Fatal("expected whitelistUpdated with forced whitelist") + } + }) + + t.Run("rejects_nil", func(t *testing.T) { + h := newWMHarness(t, 1) + if err := h.wm.forceWhitelist(nil); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("rejects_without_local", func(t *testing.T) { + h := newWMHarness(t, 2) + if err := h.wm.forceWhitelist(types.NewWhitelist([]peer.ID{h.peers[0]})); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("rejects_same_whitelist", func(t *testing.T) { + h := newWMHarness(t, 2) + if err := h.wm.forceWhitelist(h.wm.getWhitelist()); err == nil { + t.Fatal("expected error for identical whitelist") + } + }) +} + +func TestWMUpdatePeerState(t *testing.T) { + t.Run("stores_and_retrieves", func(t *testing.T) { + h := newWMHarness(t, 2) + ws := &types.WhitelistState{Current: types.NewWhitelist([]peer.ID{h.localID, h.peers[0]})} + + if !h.wm.updatePeerWhitelistState(h.peers[0], ws, 100) { + t.Fatal("expected true for new state") + } + + got := h.wm.getPeerWhitelistState(h.peers[0]) + if got == nil || !got.Current.Equals(ws.Current) { + t.Fatal("stored state should match") + } + + // No proposal active → no transition → no broadcast. + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 0 { + t.Fatalf("expected 0 broadcasts, got %d", len(h.stateUpdates)) + } + }) + + t.Run("rejects_self", func(t *testing.T) { + h := newWMHarness(t, 1) + ws := &types.WhitelistState{Current: types.NewWhitelist([]peer.ID{h.localID})} + if h.wm.updatePeerWhitelistState(h.localID, ws, 100) { + t.Fatal("self-update should return false") + } + }) + + t.Run("rejects_nil_current", func(t *testing.T) { + h := newWMHarness(t, 1) + if h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{}, 100) { + t.Fatal("nil Current should return false") + } + if h.wm.getPeerWhitelistState(h.peers[0]) != nil { + t.Fatal("nil Current should not be stored") + } + }) + + t.Run("stale_timestamp_rejected", func(t *testing.T) { + h := newWMHarness(t, 2) + wl1 := types.NewWhitelist([]peer.ID{h.localID, h.peers[0]}) + wl2 := types.NewWhitelist([]peer.ID{h.localID, h.peers[1]}) + + h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{Current: wl1}, 200) + + if h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{Current: wl2}, 100) { + t.Fatal("older timestamp should return false") + } + if h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{Current: wl2}, 200) { + t.Fatal("equal timestamp should return false") + } + + got := h.wm.getPeerWhitelistState(h.peers[0]) + if !got.Current.Equals(wl1) { + t.Fatal("state should not change from stale update") + } + }) + + t.Run("nil_clears_existing", func(t *testing.T) { + h := newWMHarness(t, 1) + ws := &types.WhitelistState{Current: types.NewWhitelist([]peer.ID{h.localID, h.peers[0]})} + h.wm.updatePeerWhitelistState(h.peers[0], ws, 100) + + if !h.wm.updatePeerWhitelistState(h.peers[0], nil, 0) { + t.Fatal("nil clear should return true") + } + if h.wm.getPeerWhitelistState(h.peers[0]) != nil { + t.Fatal("state should be nil after clear") + } + }) + + t.Run("nil_noop_for_absent", func(t *testing.T) { + h := newWMHarness(t, 1) + if h.wm.updatePeerWhitelistState(h.peers[0], nil, 0) { + t.Fatal("nil for absent peer should return false") + } + }) + + t.Run("deep_copies_input", func(t *testing.T) { + h := newWMHarness(t, 1) + original := types.NewWhitelist([]peer.ID{h.localID, h.peers[0]}) + ws := &types.WhitelistState{Current: original} + h.wm.updatePeerWhitelistState(h.peers[0], ws, 100) + + // Mutate caller-owned state after update. + ws.Current.PeerIDs[randomPeerID(t)] = struct{}{} + + got := h.wm.getPeerWhitelistState(h.peers[0]) + if len(got.Current.PeerIDs) != 2 { + t.Fatalf("expected 2 peers in stored state, got %d", len(got.Current.PeerIDs)) + } + }) +} + +// TestWMTransitionEndToEnd verifies the full propose → agree → ready → +// commit flow through the whitelistManager methods. +func TestWMTransitionEndToEnd(t *testing.T) { + h := newWMHarness(t, 2) // current = {local, p0, p1} + h.setConnected(h.peers[0]) + + proposed := types.NewWhitelist([]peer.ID{h.localID, h.peers[0], randomPeerID(t)}) + // overlap = {local, p0}, threshold = 2 + + if err := h.wm.proposeWhitelist(proposed); err != nil { + t.Fatal(err) + } + h.resetRecords() + + // Step 1: Peer reports matching proposal → local becomes ready. + current := h.wm.getWhitelist() + h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{ + Current: current, + Proposed: proposed, + }, 100) + + state := h.wm.getLocalWhitelistState() + if !state.Ready { + t.Fatal("should be ready after peer agrees") + } + + h.mu.Lock() + if len(h.stateUpdates) != 1 { + t.Fatalf("expected 1 broadcast for ready, got %d", len(h.stateUpdates)) + } + h.mu.Unlock() + h.resetRecords() + + // Step 2: Peer reports ready → whitelist commits. + h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{ + Current: current, + Proposed: proposed, + Ready: true, + }, 200) + + state = h.wm.getLocalWhitelistState() + if !state.Current.Equals(proposed) { + t.Fatal("Current should be the proposed whitelist after commit") + } + if state.Proposed != nil { + t.Fatal("Proposed should be nil after commit") + } + + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 1 { + t.Fatalf("expected 1 broadcast for commit, got %d", len(h.stateUpdates)) + } + if len(h.whitelistUpdates) != 1 || !h.whitelistUpdates[0].Equals(proposed) { + t.Fatal("expected whitelistUpdated with new whitelist") + } +} + +// TestWMNilClearTriggersRegression verifies that clearing a peer's state via +// nil update re-evaluates the transition and regresses readiness when the +// agreeing peer is removed. +func TestWMNilClearTriggersRegression(t *testing.T) { + h := newWMHarness(t, 2) // current = {local, p0, p1} + h.setConnected(h.peers[0]) + + proposed := types.NewWhitelist([]peer.ID{h.localID, h.peers[0], randomPeerID(t)}) + if err := h.wm.proposeWhitelist(proposed); err != nil { + t.Fatal(err) + } + + // Peer agrees → local becomes ready. + h.wm.updatePeerWhitelistState(h.peers[0], &types.WhitelistState{ + Current: h.wm.getWhitelist(), + Proposed: proposed, + }, 100) + + if !h.wm.getLocalWhitelistState().Ready { + t.Fatal("should be ready before clearing peer state") + } + h.resetRecords() + + // Clear the agreeing peer → ready should regress. + if !h.wm.updatePeerWhitelistState(h.peers[0], nil, 0) { + t.Fatal("nil clear should return true") + } + + state := h.wm.getLocalWhitelistState() + if state.Ready { + t.Fatal("should regress from ready after peer state clear") + } + if state.Proposed == nil { + t.Fatal("proposal should remain pending") + } + + h.mu.Lock() + defer h.mu.Unlock() + if len(h.stateUpdates) != 1 { + t.Fatalf("expected 1 broadcast for regression, got %d", len(h.stateUpdates)) + } +} diff --git a/tatanka/whitelist_test.go b/tatanka/whitelist_test.go index 872f93a..f30f371 100644 --- a/tatanka/whitelist_test.go +++ b/tatanka/whitelist_test.go @@ -1,61 +1,204 @@ package tatanka import ( - "encoding/json" - "os" "path/filepath" - "reflect" + "strings" "testing" + "github.com/bisoncraft/mesh/tatanka/types" "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" ) func TestWhitelistSaveLoad(t *testing.T) { - // Create test peer IDs and addresses peerID1 := randomPeerID(t) peerID2 := randomPeerID(t) - addr1, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1234") - if err != nil { - t.Fatalf("Failed to create multiaddr 1: %v", err) + originalWhitelist := types.NewWhitelist([]peer.ID{peerID1, peerID2}) + + path := filepath.Join(t.TempDir(), "test_whitelist.json") + if err := saveWhitelist(path, originalWhitelist); err != nil { + t.Fatalf("Failed to save whitelist: %v", err) } - addr2, err := ma.NewMultiaddr("/ip6/::1/tcp/5678") + + loadedWhitelist, err := loadWhitelist(path) if err != nil { - t.Fatalf("Failed to create multiaddr 2: %v", err) + t.Fatalf("Failed to load whitelist from file: %v", err) } - originalWhitelist := &whitelist{ - peers: []*peer.AddrInfo{ - {ID: peerID1, Addrs: []ma.Multiaddr{addr1}}, - {ID: peerID2, Addrs: []ma.Multiaddr{addr2}}, - }, + if len(loadedWhitelist.PeerIDs) != len(originalWhitelist.PeerIDs) { + t.Errorf("Expected %d peers, got %d", len(originalWhitelist.PeerIDs), len(loadedWhitelist.PeerIDs)) } - whitelistFile := originalWhitelist.toFile() - tempDir := t.TempDir() - tempFile := filepath.Join(tempDir, "test_whitelist.json") + if !loadedWhitelist.Equals(originalWhitelist) { + t.Error("Loaded whitelist does not match original whitelist") + } +} - data, err := json.Marshal(whitelistFile) - if err != nil { - t.Fatalf("Failed to marshal whitelist to JSON: %v", err) +func TestFlexibleWhitelistMatch(t *testing.T) { + id1 := randomPeerID(t) + id2 := randomPeerID(t) + id3 := randomPeerID(t) + id4 := randomPeerID(t) + + wlA := types.NewWhitelist([]peer.ID{id1, id2}) + wlB := types.NewWhitelist([]peer.ID{id1, id2, id3}) + + // Exact match: my current matches their current. + if !flexibleWhitelistMatch(wlA, nil, wlA.PeerIDsBytes(), nil) { + t.Error("Expected match: exact current↔current") } - err = os.WriteFile(tempFile, data, 0644) - if err != nil { - t.Fatalf("Failed to write whitelist to file: %v", err) + // Cross-match: my proposed matches their current. + if !flexibleWhitelistMatch(wlA, wlB, wlB.PeerIDsBytes(), nil) { + t.Error("Expected match: my proposed matches their current") + } + + // Cross-match: my current matches their proposed. + if !flexibleWhitelistMatch(wlA, nil, wlB.PeerIDsBytes(), wlA.PeerIDsBytes()) { + t.Error("Expected match: my current matches their proposed") + } + + // No match: completely different whitelists. + wlC := types.NewWhitelist([]peer.ID{id3, id4}) + if flexibleWhitelistMatch(wlA, nil, wlC.PeerIDsBytes(), nil) { + t.Error("Expected no match: completely different whitelists") + } + + // Nil proposed on both sides: only current compared. + if !flexibleWhitelistMatch(wlA, nil, wlA.PeerIDsBytes(), nil) { + t.Error("Expected match: nil proposed, current matches") + } + + // No match even with proposed. + if flexibleWhitelistMatch(wlA, nil, wlC.PeerIDsBytes(), wlC.PeerIDsBytes()) { + t.Error("Expected no match: proposed also different") + } +} + +func TestSaveAndLoadWhitelist(t *testing.T) { + id1 := randomPeerID(t) + id2 := randomPeerID(t) + + wl := types.NewWhitelist([]peer.ID{id1, id2}) + + path := filepath.Join(t.TempDir(), "wl.json") + if err := saveWhitelist(path, wl); err != nil { + t.Fatalf("saveWhitelist: %v", err) } - loadedWhitelist, err := loadWhitelist(tempFile) + loaded, err := loadWhitelist(path) if err != nil { - t.Fatalf("Failed to load whitelist from file: %v", err) + t.Fatalf("loadWhitelist: %v", err) + } + + if !wl.Equals(loaded) { + t.Error("Round-tripped whitelist peer IDs differ") } - if len(loadedWhitelist.peers) != len(originalWhitelist.peers) { - t.Errorf("Expected %d peers, got %d", len(originalWhitelist.peers), len(loadedWhitelist.peers)) + if len(loaded.PeerIDs) != 2 { + t.Fatalf("Expected 2 peers, got %d", len(loaded.PeerIDs)) + } +} + +func TestInitWhitelist(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + { + name: "requires_local_peer_when_loading_existing_whitelist", + run: func(t *testing.T) { + dataDir := t.TempDir() + localPeerID := randomPeerID(t) + otherPeerID := randomPeerID(t) + + if err := saveWhitelist(filepath.Join(dataDir, whitelistFileName), types.NewWhitelist([]peer.ID{otherPeerID})); err != nil { + t.Fatalf("saveWhitelist: %v", err) + } + + _, err := initWhitelist(dataDir, nil, false, localPeerID) + if err == nil { + t.Fatal("Expected error when local peer is missing from whitelist") + } + if !strings.Contains(err.Error(), "not included in whitelist") { + t.Fatalf("Expected missing-local-peer error, got: %v", err) + } + }, + }, + { + name: "accepts_local_peer_in_force_update", + run: func(t *testing.T) { + dataDir := t.TempDir() + localPeerID := randomPeerID(t) + otherPeerID := randomPeerID(t) + + wl, err := initWhitelist(dataDir, []peer.ID{localPeerID, otherPeerID}, true, localPeerID) + if err != nil { + t.Fatalf("initWhitelist: %v", err) + } + if _, found := wl.PeerIDs[localPeerID]; !found { + t.Fatal("Expected whitelist to include local peer") + } + }, + }, + { + name: "force_update_does_not_save_without_local_peer", + run: func(t *testing.T) { + dataDir := t.TempDir() + localPeerID := randomPeerID(t) + existingPeerID := randomPeerID(t) + missingLocalPeerID := randomPeerID(t) + + originalWL := types.NewWhitelist([]peer.ID{localPeerID, existingPeerID}) + wlPath := filepath.Join(dataDir, whitelistFileName) + if err := saveWhitelist(wlPath, originalWL); err != nil { + t.Fatalf("saveWhitelist: %v", err) + } + + _, err := initWhitelist(dataDir, []peer.ID{missingLocalPeerID}, true, localPeerID) + if err == nil { + t.Fatal("Expected error when force-updating with whitelist missing local peer") + } + if !strings.Contains(err.Error(), "not included in whitelist") { + t.Fatalf("Expected missing-local-peer error, got: %v", err) + } + + savedWL, err := loadWhitelist(wlPath) + if err != nil { + t.Fatalf("loadWhitelist: %v", err) + } + if !savedWL.Equals(originalWL) { + t.Fatal("Expected existing whitelist to remain unchanged after failed force update") + } + }, + }, + { + name: "first_run_saves_provided_whitelist", + run: func(t *testing.T) { + dataDir := t.TempDir() + localPeerID := randomPeerID(t) + otherPeerID := randomPeerID(t) + + wl, err := initWhitelist(dataDir, []peer.ID{localPeerID, otherPeerID}, false, localPeerID) + if err != nil { + t.Fatalf("initWhitelist: %v", err) + } + if _, found := wl.PeerIDs[localPeerID]; !found { + t.Fatal("Expected returned whitelist to include local peer") + } + + savedWL, err := loadWhitelist(filepath.Join(dataDir, whitelistFileName)) + if err != nil { + t.Fatalf("loadWhitelist: %v", err) + } + if !savedWL.Equals(wl) { + t.Fatal("Expected first-run provided whitelist to be saved") + } + }, + }, } - if !reflect.DeepEqual(loadedWhitelist.peers, originalWhitelist.peers) { - t.Errorf("Loaded whitelist does not match original whitelist") + for _, test := range tests { + t.Run(test.name, test.run) } } diff --git a/testing/harness.sh b/testing/harness.sh index 7c45f2e..c724c6d 100755 --- a/testing/harness.sh +++ b/testing/harness.sh @@ -5,12 +5,11 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" ROOT_DIR=~/.tatanka-test TATANKA_DIR="$SCRIPT_DIR/../cmd/tatanka" TATANKA_BIN=$TATANKA_DIR/tatanka +TATANKACTL_DIR="$SCRIPT_DIR/../cmd/tatankactl" +TATANKACTL_BIN=$TATANKACTL_DIR/tatankactl TESTCLIENT_DIR="$SCRIPT_DIR/../cmd/testclient" TESTCLIENT_BIN=$TESTCLIENT_DIR/testclient TESTCLIENT_UI_DIR="$SCRIPT_DIR/client/ui" -MAKEPRIVKEY_DIR="$SCRIPT_DIR/makeprivkey" -MAKEPRIVKEY_BIN=$ROOT_DIR/makeprivkey -WHITELIST_FILE=$ROOT_DIR/whitelist.json build_tatanka() { cd "$TATANKA_DIR" @@ -18,8 +17,10 @@ build_tatanka() { cd - > /dev/null } -build_makeprivkey() { - go build -o "$MAKEPRIVKEY_BIN" "$MAKEPRIVKEY_DIR/main.go" +build_tatankactl() { + cd "$TATANKACTL_DIR" + go build -o tatankactl + cd - > /dev/null } build_testclient() { @@ -37,24 +38,15 @@ build_testclient_ui() { cd - > /dev/null } -generate_privkey() { - local node_dir=$1 - local privkey_path=$node_dir/p.key - local peer_id=$("$MAKEPRIVKEY_BIN" "$privkey_path") - echo $peer_id -} - create_config() { local node_dir=$1 - local whitelist_path=$2 - local listen_port=$3 - local metrics_port=$4 - local admin_port=$5 + local listen_port=$2 + local metrics_port=$3 + local admin_port=$4 local config_path=$node_dir/tatanka.conf cat < $config_path appdata=$node_dir -whitelistpath=$whitelist_path listenport=$listen_port metricsport=$metrics_port adminport=$admin_port @@ -88,47 +80,76 @@ start_harness() { mkdir -p $ROOT_DIR build_tatanka - build_makeprivkey + build_tatankactl build_testclient build_testclient_ui - # Store bootstrap peers for the whitelist file - whitelist_peers=() + # Store peer IDs, listen ports, and admin ports for each node. node_peer_ids=() node_listen_ports=() + node_admin_ports=() - # Generate private keys and create config files for each node + # Initialize nodes and generate private keys using tatanka init. for i in $(seq 1 $num_nodes); do node_dir=$ROOT_DIR/tatanka-$i mkdir -p $node_dir - peer_id=$(generate_privkey $node_dir) + peer_id=$("$TATANKA_BIN" init --appdata "$node_dir") node_peer_ids+=("$peer_id") listen_port=$((12345 + i)) metrics_port=$((12355 + i)) admin_port=$((12365 + i)) node_listen_ports+=("$listen_port") - config_path=$(create_config $node_dir $WHITELIST_FILE $listen_port $metrics_port $admin_port) - - addr="/ip4/127.0.0.1/tcp/$listen_port" - whitelist_peers+=("{\"id\": \"$peer_id\", \"address\": \"$addr\"}") + node_admin_ports+=("$admin_port") + config_path=$(create_config $node_dir $listen_port $metrics_port $admin_port) done - # Create whitelist file + # Write shared whitelist into each node's data dir. + whitelist_peers=() + for i in $(seq 0 $((num_nodes - 1))); do + whitelist_peers+=("\"${node_peer_ids[$i]}\"") + done whitelist_json=$(printf ",%s" "${whitelist_peers[@]}") whitelist_json="${whitelist_json:1}" # Remove leading comma - whitelist="{\"peers\": [$whitelist_json]}" - echo $whitelist > $WHITELIST_FILE + whitelist="[$whitelist_json]" - # Start tmux session and start nodes + for i in $(seq 1 $num_nodes); do + node_dir=$ROOT_DIR/tatanka-$i + echo "$whitelist" > $node_dir/whitelist.json + done + + # Start tmux session with interleaved node/ctl windows starting at 0 + # so that 5 nodes fit in windows 0-9 (node-1=0, ctl-1=1, node-2=2, ...). session_name="tatanka-test" - tmux new-session -d -s $session_name + tmux new-session -d -s $session_name -x 200 -y 50 + win=0 for i in $(seq 1 $num_nodes); do node_dir=$ROOT_DIR/tatanka-$i config_path=$node_dir/tatanka.conf - tmux new-window -t $session_name:$i -n node-$i - tmux send-keys -t $session_name:$i "$TATANKA_BIN -C $config_path" C-m + + # Build bootstrap flags for all OTHER nodes. + bootstrap_flags="" + for j in $(seq 0 $((num_nodes - 1))); do + if [ $j -ne $((i - 1)) ]; then + bootstrap_flags="$bootstrap_flags --bootstrap /ip4/127.0.0.1/tcp/${node_listen_ports[$j]}/p2p/${node_peer_ids[$j]}" + fi + done + + # Node window + if [ $win -eq 0 ]; then + tmux rename-window -t $session_name:0 node-$i + else + tmux new-window -t $session_name:$win -n node-$i + fi + tmux send-keys -t $session_name:$win "$TATANKA_BIN -C $config_path$bootstrap_flags" C-m + win=$((win + 1)) + + # Ctl window + admin_port=${node_admin_ports[$((i - 1))]} + tmux new-window -t $session_name:$win -n ctl-$i + tmux send-keys -t $session_name:$win "$TATANKACTL_BIN -a localhost:$admin_port" C-m + win=$((win + 1)) done sleep 2 # wait for nodes to start before starting clients @@ -180,4 +201,4 @@ fi start_harness $num_nodes $num_clients -tmux attach-session -t $session_name \ No newline at end of file +tmux attach-session -t $session_name diff --git a/testing/makeprivkey/main.go b/testing/makeprivkey/main.go deleted file mode 100644 index 40750a2..0000000 --- a/testing/makeprivkey/main.go +++ /dev/null @@ -1,52 +0,0 @@ -// makeprivkey generates a new Ed25519 private key for a libp2p node, -// writes it to the specified file path, and outputs the corresponding -// peer ID to stdout. This is used to create node identities for the -// tatanka test harness. -package main - -import ( - "fmt" - "os" - - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" -) - -func main() { - if len(os.Args) != 2 { - fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) - os.Exit(1) - } - - filePath := os.Args[1] - - // Generate a new Ed25519 private key - priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) - if err != nil { - fmt.Fprintf(os.Stderr, "Error generating key pair: %v\n", err) - os.Exit(1) - } - - // Marshal the private key to bytes - bytes, err := crypto.MarshalPrivateKey(priv) - if err != nil { - fmt.Fprintf(os.Stderr, "Error marshaling private key: %v\n", err) - os.Exit(1) - } - - // Write the private key to file - if err := os.WriteFile(filePath, bytes, 0600); err != nil { - fmt.Fprintf(os.Stderr, "Error writing private key to file: %v\n", err) - os.Exit(1) - } - - // Get the peer ID from the public key - peerID, err := peer.IDFromPublicKey(priv.GetPublic()) - if err != nil { - fmt.Fprintf(os.Stderr, "Error deriving peer ID: %v\n", err) - os.Exit(1) - } - - // Print the peer ID to stdout - fmt.Println(peerID.String()) -}