diff --git a/cmd/tatanka/.!28982!tatanka b/cmd/tatanka/.!28982!tatanka new file mode 100755 index 0000000..e69de29 diff --git a/cmd/tatanka/main.go b/cmd/tatanka/main.go index b0ab9dc..7b5543d 100644 --- a/cmd/tatanka/main.go +++ b/cmd/tatanka/main.go @@ -29,9 +29,9 @@ type Config struct { WhitelistPath string `long:"whitelistpath" description:"Path to local whitelist file."` // Oracle Configuration - CMCKey string `long:"cmckey" description:"coinmarketcap API key"` - TatumKey string `long:"tatumkey" description:"tatum API key"` - CryptoApisKey string `long:"cryptoapiskey" description:"crypto apis API key"` + CMCKey string `long:"cmckey" description:"coinmarketcap API key"` + TatumKey string `long:"tatumkey" description:"tatum API key"` + BlockcypherToken string `long:"blockcyphertoken" description:"blockcypher API token"` } // initLogRotator initializes the logging rotater to write logs to logFile and @@ -107,9 +107,9 @@ func main() { MetricsPort: cfg.MetricsPort, WhitelistPath: cfg.WhitelistPath, AdminPort: cfg.AdminPort, - CMCKey: cfg.CMCKey, - TatumKey: cfg.TatumKey, - CryptoApisKey: cfg.CryptoApisKey, + CMCKey: cfg.CMCKey, + TatumKey: cfg.TatumKey, + BlockcypherToken: cfg.BlockcypherToken, } // Create Tatanka node diff --git a/cmd/tatankactl/api.go b/cmd/tatankactl/api.go new file mode 100644 index 0000000..09c2a7c --- /dev/null +++ b/cmd/tatankactl/api.go @@ -0,0 +1,344 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/url" + "sort" + "strings" + "sync" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/gorilla/websocket" + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/tatanka/admin" +) + +// --- Navigation messages --- + +type viewID int + +const ( + viewMenu viewID = iota + viewConnections + viewDiff + viewOracleSources + viewOracleDetail + viewOracleAggregated + viewAggregatedDetail +) + +type navigateMsg struct{ view viewID } +type navigateBackMsg struct{} +type navigateToDiffMsg struct{ node admin.NodeInfo } +type navigateToSourceDetailMsg struct { + sourceName string +} +type navigateToAggregatedDetailMsg struct { + dataType oracle.DataType + key string // ticker or network name +} + +// --- Data messages --- + +type adminStateMsg struct { + state *admin.AdminState +} + +type wsConnectedMsg struct{} +type wsErrorMsg struct{ err error } +type wsReconnectMsg struct{} + +// oracleSnapshotMsg is received on WS connect as full state. +type oracleSnapshotMsg oracle.OracleSnapshot + +// oracleUpdateMsg is a partial diff received as oracle_update. +type oracleUpdateMsg oracle.OracleSnapshot + +// renderTickMsg triggers periodic re-rendering while oracle views are active. +type renderTickMsg time.Time + +// --- Shared helpers --- + +func navBack() tea.Cmd { + return func() tea.Msg { return navigateBackMsg{} } +} + +func sortedKeys[M ~map[string]V, V any](m M) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +// --- Shared oracle data helpers --- + +func newOracleSnapshot() *oracle.OracleSnapshot { + return &oracle.OracleSnapshot{ + Prices: make(map[string]*oracle.SnapshotRate), + FeeRates: make(map[string]*oracle.SnapshotRate), + Sources: make(map[string]*oracle.SourceStatus), + } +} + +func getOrCreateSource(d *oracle.OracleSnapshot, name string) *oracle.SourceStatus { + src, ok := d.Sources[name] + if !ok { + src = &oracle.SourceStatus{ + Fetches24h: make(map[string]int), + Quotas: make(map[string]*oracle.Quota), + } + d.Sources[name] = src + } + return src +} + +func getOrCreateRate(m map[string]*oracle.SnapshotRate, key string) *oracle.SnapshotRate { + r, ok := m[key] + if !ok { + r = &oracle.SnapshotRate{ + Contributions: make(map[string]*oracle.SourceContribution), + } + m[key] = r + } + return r +} + +// mergeSnapshot applies a partial oracle.OracleSnapshot to the shared state. +func mergeSnapshot(d *oracle.OracleSnapshot, msg oracle.OracleSnapshot) { + if msg.NodeID != "" { + d.NodeID = msg.NodeID + } + + for name, s := range msg.Sources { + src := getOrCreateSource(d, name) + if s.LastFetch != nil { + src.LastFetch = s.LastFetch + } + if s.NextFetchTime != nil { + src.NextFetchTime = s.NextFetchTime + } + if s.MinFetchInterval != nil { + src.MinFetchInterval = s.MinFetchInterval + } + if s.NetworkSustainableRate != nil { + src.NetworkSustainableRate = s.NetworkSustainableRate + } + if s.NetworkSustainablePeriod != nil { + src.NetworkSustainablePeriod = s.NetworkSustainablePeriod + } + if s.NetworkNextFetchTime != nil { + src.NetworkNextFetchTime = s.NetworkNextFetchTime + } + if s.NextFetchTime != nil { + // Schedule updates from the diviner always carry error + // state. Empty values mean the error was cleared. + src.LastError = s.LastError + src.LastErrorTime = s.LastErrorTime + } else if s.LastError != "" || s.LastErrorTime != nil { + src.LastError = s.LastError + src.LastErrorTime = s.LastErrorTime + } + if s.OrderedNodes != nil { + src.OrderedNodes = s.OrderedNodes + } + if s.Fetches24h != nil { + src.Fetches24h = s.Fetches24h + } + if s.Quotas != nil { + for nodeID, q := range s.Quotas { + src.Quotas[nodeID] = q + } + } + if s.LatestData != nil { + if src.LatestData == nil { + src.LatestData = make(map[string]map[string]string) + } + for dataType, entries := range s.LatestData { + if src.LatestData[dataType] == nil { + src.LatestData[dataType] = make(map[string]string) + } + for id, value := range entries { + src.LatestData[dataType][id] = value + } + } + } + } + + for ticker, sr := range msg.Prices { + rate := getOrCreateRate(d.Prices, ticker) + rate.Value = sr.Value + for source, c := range sr.Contributions { + rate.Contributions[source] = c + } + } + + for network, sr := range msg.FeeRates { + rate := getOrCreateRate(d.FeeRates, network) + rate.Value = sr.Value + for source, c := range sr.Contributions { + rate.Contributions[source] = c + } + } +} + +// updateOracleData applies a WS message to the shared oracle data. +func updateOracleData(d *oracle.OracleSnapshot, msg tea.Msg) { + switch msg := msg.(type) { + case oracleSnapshotMsg: + d.Sources = make(map[string]*oracle.SourceStatus) + d.Prices = make(map[string]*oracle.SnapshotRate) + d.FeeRates = make(map[string]*oracle.SnapshotRate) + mergeSnapshot(d, oracle.OracleSnapshot(msg)) + case oracleUpdateMsg: + mergeSnapshot(d, oracle.OracleSnapshot(msg)) + } +} + +// --- API client --- + +type apiClient struct { + address string + + wsMu sync.Mutex + wsConn *websocket.Conn + wsCancel chan struct{} +} + +func newAPIClient(address string) *apiClient { + return &apiClient{ + address: normalizeAddress(address), + } +} + +func normalizeAddress(addr string) string { + if !strings.HasPrefix(addr, "http://") && !strings.HasPrefix(addr, "https://") { + return "http://" + addr + } + return addr +} + +// wsMessage mirrors admin.WSMessage for client-side parsing. +type wsMessage struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +func (c *apiClient) connectWebSocket(ch chan<- tea.Msg) tea.Cmd { + return func() tea.Msg { + wsURL, err := url.Parse(c.address) + if err != nil { + return wsErrorMsg{err: fmt.Errorf("invalid address: %w", err)} + } + + if wsURL.Scheme == "https" { + wsURL.Scheme = "wss" + } else { + wsURL.Scheme = "ws" + } + wsURL.Path = "/admin/ws" + + conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) + if err != nil { + return wsErrorMsg{err: fmt.Errorf("failed to connect: %w", err)} + } + + cancel := make(chan struct{}) + + c.wsMu.Lock() + c.wsConn = conn + c.wsCancel = cancel + c.wsMu.Unlock() + + // Start reader goroutine + go func() { + defer conn.Close() + for { + select { + case <-cancel: + return + default: + } + + _, message, err := conn.ReadMessage() + if err != nil { + select { + case <-cancel: + return + default: + } + ch <- wsErrorMsg{err: fmt.Errorf("connection lost: %w", err)} + return + } + + var envelope wsMessage + if err := json.Unmarshal(message, &envelope); err != nil { + continue + } + + var msg tea.Msg + switch envelope.Type { + case "admin_state": + var state admin.AdminState + if err := json.Unmarshal(envelope.Data, &state); err != nil { + continue + } + msg = adminStateMsg{state: &state} + case "oracle_snapshot": + var snapshot oracleSnapshotMsg + if err := json.Unmarshal(envelope.Data, &snapshot); err != nil { + continue + } + msg = snapshot + case "oracle_update": + var update oracleUpdateMsg + if err := json.Unmarshal(envelope.Data, &update); err != nil { + continue + } + msg = update + default: + continue + } + + select { + case ch <- msg: + case <-cancel: + return + } + } + }() + + return wsConnectedMsg{} + } +} + +func (c *apiClient) disconnectWebSocket() { + c.wsMu.Lock() + defer c.wsMu.Unlock() + + if c.wsCancel != nil { + close(c.wsCancel) + c.wsCancel = nil + } + if c.wsConn != nil { + c.wsConn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) + c.wsConn.Close() + c.wsConn = nil + } +} + +func listenForWSUpdates(ch <-chan tea.Msg) tea.Cmd { + return func() tea.Msg { + msg, ok := <-ch + if !ok { + return wsErrorMsg{err: fmt.Errorf("channel closed")} + } + return msg + } +} diff --git a/cmd/tatankactl/connections.go b/cmd/tatankactl/connections.go new file mode 100644 index 0000000..527fd4f --- /dev/null +++ b/cmd/tatankactl/connections.go @@ -0,0 +1,162 @@ +package main + +import ( + "fmt" + "sort" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/bisoncraft/mesh/tatanka/admin" +) + +type connectionsModel struct { + state *admin.AdminState + nodes []admin.NodeInfo + mismatchIndices []int + cursor int // index into mismatchIndices + lastUpdate time.Time + height int +} + +func (m connectionsModel) Update(msg tea.Msg) (connectionsModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + case tea.KeyMsg: + switch msg.String() { + case "up", "k": + if len(m.mismatchIndices) > 0 && m.cursor > 0 { + m.cursor-- + } + case "down", "j": + if len(m.mismatchIndices) > 0 && m.cursor < len(m.mismatchIndices)-1 { + m.cursor++ + } + case "enter": + if len(m.mismatchIndices) > 0 { + idx := m.mismatchIndices[m.cursor] + node := m.nodes[idx] + return m, func() tea.Msg { + return navigateToDiffMsg{node: node} + } + } + case "esc", "q": + return m, navBack() + } + } + return m, nil +} + +func (m *connectionsModel) sortNodes() { + nodes := make([]admin.NodeInfo, 0, len(m.state.Nodes)) + for _, node := range m.state.Nodes { + nodes = append(nodes, node) + } + + stateOrder := map[admin.NodeConnectionState]int{ + admin.StateConnected: 0, + admin.StateWhitelistMismatch: 1, + admin.StateDisconnected: 2, + } + sort.Slice(nodes, func(i, j int) bool { + oi, oj := stateOrder[nodes[i].State], stateOrder[nodes[j].State] + if oi != oj { + return oi < oj + } + return nodes[i].PeerID < nodes[j].PeerID + }) + + m.nodes = nodes + m.mismatchIndices = nil + for i, n := range nodes { + if n.State == admin.StateWhitelistMismatch { + m.mismatchIndices = append(m.mismatchIndices, i) + } + } + + // Keep cursor in bounds + if m.cursor >= len(m.mismatchIndices) { + m.cursor = max(0, len(m.mismatchIndices)-1) + } +} + +func (m connectionsModel) View() string { + var b strings.Builder + + // Header + ts := "" + if !m.lastUpdate.IsZero() { + ts = dimStyle.Render(m.lastUpdate.Format("15:04:05")) + } + b.WriteString(fmt.Sprintf(" %s%s\n\n", + headerStyle.Render("Connections"), + pad(ts, 50))) + + if m.state == nil { + b.WriteString(" Waiting for data...\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return b.String() + } + + // Summary counts + counts := make(map[admin.NodeConnectionState]int) + for _, node := range m.nodes { + counts[node.State]++ + } + b.WriteString(fmt.Sprintf(" Connected: %s | Mismatch: %s | Disconnected: %s\n\n", + connectedStyle.Render(fmt.Sprintf("%d", counts[admin.StateConnected])), + mismatchStyle.Render(fmt.Sprintf("%d", counts[admin.StateWhitelistMismatch])), + disconnectedStyle.Render(fmt.Sprintf("%d", counts[admin.StateDisconnected])), + )) + + if len(m.nodes) == 0 { + b.WriteString(" No nodes in whitelist\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return b.String() + } + + // Build a set of mismatch node indices for cursor display + selectedNodeIdx := -1 + if len(m.mismatchIndices) > 0 { + selectedNodeIdx = m.mismatchIndices[m.cursor] + } + + for i, node := range m.nodes { + icon := getStateIcon(node.State) + stateStr := getStateString(node.State) + + cursorStr := "" + if i == selectedNodeIdx { + cursorStr = cursorStyle.Render(" \u25c0 [Enter for diff]") + } + + b.WriteString(fmt.Sprintf(" %s %-25s %s%s\n", + icon, stateStr, dimStyle.Render(node.PeerID), cursorStr)) + + for _, addr := range node.Addresses { + b.WriteString(fmt.Sprintf(" \u2502 %s\n", dimStyle.Render(addr))) + } + + b.WriteString("\n") + } + + // Help + help := " \u2191\u2193 Navigate mismatch nodes Enter: View diff Esc: Back" + if len(m.mismatchIndices) == 0 { + help = " Esc: Back" + } + b.WriteString(helpStyle.Render(help)) + + return fitToHeight(b.String(), m.height) +} + +func pad(s string, width int) string { + // Right-align s within width by prepending spaces + w := lipgloss.Width(s) + if w >= width { + return " " + s + } + return strings.Repeat(" ", width-w) + s +} diff --git a/cmd/tatankactl/diff.go b/cmd/tatankactl/diff.go new file mode 100644 index 0000000..4c629ab --- /dev/null +++ b/cmd/tatankactl/diff.go @@ -0,0 +1,171 @@ +package main + +import ( + "fmt" + "sort" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/bisoncraft/mesh/tatanka/admin" +) + +type diffModel struct { + node admin.NodeInfo + inBoth []string + onlyOurs []string + onlyPeers []string + scrollOffset int + height int +} + +func newDiffModel(node admin.NodeInfo, state *admin.AdminState) diffModel { + ourSet := make(map[string]bool) + for _, id := range state.OurWhitelist { + ourSet[id] = true + } + + peerSet := make(map[string]bool) + for _, id := range node.PeerWhitelist { + peerSet[id] = true + } + + var inBoth, onlyOurs, onlyPeers []string + + for _, id := range state.OurWhitelist { + if peerSet[id] { + inBoth = append(inBoth, id) + } else { + onlyOurs = append(onlyOurs, id) + } + } + + for _, id := range node.PeerWhitelist { + if !ourSet[id] { + onlyPeers = append(onlyPeers, id) + } + } + + sort.Strings(inBoth) + sort.Strings(onlyOurs) + sort.Strings(onlyPeers) + + return diffModel{ + node: node, + inBoth: inBoth, + onlyOurs: onlyOurs, + onlyPeers: onlyPeers, + height: 40, // default, updated by WindowSizeMsg + } +} + +func (m diffModel) Init() tea.Cmd { + return nil +} + +func (m diffModel) Update(msg tea.Msg) (diffModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + case tea.KeyMsg: + switch msg.String() { + case "up", "k": + if m.scrollOffset > 0 { + m.scrollOffset-- + } + case "down", "j": + if m.scrollOffset < m.maxOffset() { + m.scrollOffset++ + } + case "esc", "q": + return m, navBack() + } + } + return m, nil +} + +func (m diffModel) totalLines() int { + total := 4 // header, blank, separator, blank + if len(m.inBoth) > 0 { + total += 1 + len(m.inBoth) + 1 + } + if len(m.onlyOurs) > 0 { + total += 1 + len(m.onlyOurs) + 1 + } + if len(m.onlyPeers) > 0 { + total += 1 + len(m.onlyPeers) + 1 + } + total++ // help line + return total +} + +func (m diffModel) maxOffset() int { + visible := m.height - 2 + if visible < 1 { + visible = 1 + } + maxOff := m.totalLines() - visible + if maxOff < 0 { + return 0 + } + return maxOff +} + +func (m diffModel) View() string { + var lines []string + + lines = append(lines, + headerStyle.Render(fmt.Sprintf(" Whitelist Diff \u2014 %s", m.node.PeerID)), + "", + dimStyle.Render(" "+strings.Repeat("\u2500", 50)), + "", + ) + + if len(m.inBoth) > 0 { + lines = append(lines, + dimStyle.Render(fmt.Sprintf(" \u2713 In Both Whitelists (%d):", len(m.inBoth)))) + for _, id := range m.inBoth { + lines = append(lines, dimStyle.Render(" "+id)) + } + lines = append(lines, "") + } + + if len(m.onlyOurs) > 0 { + lines = append(lines, + diffGreenStyle.Render(fmt.Sprintf(" + Only in Our Whitelist (%d):", len(m.onlyOurs)))) + for _, id := range m.onlyOurs { + lines = append(lines, diffGreenStyle.Render(" "+id)) + } + lines = append(lines, "") + } + + if len(m.onlyPeers) > 0 { + lines = append(lines, + diffRedStyle.Render(fmt.Sprintf(" - Only in Peer's Whitelist (%d):", len(m.onlyPeers)))) + for _, id := range m.onlyPeers { + lines = append(lines, diffRedStyle.Render(" "+id)) + } + lines = append(lines, "") + } + + lines = append(lines, helpStyle.Render(" \u2191\u2193 Scroll Esc: Back to connections")) + + // Apply scroll + maxOffset := len(lines) - m.height + 2 + if maxOffset < 0 { + maxOffset = 0 + } + if m.scrollOffset > maxOffset { + m.scrollOffset = maxOffset + } + + start := m.scrollOffset + end := start + m.height - 2 + if end > len(lines) { + end = len(lines) + } + if start > len(lines) { + start = len(lines) + } + + return fitToHeight(strings.Join(lines[start:end], "\n"), m.height) +} diff --git a/cmd/tatankactl/main.go b/cmd/tatankactl/main.go index 26bdcd7..98b282d 100644 --- a/cmd/tatankactl/main.go +++ b/cmd/tatankactl/main.go @@ -1,482 +1,255 @@ package main import ( - "encoding/json" + "flag" "fmt" - "net/http" - "net/url" "os" - "os/signal" - "sort" - "strings" - "syscall" "time" - "github.com/gorilla/websocket" - "github.com/jessevdk/go-flags" - "github.com/bisoncraft/mesh/tatanka/admin" + tea "github.com/charmbracelet/bubbletea" + "github.com/bisoncraft/mesh/oracle" ) -// Command definitions -type command struct { - description string - usage string - help string - run func(args []string) error +// rootModel is the top-level bubbletea model that routes between views. +type rootModel struct { + api *apiClient + oracleData *oracle.OracleSnapshot + activeView viewID + menu menuModel + connections connectionsModel + diff diffModel + oracle oracleModel + oracleDetail oracleDetailModel + oracleAggregated oracleAggregatedModel + oracleAggregatedDetail oracleAggregatedDetailModel + height int + wsCh chan tea.Msg } -var commands = map[string]*command{} - -const globalOptions = `Global options: - -a, --address= Admin server address (default: localhost:12366)` - -func init() { - commands["conns"] = &command{ - description: "Display current node connections", - usage: "tatankactl conns", - help: `Options: - (none)`, - run: runConns, - } - commands["watchconns"] = &command{ - description: "Watch node connections in real-time", - usage: "tatankactl watchconns", - help: `Options: - (none) - -Press Ctrl+C to stop watching.`, - run: runWatchConns, - } - commands["diff"] = &command{ - description: "Show whitelist diff for a node with whitelist mismatch", - usage: "tatankactl diff ", - help: `Arguments: - peer_id Peer ID (or prefix) to show diff for - -Options: - (none) - -The peer must be in whitelist_mismatch state to show the diff.`, - run: runDiff, - } - commands["help"] = &command{ - description: "Show help for commands", - usage: "tatankactl help [command]", - help: `Arguments: - command Command to show help for (optional)`, - run: runHelp, +func newRootModel(api *apiClient) rootModel { + return rootModel{ + api: api, + oracleData: newOracleSnapshot(), + activeView: viewMenu, + menu: newMenuModel(), + wsCh: make(chan tea.Msg, 20), } } -func main() { - if len(os.Args) < 2 { - printUsage() - os.Exit(1) - } - - cmdName := os.Args[1] - - // Handle --help or -h at top level - if cmdName == "--help" || cmdName == "-h" { - printUsage() - os.Exit(0) - } - - cmd, ok := commands[cmdName] - if !ok { - fmt.Fprintf(os.Stderr, "Unknown command: %s\n\n", cmdName) - printUsage() - os.Exit(1) - } - - // Pass remaining args to the command - if err := cmd.run(os.Args[2:]); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} - -func printUsage() { - fmt.Println("tatankactl - Tatanka node administration tool") - fmt.Println() - fmt.Println("Usage: tatankactl [options]") - fmt.Println() - fmt.Println(globalOptions) - fmt.Println() - fmt.Println("Commands:") - for name := range commands { - fmt.Printf(" %-12s %s\n", name, commands[name].description) - } - fmt.Println() - fmt.Println("Use \"tatankactl help \" for more information about a command.") -} - -func printCommandUsage(cmd *command) { - fmt.Println(cmd.usage) - fmt.Println() - fmt.Println(cmd.description) - fmt.Println() - fmt.Println(cmd.help) - fmt.Println() - fmt.Println(globalOptions) -} - -// Common options for connection commands -type connOptions struct { - Address string `short:"a" long:"address" description:"Admin server address" default:"localhost:12366"` -} - -func parseConnOptions(args []string) (*connOptions, []string, error) { - var opts connOptions - parser := flags.NewParser(&opts, flags.Default&^flags.PrintErrors) - remaining, err := parser.ParseArgs(args) - if err != nil { - if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { - return nil, nil, err - } - return nil, nil, err - } - return &opts, remaining, nil -} - -func normalizeAddress(addr string) string { - if !strings.HasPrefix(addr, "http://") && !strings.HasPrefix(addr, "https://") { - return "http://" + addr - } - return addr -} - -// conns command -func runConns(args []string) error { - opts, remaining, err := parseConnOptions(args) - if err != nil { - if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { - printCommandUsage(commands["conns"]) - return nil - } - return err - } - - if len(remaining) > 0 { - return fmt.Errorf("conns does not accept additional arguments: %v", remaining) - } - - address := normalizeAddress(opts.Address) - state, err := fetchState(address) - if err != nil { - return err - } - - printState(state) - return nil -} - -// watchconns command -func runWatchConns(args []string) error { - opts, remaining, err := parseConnOptions(args) - if err != nil { - if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { - printCommandUsage(commands["watchconns"]) - return nil - } - return err - } - - if len(remaining) > 0 { - return fmt.Errorf("watchconns does not accept additional arguments: %v", remaining) - } - - address := normalizeAddress(opts.Address) - watchState(address) - return nil -} - -// diff command -func runDiff(args []string) error { - opts, remaining, err := parseConnOptions(args) - if err != nil { - if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { - printCommandUsage(commands["diff"]) - return nil - } - return err - } - - if len(remaining) == 0 { - return fmt.Errorf("diff requires a peer ID argument") - } - if len(remaining) > 1 { - return fmt.Errorf("diff accepts only one peer ID argument, got: %v", remaining) - } - - peerID := remaining[0] - address := normalizeAddress(opts.Address) - - state, err := fetchState(address) - if err != nil { - return err - } - - return showDiff(state, peerID) -} - -// help command -func runHelp(args []string) error { - if len(args) == 0 { - printUsage() - return nil - } - - if len(args) > 1 { - return fmt.Errorf("help accepts at most one argument") - } - - cmdName := args[0] - cmd, ok := commands[cmdName] - if !ok { - return fmt.Errorf("unknown command: %s", cmdName) - } - - printCommandUsage(cmd) - - return nil +func (m rootModel) Init() tea.Cmd { + return m.api.connectWebSocket(m.wsCh) } -func fetchState(address string) (*admin.AdminState, error) { - resp, err := http.Get(address + "/admin/state") - if err != nil { - return nil, fmt.Errorf("failed to connect to admin server: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("server returned status %d", resp.StatusCode) - } - - var state admin.AdminState - if err := json.NewDecoder(resp.Body).Decode(&state); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - return &state, nil +func renderTick() tea.Cmd { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { + return renderTickMsg(t) + }) } -func watchState(address string) { - // Convert HTTP URL to WebSocket URL - wsURL, err := url.Parse(address) - if err != nil { - fmt.Fprintf(os.Stderr, "Invalid address: %v\n", err) - os.Exit(1) - } - - if wsURL.Scheme == "https" { - wsURL.Scheme = "wss" - } else { - wsURL.Scheme = "ws" - } - wsURL.Path = "/admin/ws" - - // Handle interrupt signal - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM) - - fmt.Printf("Connecting to %s...\n", wsURL.String()) - - conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to connect: %v\n", err) - os.Exit(1) - } - defer conn.Close() - - fmt.Println("Connected. Watching for updates (Ctrl+C to exit)...") - - done := make(chan struct{}) - - go func() { - defer close(done) - for { - _, message, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - fmt.Fprintf(os.Stderr, "Connection error: %v\n", err) - } - return - } - - var state admin.AdminState - if err := json.Unmarshal(message, &state); err != nil { - fmt.Fprintf(os.Stderr, "Failed to decode message: %v\n", err) - continue - } - - // Clear screen and print new state - fmt.Print("\033[H\033[2J") - fmt.Printf("Tatanka Admin - %s\n", time.Now().Format("15:04:05")) - fmt.Println(strings.Repeat("=", 60)) - printState(&state) - } - }() - - select { - case <-done: - case <-interrupt: - fmt.Println("\nDisconnecting...") - conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - select { - case <-done: - case <-time.After(time.Second): - } +func (m rootModel) isOracleView() bool { + switch m.activeView { + case viewOracleSources, viewOracleDetail, viewOracleAggregated, viewAggregatedDetail: + return true } + return false } -func printState(state *admin.AdminState) { - nodes := make([]admin.NodeInfo, 0, len(state.Nodes)) - for _, node := range state.Nodes { - nodes = append(nodes, node) - } - - // Sort nodes by state: connected first, then whitelist_mismatch, then disconnected - sort.Slice(nodes, func(i, j int) bool { - stateOrder := map[admin.NodeConnectionState]int{ - admin.StateConnected: 0, - admin.StateWhitelistMismatch: 1, - admin.StateDisconnected: 2, +func (m rootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.String() == "ctrl+c" { + m.api.disconnectWebSocket() + return m, tea.Quit } - return stateOrder[nodes[i].State] < stateOrder[nodes[j].State] - }) - - // Count by state - counts := make(map[admin.NodeConnectionState]int) - for _, node := range nodes { - counts[node.State]++ - } - - fmt.Printf("Node Connections (%d total)\n", len(nodes)) - fmt.Printf(" Connected: %d | Whitelist Mismatch: %d | Disconnected: %d\n\n", - counts[admin.StateConnected], counts[admin.StateWhitelistMismatch], counts[admin.StateDisconnected]) - - if len(nodes) == 0 { - fmt.Println(" No nodes in whitelist") - return - } - - for _, node := range nodes { - icon := getStateIcon(node.State) - stateStr := getStateString(node.State) - fmt.Printf(" %s %-20s %s\n", icon, stateStr, node.PeerID) - // Print addresses - if len(node.Addresses) > 0 { - for _, addr := range node.Addresses { - fmt.Printf(" │ %s\n", addr) - } + case tea.WindowSizeMsg: + m.height = msg.Height + // Propagate to active child + var cmd tea.Cmd + switch m.activeView { + case viewMenu: + m.menu, cmd = m.menu.Update(msg) + return m, cmd + case viewConnections: + m.connections, cmd = m.connections.Update(msg) + return m, cmd + case viewDiff: + m.diff, cmd = m.diff.Update(msg) + return m, cmd + case viewOracleSources: + m.oracle, cmd = m.oracle.Update(msg) + return m, cmd + case viewOracleDetail: + m.oracleDetail, cmd = m.oracleDetail.Update(msg) + return m, cmd + case viewOracleAggregated: + m.oracleAggregated, cmd = m.oracleAggregated.Update(msg) + return m, cmd + case viewAggregatedDetail: + m.oracleAggregatedDetail, cmd = m.oracleAggregatedDetail.Update(msg) + return m, cmd } - - if node.State == admin.StateWhitelistMismatch && len(node.PeerWhitelist) > 0 { - fmt.Printf(" └─ Use \"tatankactl diff %s\" to see whitelist differences\n", node.PeerID[:12]) + return m, nil + + case navigateMsg: + switch msg.view { + case viewConnections: + m.activeView = viewConnections + m.connections.height = m.height + return m, nil + case viewOracleSources: + m.activeView = viewOracleSources + m.oracle = newOracleModel(m.oracleData) + m.oracle.height = m.height + return m, tea.Batch(m.oracle.Init(), renderTick()) + case viewOracleAggregated: + m.activeView = viewOracleAggregated + m.oracleAggregated = newOracleAggregatedModel(m.oracleData) + m.oracleAggregated.height = m.height + return m, tea.Batch(m.oracleAggregated.Init(), renderTick()) } - } -} - -func showDiff(state *admin.AdminState, peerID string) error { - var targetNode *admin.NodeInfo - for id, node := range state.Nodes { - if strings.HasPrefix(id, peerID) { - nodeCopy := node - targetNode = &nodeCopy - break + return m, nil + + case navigateBackMsg: + switch m.activeView { + case viewConnections: + m.activeView = viewMenu + return m, nil + case viewDiff: + m.activeView = viewConnections + return m, nil + case viewOracleSources: + m.activeView = viewMenu + return m, nil + case viewOracleDetail: + m.activeView = viewOracleSources + m.oracle.rebuildSortedSources() + return m, nil + case viewOracleAggregated: + m.activeView = viewMenu + return m, nil + case viewAggregatedDetail: + m.activeView = viewOracleAggregated + m.oracleAggregated.buildSections() + return m, nil } - } - - if targetNode == nil { - return fmt.Errorf("node not found: %s", peerID) - } - - if targetNode.State != admin.StateWhitelistMismatch { - return fmt.Errorf("node %s is not in whitelist mismatch state", peerID) - } - - if len(targetNode.PeerWhitelist) == 0 { - return fmt.Errorf("no peer whitelist data available for %s", peerID) - } - - ourSet := make(map[string]bool) - for _, id := range state.OurWhitelist { - ourSet[id] = true - } - - peerSet := make(map[string]bool) - for _, id := range targetNode.PeerWhitelist { - peerSet[id] = true - } - - var onlyInOurs, onlyInPeers, inBoth []string - - for _, id := range state.OurWhitelist { - if peerSet[id] { - inBoth = append(inBoth, id) - } else { - onlyInOurs = append(onlyInOurs, id) + return m, nil + + case wsConnectedMsg: + return m, listenForWSUpdates(m.wsCh) + + case wsErrorMsg: + // Reconnect after a brief delay. + return m, tea.Tick(3*time.Second, func(t time.Time) tea.Msg { + return wsReconnectMsg{} + }) + + case wsReconnectMsg: + return m, m.api.connectWebSocket(m.wsCh) + + // Oracle WS messages — update shared data and trigger view rebuilds + case oracleSnapshotMsg, oracleUpdateMsg: + updateOracleData(m.oracleData, msg) + var cmds []tea.Cmd + cmds = append(cmds, listenForWSUpdates(m.wsCh)) + if m.activeView == viewOracleDetail { + m.oracleDetail.buildSections() } - } - - for _, id := range targetNode.PeerWhitelist { - if !ourSet[id] { - onlyInPeers = append(onlyInPeers, id) + if m.activeView == viewOracleAggregated { + m.oracleAggregated.buildSections() } - } - - fmt.Printf("Whitelist Diff for %s\n", targetNode.PeerID) - fmt.Println(strings.Repeat("=", 60)) - - if len(inBoth) > 0 { - fmt.Printf("\n✓ In Both Whitelists (%d):\n", len(inBoth)) - for _, id := range inBoth { - fmt.Printf(" %s\n", id) + if m.activeView == viewOracleSources { + m.oracle.rebuildSortedSources() } - } + return m, tea.Batch(cmds...) - if len(onlyInOurs) > 0 { - fmt.Printf("\n+ Only in Our Whitelist (%d):\n", len(onlyInOurs)) - for _, id := range onlyInOurs { - fmt.Printf(" %s\n", id) + case adminStateMsg: + if msg.state != nil { + m.connections.state = msg.state + m.connections.sortNodes() + m.connections.lastUpdate = time.Now() } - } + return m, listenForWSUpdates(m.wsCh) - if len(onlyInPeers) > 0 { - fmt.Printf("\n- Only in Peer's Whitelist (%d):\n", len(onlyInPeers)) - for _, id := range onlyInPeers { - fmt.Printf(" %s\n", id) + case renderTickMsg: + if m.isOracleView() { + // Re-render for relative time updates + return m, renderTick() } - } - - fmt.Println() - return nil -} - -func getStateIcon(state admin.NodeConnectionState) string { - switch state { - case admin.StateConnected: - return "🟢" - case admin.StateWhitelistMismatch: - return "🟡" - case admin.StateDisconnected: - return "🔴" + return m, nil + + case navigateToSourceDetailMsg: + m.oracleDetail = newOracleDetailModel(m.oracleData, msg.sourceName) + m.oracleDetail.height = m.height + m.activeView = viewOracleDetail + return m, m.oracleDetail.Init() + + case navigateToAggregatedDetailMsg: + m.oracleAggregatedDetail = newOracleAggregatedDetailModel(m.oracleData, msg.dataType, msg.key) + m.oracleAggregatedDetail.height = m.height + m.activeView = viewAggregatedDetail + return m, m.oracleAggregatedDetail.Init() + + case navigateToDiffMsg: + m.diff = newDiffModel(msg.node, m.connections.state) + m.diff.height = m.height + m.activeView = viewDiff + return m, m.diff.Init() + } + + // Delegate to active view + var cmd tea.Cmd + switch m.activeView { + case viewMenu: + m.menu, cmd = m.menu.Update(msg) + case viewConnections: + m.connections, cmd = m.connections.Update(msg) + case viewDiff: + m.diff, cmd = m.diff.Update(msg) + case viewOracleSources: + m.oracle, cmd = m.oracle.Update(msg) + case viewOracleDetail: + m.oracleDetail, cmd = m.oracleDetail.Update(msg) + case viewOracleAggregated: + m.oracleAggregated, cmd = m.oracleAggregated.Update(msg) + case viewAggregatedDetail: + m.oracleAggregatedDetail, cmd = m.oracleAggregatedDetail.Update(msg) + } + return m, cmd +} + +func (m rootModel) View() string { + switch m.activeView { + case viewMenu: + return m.menu.View() + case viewConnections: + return m.connections.View() + case viewDiff: + return m.diff.View() + case viewOracleSources: + return m.oracle.View() + case viewOracleDetail: + return m.oracleDetail.View() + case viewOracleAggregated: + return m.oracleAggregated.View() + case viewAggregatedDetail: + return m.oracleAggregatedDetail.View() default: - return "⚪" + return "" } } -func getStateString(state admin.NodeConnectionState) string { - switch state { - case admin.StateConnected: - return "Connected" - case admin.StateWhitelistMismatch: - return "Whitelist Mismatch" - case admin.StateDisconnected: - return "Disconnected" - default: - return string(state) +func main() { + address := flag.String("a", "localhost:12366", "Admin server address") + flag.StringVar(address, "address", "localhost:12366", "Admin server address") + flag.Parse() + + api := newAPIClient(*address) + model := newRootModel(api) + + p := tea.NewProgram(model, tea.WithAltScreen()) + if _, err := p.Run(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } } diff --git a/cmd/tatankactl/menu.go b/cmd/tatankactl/menu.go new file mode 100644 index 0000000..4a21daa --- /dev/null +++ b/cmd/tatankactl/menu.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" +) + +type menuModel struct { + choices []string + views []viewID + cursor int + height int +} + +func newMenuModel() menuModel { + return menuModel{ + choices: []string{"Connections", "Oracle Sources", "Oracle Data"}, + views: []viewID{viewConnections, viewOracleSources, viewOracleAggregated}, + cursor: 0, + } +} + +func (m menuModel) Init() tea.Cmd { + return nil +} + +func (m menuModel) Update(msg tea.Msg) (menuModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + case tea.KeyMsg: + switch msg.String() { + case "up", "k": + if m.cursor > 0 { + m.cursor-- + } + case "down", "j": + if m.cursor < len(m.choices)-1 { + m.cursor++ + } + case "enter": + return m, func() tea.Msg { + return navigateMsg{view: m.views[m.cursor]} + } + case "q": + return m, tea.Quit + } + } + return m, nil +} + +func (m menuModel) View() string { + var b strings.Builder + + b.WriteString(titleStyle.Render("Tatanka Admin")) + b.WriteString("\n\n") + + for i, choice := range m.choices { + cursor := " " + if i == m.cursor { + cursor = cursorStyle.Render("> ") + } + b.WriteString(fmt.Sprintf("%s%s\n", cursor, choice)) + } + + b.WriteString(helpStyle.Render("\nEnter: select q: quit")) + + return fitToHeight(menuBoxStyle.Render(b.String()), m.height) +} diff --git a/cmd/tatankactl/oracle_aggregated.go b/cmd/tatankactl/oracle_aggregated.go new file mode 100644 index 0000000..caf583f --- /dev/null +++ b/cmd/tatankactl/oracle_aggregated.go @@ -0,0 +1,146 @@ +package main + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/bisoncraft/mesh/oracle" +) + +type oracleAggregatedModel struct { + data *oracle.OracleSnapshot + sections []detailSection + focused int + height int + filter filterState +} + +func newOracleAggregatedModel(data *oracle.OracleSnapshot) oracleAggregatedModel { + m := oracleAggregatedModel{ + data: data, + height: 40, + } + m.buildSections() + return m +} + +func (m oracleAggregatedModel) Init() tea.Cmd { + return nil +} + +func (m oracleAggregatedModel) Update(msg tea.Msg) (oracleAggregatedModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + + case tea.KeyMsg: + if m.filter.active { + if m.filter.handleFilterKey(msg.String()) { + m.buildSections() + } + return m, nil + } + switch msg.String() { + case "up", "k": + if len(m.sections) > 0 { + m.sections[m.focused].cursorUp() + } + case "down", "j": + if len(m.sections) > 0 { + m.sections[m.focused].cursorDown() + } + case "tab": + if len(m.sections) > 0 { + m.focused = (m.focused + 1) % len(m.sections) + } + case "shift+tab": + if len(m.sections) > 0 { + m.focused = (m.focused - 1 + len(m.sections)) % len(m.sections) + } + case "enter": + if len(m.sections) > 0 { + sec := &m.sections[m.focused] + key := sec.selectedKey() + if key != "" { + dataType := oracle.PriceData + if sec.title == "Fee Rates" { + dataType = oracle.FeeRateData + } + return m, func() tea.Msg { + return navigateToAggregatedDetailMsg{ + dataType: dataType, + key: key, + } + } + } + } + case "/": + m.filter.startFiltering() + case "esc", "q": + if m.filter.handleEscOrQ() { + m.buildSections() + } else { + return m, navBack() + } + } + } + return m, nil +} + +func (m *oracleAggregatedModel) buildSections() { + m.sections = nil + + if lines, keys := m.buildRateLines(m.data.Prices); len(lines) > 0 { + m.sections = append(m.sections, detailSection{title: "Prices", lines: lines, keys: keys}) + } + + if lines, keys := m.buildRateLines(m.data.FeeRates); len(lines) > 0 { + m.sections = append(m.sections, detailSection{title: "Fee Rates", lines: lines, keys: keys}) + } + + if m.focused >= len(m.sections) { + m.focused = max(0, len(m.sections)-1) + } +} + +func (m oracleAggregatedModel) buildRateLines(rates map[string]*oracle.SnapshotRate) ([]string, []string) { + var lines, keys []string + for _, key := range sortedKeys(rates) { + if !m.filter.matches(key) { + continue + } + rate := rates[key] + lines = append(lines, fmt.Sprintf(" %-10s %s", key, rate.Value)) + keys = append(keys, key) + } + return lines, keys +} + +func (m oracleAggregatedModel) View() string { + var b strings.Builder + + // Header + b.WriteString(fmt.Sprintf(" %s\n\n", headerStyle.Render("Aggregated Data"))) + + // Filter bar + m.filter.renderFilterBar(&b) + + if len(m.sections) == 0 { + if m.filter.text != "" { + b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("No matches for \""+m.filter.text+"\""))) + } else if len(m.data.Prices) == 0 && len(m.data.FeeRates) == 0 { + b.WriteString(" " + dimStyle.Render("No aggregated data available") + "\n") + } + } + + // Sections + for i, sec := range m.sections { + renderSection(&b, &sec, i == m.focused) + } + + // Help + b.WriteString(buildFilterHelp(m.sections, m.filter, "Enter: Details")) + + return fitToHeight(b.String(), m.height) +} diff --git a/cmd/tatankactl/oracle_aggregated_detail.go b/cmd/tatankactl/oracle_aggregated_detail.go new file mode 100644 index 0000000..7b543d7 --- /dev/null +++ b/cmd/tatankactl/oracle_aggregated_detail.go @@ -0,0 +1,152 @@ +package main + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/bisoncraft/mesh/oracle" +) + +type oracleAggregatedDetailModel struct { + data *oracle.OracleSnapshot + dataType oracle.DataType + key string // ticker or network name + offset int + height int +} + +func newOracleAggregatedDetailModel(data *oracle.OracleSnapshot, dataType oracle.DataType, key string) oracleAggregatedDetailModel { + return oracleAggregatedDetailModel{ + data: data, + dataType: dataType, + key: key, + height: 40, + } +} + +func (m oracleAggregatedDetailModel) Init() tea.Cmd { + return nil +} + +func (m oracleAggregatedDetailModel) Update(msg tea.Msg) (oracleAggregatedDetailModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + + case tea.KeyMsg: + switch msg.String() { + case "up", "k": + if m.offset > 0 { + m.offset-- + } + case "down", "j": + maxOffset := m.maxOffset() + if m.offset < maxOffset { + m.offset++ + } + case "esc", "q": + return m, navBack() + } + } + return m, nil +} + +func (m oracleAggregatedDetailModel) getContributions() map[string]*oracle.SourceContribution { + var rate *oracle.SnapshotRate + if m.dataType == oracle.PriceData { + rate = m.data.Prices[m.key] + } else { + rate = m.data.FeeRates[m.key] + } + if rate == nil { + return nil + } + return rate.Contributions +} + +func (m oracleAggregatedDetailModel) maxOffset() int { + contribs := m.getContributions() + if contribs == nil { + return 0 + } + lines := len(contribs) * 4 + visible := m.height - 8 + if visible < 5 { + visible = 5 + } + maxOff := lines - visible + if maxOff < 0 { + return 0 + } + return maxOff +} + +func (m oracleAggregatedDetailModel) View() string { + var b strings.Builder + + // Header + label := m.key + if m.dataType == oracle.PriceData { + label += " Price" + } else { + label += " Fee Rate" + } + b.WriteString(fmt.Sprintf(" %s\n\n", headerStyle.Render("Sources: "+label))) + + contribs := m.getContributions() + if len(contribs) == 0 { + b.WriteString(" " + dimStyle.Render("No source data available") + "\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return b.String() + } + + // Sort by source name + sources := sortedKeys(contribs) + + // Build content lines + var lines []string + for _, name := range sources { + c := contribs[name] + age := relativeTime(c.Stamp) + agedWeight := dimStyle.Render(fmt.Sprintf("(weight: %.2f, %s)", c.Weight, age)) + lines = append(lines, + fmt.Sprintf(" %s", headerStyle.Render(name)), + fmt.Sprintf(" Value: %s", c.Value), + fmt.Sprintf(" %s", agedWeight), + "", + ) + } + + // Apply scroll offset + visible := m.height - 8 + if visible < 5 { + visible = 5 + } + + start := m.offset + if start > len(lines) { + start = len(lines) + } + end := start + visible + if end > len(lines) { + end = len(lines) + } + + if m.offset > 0 { + b.WriteString(dimStyle.Render(" \u25b2 more above") + "\n") + } + + for _, line := range lines[start:end] { + b.WriteString(line + "\n") + } + + if end < len(lines) { + b.WriteString(dimStyle.Render(" \u25bc more below") + "\n") + } + + // Help + b.WriteString(helpStyle.Render("\n \u2191\u2193 Scroll Esc: Back")) + + return fitToHeight(b.String(), m.height) +} diff --git a/cmd/tatankactl/oracle_detail.go b/cmd/tatankactl/oracle_detail.go new file mode 100644 index 0000000..182356c --- /dev/null +++ b/cmd/tatankactl/oracle_detail.go @@ -0,0 +1,280 @@ +package main + +import ( + "fmt" + "math" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/bisoncraft/mesh/oracle" +) + +type oracleDetailModel struct { + data *oracle.OracleSnapshot + sourceName string + sections []detailSection + focused int + height int + filter filterState +} + +func newOracleDetailModel(data *oracle.OracleSnapshot, sourceName string) oracleDetailModel { + m := oracleDetailModel{ + data: data, + sourceName: sourceName, + height: 40, + } + m.buildSections() + return m +} + +func (m *oracleDetailModel) buildSections() { + m.sections = nil + + if lines := m.buildContribLines(m.data.Prices); len(lines) > 0 { + m.sections = append(m.sections, detailSection{title: "Prices", lines: lines}) + } + + if lines := m.buildContribLines(m.data.FeeRates); len(lines) > 0 { + m.sections = append(m.sections, detailSection{title: "Fee Rates", lines: lines}) + } + + src := m.data.Sources[m.sourceName] + if src != nil && m.sourceHasQuotas(src) { + if lines := m.buildQuotaLines(); len(lines) > 0 { + m.sections = append(m.sections, detailSection{title: "Quotas", lines: lines}) + } + } + + if lines := m.buildFetchLines(); len(lines) > 0 { + m.sections = append(m.sections, detailSection{title: "Fetches (24h)", lines: lines}) + } + + if m.focused >= len(m.sections) { + m.focused = max(0, len(m.sections)-1) + } +} + +func (m oracleDetailModel) Init() tea.Cmd { + return nil +} + +func (m oracleDetailModel) Update(msg tea.Msg) (oracleDetailModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + case tea.KeyMsg: + if m.filter.active { + if m.filter.handleFilterKey(msg.String()) { + m.buildSections() + } + return m, nil + } + switch msg.String() { + case "up", "k": + if len(m.sections) > 0 { + m.sections[m.focused].scrollUp() + } + case "down", "j": + if len(m.sections) > 0 { + m.sections[m.focused].scrollDown() + } + case "tab": + if len(m.sections) > 0 { + m.focused = (m.focused + 1) % len(m.sections) + } + case "shift+tab": + if len(m.sections) > 0 { + m.focused = (m.focused - 1 + len(m.sections)) % len(m.sections) + } + case "/": + m.filter.startFiltering() + case "esc", "q": + if m.filter.handleEscOrQ() { + m.buildSections() + } else { + return m, navBack() + } + } + } + return m, nil +} + +func (m oracleDetailModel) View() string { + var b strings.Builder + + // Header + b.WriteString(fmt.Sprintf(" %s\n\n", headerStyle.Render("Source: "+m.sourceName))) + + src := m.data.Sources[m.sourceName] + if src == nil { + b.WriteString(" " + dimStyle.Render("Source not found") + "\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return b.String() + } + + // Schedule section + b.WriteString(fmt.Sprintf(" %s\n", cursorStyle.Render("Schedule"))) + + lastStr := "never" + if src.LastFetch != nil { + lastStr = relativeTime(*src.LastFetch) + } + b.WriteString(fmt.Sprintf(" Last Fetch: %s\n", dimStyle.Render(lastStr))) + + if src.NetworkNextFetchTime != nil { + b.WriteString(fmt.Sprintf(" Network Next: %s\n", dimStyle.Render(relativeTime(*src.NetworkNextFetchTime)))) + } + + if src.NextFetchTime != nil { + posStr := "" + if len(src.OrderedNodes) > 0 { + orderIndex := -1 + for i, nodeID := range src.OrderedNodes { + if nodeID == m.data.NodeID { + orderIndex = i + break + } + } + if orderIndex >= 0 { + posStr = fmt.Sprintf(" (#%d of %d)", orderIndex+1, len(src.OrderedNodes)) + } + } + b.WriteString(fmt.Sprintf(" Your Next Fetch: %s%s\n", dimStyle.Render(relativeTime(*src.NextFetchTime)), dimStyle.Render(posStr))) + } + + hasQuotas := m.sourceHasQuotas(src) + + if hasQuotas { + if src.NetworkSustainableRate != nil && *src.NetworkSustainableRate > 0 { + b.WriteString(fmt.Sprintf(" Sustainable Rate: %s\n", dimStyle.Render(fmt.Sprintf("%.4f fetches/sec", *src.NetworkSustainableRate)))) + } + if src.NetworkSustainablePeriod != nil && *src.NetworkSustainablePeriod > 0 { + b.WriteString(fmt.Sprintf(" Sustainable Period: %s\n", dimStyle.Render("1 fetch / "+src.NetworkSustainablePeriod.String()))) + } + } else { + b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("This source has no quotas \u2014 fetch interval determined by minimum period"))) + } + + if src.MinFetchInterval != nil && *src.MinFetchInterval > 0 { + b.WriteString(fmt.Sprintf(" Min Period: %s\n", dimStyle.Render(src.MinFetchInterval.String()))) + } + + if src.LastError != "" { + errAge := "" + if src.LastErrorTime != nil { + errAge = " (" + relativeTime(*src.LastErrorTime) + ")" + } + b.WriteString(fmt.Sprintf(" Last Error: %s\n", disconnectedStyle.Render(src.LastError+errAge))) + } + b.WriteString("\n") + + // Fetch Order section + if len(src.OrderedNodes) > 0 { + b.WriteString(fmt.Sprintf(" %s\n", cursorStyle.Render("Fetch Order"))) + for i, nodeID := range src.OrderedNodes { + label := truncatePeerID(nodeID) + marker := " " + if nodeID == m.data.NodeID { + label = "You" + marker = "\u2190 " + } + b.WriteString(fmt.Sprintf(" %d. %-20s %s\n", i+1, label, dimStyle.Render(marker))) + } + b.WriteString("\n") + } + + // Filter bar + m.filter.renderFilterBar(&b) + + if len(m.sections) == 0 { + if m.filter.text != "" { + b.WriteString(fmt.Sprintf(" %s\n", dimStyle.Render("No matches for \""+m.filter.text+"\""))) + } else { + b.WriteString(" " + dimStyle.Render("No data available") + "\n") + } + } + + // Sections + for i, sec := range m.sections { + renderSection(&b, &sec, i == m.focused) + } + + // Help + b.WriteString(buildFilterHelp(m.sections, m.filter)) + + return fitToHeight(b.String(), m.height) +} + +func (m oracleDetailModel) sourceHasQuotas(src *oracle.SourceStatus) bool { + for _, q := range src.Quotas { + if q.FetchesLimit > 0 && q.FetchesLimit < math.MaxInt64 { + return true + } + } + return false +} + +// --- Section content builders --- + +func (m oracleDetailModel) buildContribLines(rates map[string]*oracle.SnapshotRate) []string { + var lines []string + for _, key := range sortedKeys(rates) { + rate := rates[key] + contrib, ok := rate.Contributions[m.sourceName] + if !ok || !m.filter.matches(key) { + continue + } + age := dimStyle.Render("(" + relativeTime(contrib.Stamp) + ")") + lines = append(lines, fmt.Sprintf(" %-8s %s %s", + key, contrib.Value, age)) + } + return lines +} + +func (m oracleDetailModel) buildQuotaLines() []string { + src := m.data.Sources[m.sourceName] + if src == nil { + return nil + } + + var lines []string + for _, nid := range sortedKeys(src.Quotas) { + q := src.Quotas[nid] + if q.FetchesLimit <= 0 { + continue + } + label := truncatePeerID(nid) + if nid == m.data.NodeID { + label += " (ours)" + } + lines = append(lines, + fmt.Sprintf(" %s", dimStyle.Render(label)), + fmt.Sprintf(" Fetches: %d / %d", q.FetchesRemaining, q.FetchesLimit), + fmt.Sprintf(" Resets: %s", dimStyle.Render(relativeTime(q.ResetTime))), + "", + ) + } + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + return lines +} + +func (m oracleDetailModel) buildFetchLines() []string { + src := m.data.Sources[m.sourceName] + if src == nil || len(src.Fetches24h) == 0 { + return nil + } + + var lines []string + for _, nid := range sortedKeys(src.Fetches24h) { + count := src.Fetches24h[nid] + label := truncatePeerID(nid) + if nid == m.data.NodeID { + label += " (ours)" + } + lines = append(lines, fmt.Sprintf(" %-24s %d", label, count)) + } + return lines +} diff --git a/cmd/tatankactl/oracle_view.go b/cmd/tatankactl/oracle_view.go new file mode 100644 index 0000000..176fa8e --- /dev/null +++ b/cmd/tatankactl/oracle_view.go @@ -0,0 +1,171 @@ +package main + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + "github.com/bisoncraft/mesh/oracle" +) + +type oracleModel struct { + data *oracle.OracleSnapshot + cursor int + height int + // sorted source names for stable ordering + sortedSources []string +} + +func newOracleModel(data *oracle.OracleSnapshot) oracleModel { + m := oracleModel{ + data: data, + height: 40, + } + m.rebuildSortedSources() + return m +} + +func (m oracleModel) Init() tea.Cmd { + return nil +} + +func (m *oracleModel) rebuildSortedSources() { + m.sortedSources = sortedKeys(m.data.Sources) +} + +func (m oracleModel) Update(msg tea.Msg) (oracleModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.height = msg.Height + + case tea.KeyMsg: + switch msg.String() { + case "up", "k": + if m.cursor > 0 { + m.cursor-- + } + case "down", "j": + if m.cursor < len(m.sortedSources)-1 { + m.cursor++ + } + case "enter": + if len(m.sortedSources) > 0 { + srcName := m.sortedSources[m.cursor] + return m, func() tea.Msg { + return navigateToSourceDetailMsg{sourceName: srcName} + } + } + case "esc", "q": + return m, navBack() + } + } + return m, nil +} + +func (m oracleModel) View() string { + var b strings.Builder + + // Header + b.WriteString(fmt.Sprintf(" %s\n\n", headerStyle.Render("Oracle Status"))) + + if len(m.data.Sources) == 0 { + b.WriteString(" No oracle sources configured\n") + b.WriteString(helpStyle.Render("\n Esc: Back")) + return b.String() + } + + // Table + b.WriteString(m.renderSourceList(m.sortedSources)) + + b.WriteString("\n") + b.WriteString(helpStyle.Render(" \u2191\u2193 Navigate Enter: Details Esc: Back")) + + return fitToHeight(b.String(), m.height) +} + +func (m oracleModel) renderSourceList(sorted []string) string { + const ( + colSource = 22 + colLast = 16 + colNext = 16 + ) + + border := tableBorderStyle.Render + + hLine := func(left, mid, right, fill string) string { + return border(left) + + border(strings.Repeat(fill, colSource)) + + border(mid) + + border(strings.Repeat(fill, colLast)) + + border(mid) + + border(strings.Repeat(fill, colNext)) + + border(right) + } + + padCell := func(s string, w int) string { + vw := lipgloss.Width(s) + if vw > w-1 { + s = ansi.Truncate(s, w-2, "\u2026") + vw = lipgloss.Width(s) + } + pad := w - 1 - vw + if pad < 0 { + pad = 0 + } + return " " + s + strings.Repeat(" ", pad) + } + + row := func(src, last, next string) string { + return border("\u2502") + + padCell(src, colSource) + + border("\u2502") + + padCell(last, colLast) + + border("\u2502") + + padCell(next, colNext) + + border("\u2502") + } + + var lines []string + + // Top border + lines = append(lines, " "+hLine("\u250c", "\u252c", "\u2510", "\u2500")) + + // Header row + lines = append(lines, " "+row("Source", "Last Fetch", "Next Fetch")) + + for i, name := range sorted { + src := m.data.Sources[name] + + // Separator + lines = append(lines, " "+hLine("\u251c", "\u253c", "\u2524", "\u2500")) + + lastStr := "never" + if src.LastFetch != nil { + lastStr = relativeTime(*src.LastFetch) + } + + nextStr := "\u2014" + if src.NextFetchTime != nil { + nextStr = relativeTime(*src.NextFetchTime) + } + + srcName := name + if src.LastError != "" { + srcName += " " + disconnectedStyle.Render("!") + } + if i == m.cursor { + srcName = "> " + srcName + } else { + srcName = " " + srcName + } + + lines = append(lines, " "+row(srcName, lastStr, nextStr)) + } + + // Bottom border + lines = append(lines, " "+hLine("\u2514", "\u2534", "\u2518", "\u2500")) + + return strings.Join(lines, "\n") +} diff --git a/cmd/tatankactl/section.go b/cmd/tatankactl/section.go new file mode 100644 index 0000000..e5270bd --- /dev/null +++ b/cmd/tatankactl/section.go @@ -0,0 +1,235 @@ +package main + +import ( + "fmt" + "strings" +) + +const sectionMaxVisible = 10 + +// detailSection holds the content lines for one scrollable section. +type detailSection struct { + title string + lines []string + keys []string // parallel to lines; enables item selection when non-empty + offset int + itemCursor int // highlighted item index (only used when keys is non-empty) +} + +func (s *detailSection) scrollDown() { + max := len(s.lines) - sectionMaxVisible + if max < 0 { + max = 0 + } + if s.offset < max { + s.offset++ + } +} + +func (s *detailSection) scrollUp() { + if s.offset > 0 { + s.offset-- + } +} + +func (s *detailSection) cursorDown() { + if s.itemCursor < len(s.lines)-1 { + s.itemCursor++ + } + if s.itemCursor >= s.offset+sectionMaxVisible { + s.offset = s.itemCursor - sectionMaxVisible + 1 + } +} + +func (s *detailSection) cursorUp() { + if s.itemCursor > 0 { + s.itemCursor-- + } + if s.itemCursor < s.offset { + s.offset = s.itemCursor + } +} + +func (s detailSection) selectedKey() string { + if len(s.keys) == 0 || s.itemCursor >= len(s.keys) { + return "" + } + return s.keys[s.itemCursor] +} + +func (s detailSection) needsScroll() bool { + return len(s.lines) > sectionMaxVisible +} + +func (s detailSection) visibleLines() []string { + if !s.needsScroll() { + return s.lines + } + end := s.offset + sectionMaxVisible + if end > len(s.lines) { + end = len(s.lines) + } + return s.lines[s.offset:end] +} + +// renderSection renders a section with header, separator, scroll indicators, +// and content. Sections with keys get cursor highlighting on the selected item. +func renderSection(b *strings.Builder, sec *detailSection, focused bool) { + // Section header + titleStr := sec.title + if focused { + titleStr = cursorStyle.Render("\u25b6 ") + headerStyle.Render(sec.title) + } else { + titleStr = dimStyle.Render(" " + sec.title) + } + b.WriteString(" " + titleStr) + + // Scroll position indicator + if sec.needsScroll() { + b.WriteString(dimStyle.Render(fmt.Sprintf(" (%d-%d of %d)", + sec.offset+1, + min(sec.offset+sectionMaxVisible, len(sec.lines)), + len(sec.lines)))) + } + b.WriteString("\n") + + // Separator + if focused { + b.WriteString(" " + tableBorderStyle.Render(strings.Repeat("\u2500", 50)) + "\n") + } else { + b.WriteString(" " + dimStyle.Render(strings.Repeat("\u2500", 50)) + "\n") + } + + // Up indicator + if sec.needsScroll() && sec.offset > 0 { + b.WriteString(dimStyle.Render(" \u25b2 more above") + "\n") + } + + // Visible content + hasCursor := len(sec.keys) > 0 + visibleStart := sec.offset + for i, line := range sec.visibleLines() { + absIdx := visibleStart + i + if hasCursor && focused && absIdx == sec.itemCursor { + if len(line) > 0 { + b.WriteString(cursorStyle.Render(">") + line[1:] + "\n") + } else { + b.WriteString(cursorStyle.Render(">") + "\n") + } + } else { + b.WriteString(line + "\n") + } + } + + // Down indicator + if sec.needsScroll() && sec.offset+sectionMaxVisible < len(sec.lines) { + b.WriteString(dimStyle.Render(" \u25bc more below") + "\n") + } + + b.WriteString("\n") +} + +// fitToHeight ensures the rendered output is exactly height lines tall. +// It truncates excess content from the bottom (keeping the header visible) +// and pads with empty lines to fill the screen (preventing alt-screen artifacts). +func fitToHeight(content string, height int) string { + if height <= 0 { + return content + } + lines := strings.Split(content, "\n") + // strings.Split on trailing \n produces an extra empty element; trim it + // so we count only visual lines. + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + if len(lines) > height { + lines = lines[:height] + } + for len(lines) < height { + lines = append(lines, "") + } + return strings.Join(lines, "\n") +} + +// buildFilterHelp builds the standard help bar for views with sections and filtering. +func buildFilterHelp(sections []detailSection, filter filterState, extra ...string) string { + parts := []string{"\u2191\u2193 Scroll"} + if len(sections) > 1 { + parts = append(parts, "Tab: Next section") + } + parts = append(parts, extra...) + parts = append(parts, "/: Filter") + if filter.text != "" { + parts = append(parts, "Esc: Clear filter") + } else { + parts = append(parts, "Esc: Back") + } + return helpStyle.Render(" " + strings.Join(parts, " ")) +} + +// filterState manages text filtering shared by multiple views. +type filterState struct { + active bool + text string +} + +func (f *filterState) startFiltering() { + f.active = true + f.text = "" +} + +func (f *filterState) matches(name string) bool { + if f.text == "" { + return true + } + return strings.Contains(strings.ToUpper(name), strings.ToUpper(f.text)) +} + +// handleFilterKey processes a key press while in filter mode. +// Returns true if sections need rebuilding. +func (f *filterState) handleFilterKey(key string) bool { + switch key { + case "enter": + f.active = false + return true + case "esc": + f.active = false + f.text = "" + return true + case "backspace": + if len(f.text) > 0 { + f.text = f.text[:len(f.text)-1] + return true + } + default: + if len(key) == 1 && key[0] >= 32 && key[0] <= 126 { + f.text += key + return true + } + } + return false +} + +// handleEscOrQ handles esc/q when not in filter mode. +// Returns true if the filter was cleared (sections need rebuilding). +// Returns false if navigation back should happen. +func (f *filterState) handleEscOrQ() bool { + if f.text != "" { + f.text = "" + return true + } + return false +} + +// renderFilterBar renders the filter input or active filter indicator. +func (f *filterState) renderFilterBar(b *strings.Builder) { + if f.active { + b.WriteString(fmt.Sprintf(" %s %s\u2588\n\n", + cursorStyle.Render("/"), + f.text)) + } else if f.text != "" { + b.WriteString(fmt.Sprintf(" %s %s\n\n", + dimStyle.Render("Filter:"), + connectedStyle.Render(f.text))) + } +} diff --git a/cmd/tatankactl/styles.go b/cmd/tatankactl/styles.go new file mode 100644 index 0000000..dc27083 --- /dev/null +++ b/cmd/tatankactl/styles.go @@ -0,0 +1,147 @@ +package main + +import ( + "fmt" + "math" + "time" + + "github.com/charmbracelet/lipgloss" + "github.com/bisoncraft/mesh/tatanka/admin" +) + +// Colors +var ( + colorGreen = lipgloss.Color("42") + colorYellow = lipgloss.Color("214") + colorRed = lipgloss.Color("196") + colorDim = lipgloss.Color("241") + colorCyan = lipgloss.Color("86") + colorWhite = lipgloss.Color("255") + colorBorder = lipgloss.Color("63") + colorHeader = lipgloss.Color("99") + colorCursor = lipgloss.Color("214") + colorGreenFg = lipgloss.Color("46") + colorRedFg = lipgloss.Color("196") +) + +// Styles +var ( + titleStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(colorHeader). + MarginBottom(1) + + headerStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(colorWhite) + + connectedStyle = lipgloss.NewStyle(). + Foreground(colorGreen) + + mismatchStyle = lipgloss.NewStyle(). + Foreground(colorYellow) + + disconnectedStyle = lipgloss.NewStyle(). + Foreground(colorRed) + + dimStyle = lipgloss.NewStyle(). + Foreground(colorDim) + + helpStyle = lipgloss.NewStyle(). + Foreground(colorDim). + MarginTop(1) + + cursorStyle = lipgloss.NewStyle(). + Foreground(colorCursor). + Bold(true) + + diffGreenStyle = lipgloss.NewStyle(). + Foreground(colorGreenFg) + + diffRedStyle = lipgloss.NewStyle(). + Foreground(colorRedFg) + + menuBoxStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorBorder). + Padding(1, 3) + + tableBorderStyle = lipgloss.NewStyle(). + Foreground(colorBorder) +) + +func getStateIcon(state admin.NodeConnectionState) string { + switch state { + case admin.StateConnected: + return connectedStyle.Render("●") + case admin.StateWhitelistMismatch: + return mismatchStyle.Render("●") + case admin.StateDisconnected: + return disconnectedStyle.Render("●") + default: + return dimStyle.Render("●") + } +} + +func getStateString(state admin.NodeConnectionState) string { + switch state { + case admin.StateConnected: + return connectedStyle.Render("Connected") + case admin.StateWhitelistMismatch: + return mismatchStyle.Render("Whitelist Mismatch") + case admin.StateDisconnected: + return disconnectedStyle.Render("Disconnected") + default: + return string(state) + } +} + +func relativeTime(t time.Time) string { + now := time.Now() + d := now.Sub(t) + if d < 0 { + // Future time + d = -d + return "in " + formatDuration(d) + } + return formatDuration(d) + " ago" +} + +func formatDuration(d time.Duration) string { + if d < time.Second { + return "<1s" + } + totalSecs := int(math.Round(d.Seconds())) + if totalSecs < 60 { + return fmt.Sprintf("%ds", totalSecs) + } + minutes := totalSecs / 60 + seconds := totalSecs % 60 + if minutes < 60 { + if seconds == 0 { + return fmt.Sprintf("%dm", minutes) + } + return fmt.Sprintf("%dm %ds", minutes, seconds) + } + hours := minutes / 60 + minutes = minutes % 60 + if hours < 24 { + if minutes == 0 { + return fmt.Sprintf("%dh", hours) + } + return fmt.Sprintf("%dh %dm", hours, minutes) + } + days := hours / 24 + hours = hours % 24 + if hours == 0 { + return fmt.Sprintf("%dd", days) + } + return fmt.Sprintf("%dd %dh", days, hours) +} + +func truncatePeerID(id string) string { + if len(id) <= 16 { + return id + } + return id[:8] + ".." + id[len(id)-4:] +} diff --git a/go.mod b/go.mod index ce5872c..859d145 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module github.com/bisoncraft/mesh go 1.24.9 require ( + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/charmbracelet/x/ansi v0.10.1 github.com/decred/slog v1.2.0 github.com/go-chi/chi/v5 v5.2.3 github.com/go-chi/cors v1.2.2 @@ -19,11 +22,16 @@ require ( ) require ( + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/flynn/noise v1.1.0 // indirect github.com/francoispqt/gojay v1.2.13 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -42,12 +50,19 @@ require ( github.com/libp2p/go-netroute v0.3.0 // indirect github.com/libp2p/go-reuseport v0.4.0 // indirect github.com/libp2p/go-yamux/v5 v5.0.1 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/miekg/dns v1.1.66 // indirect github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b // indirect github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect github.com/minio/sha256-simd v1.0.1 // indirect github.com/mr-tron/base58 v1.2.0 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/multiformats/go-base32 v0.1.0 // indirect github.com/multiformats/go-base36 v0.2.0 // indirect github.com/multiformats/go-multiaddr-dns v0.4.1 // indirect @@ -84,9 +99,11 @@ require ( github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.55.0 // indirect github.com/quic-go/webtransport-go v0.9.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.uber.org/dig v1.19.0 // indirect go.uber.org/fx v1.24.0 // indirect go.uber.org/mock v0.5.2 // indirect @@ -96,7 +113,7 @@ require ( golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 // indirect golang.org/x/mod v0.27.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.28.0 // indirect golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.36.0 // indirect diff --git a/go.sum b/go.sum index 6a0d895..011e548 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -18,6 +20,18 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= +github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -32,6 +46,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjY github.com/decred/slog v1.2.0 h1:soHAxV52B54Di3WtKLfPum9OFfWqwtf/ygf9njdfnPM= github.com/decred/slog v1.2.0/go.mod h1:kVXlGnt6DHy2fV5OjSeuvCJ0OmlmTF6LFpEPMu/fOY0= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= @@ -129,12 +145,20 @@ github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQsc github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/libp2p/go-yamux/v5 v5.0.1 h1:f0WoX/bEF2E8SbE4c/k1Mo+/9z0O4oC/hWEA+nfYRSg= github.com/libp2p/go-yamux/v5 v5.0.1/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marcopolo/simnet v0.0.1 h1:rSMslhPz6q9IvJeFWDoMGxMIrlsbXau3NkuIXHGJxfg= github.com/marcopolo/simnet v0.0.1/go.mod h1:WDaQkgLAjqDUEBAOXz22+1j6wXKfGlC5sD5XWt3ddOs= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE= @@ -154,6 +178,12 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/multiformats/go-base32 v0.1.0 h1:pVx9xoSPqEIQG8o+UbAe7DNi51oej1NtK+aGkbLYxPE= github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYgtWibDcT0rExnbI= github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0= @@ -246,6 +276,9 @@ github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9M github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U= github.com/quic-go/webtransport-go v0.9.0 h1:jgys+7/wm6JarGDrW+lD/r9BGqBAmqY/ssklE09bA70= github.com/quic-go/webtransport-go v0.9.0/go.mod h1:4FUYIiUc75XSsF6HShcLeXXYZJ9AGwo/xh3L8M/P1ao= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= @@ -292,6 +325,8 @@ github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMI github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -385,15 +420,17 @@ golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/oracle/.gitignore b/oracle/.gitignore deleted file mode 100644 index cc2ddc8..0000000 --- a/oracle/.gitignore +++ /dev/null @@ -1 +0,0 @@ -coinpap2000.json diff --git a/oracle/buckets.go b/oracle/buckets.go new file mode 100644 index 0000000..1ff474c --- /dev/null +++ b/oracle/buckets.go @@ -0,0 +1,227 @@ +package oracle + +import ( + "math" + "math/big" + "sync" + "sync/atomic" + "time" +) + +const ( + fullValidityPeriod = time.Minute * 5 + validityExpiration = time.Minute * 30 + decayPeriod = validityExpiration - fullValidityPeriod +) + +// agedWeight returns a weight based on the age of an update. +func agedWeight(weight float64, stamp time.Time) float64 { + age := time.Since(stamp) + if age < 0 { + age = 0 + } + + switch { + case age < fullValidityPeriod: + return weight + case age > validityExpiration: + return 0 + default: + remainingValidity := validityExpiration - age + return weight * (float64(remainingValidity) / float64(decayPeriod)) + } +} + +// priceUpdate is the internal message used for when a price update is fetched +// or received from a source. +type priceUpdate struct { + ticker Ticker + price float64 + stamp time.Time + weight float64 +} + +// feeRateUpdate is the internal message used for when a fee rate update is +// fetched or received from a source. +type feeRateUpdate struct { + network Network + feeRate *big.Int + stamp time.Time + weight float64 +} + +// priceBucket is a collection of price updates from a single source +// and the aggregated price. +type priceBucket struct { + latest atomic.Uint64 + + mtx sync.RWMutex + sources map[string]*priceUpdate +} + +func newPriceBucket() *priceBucket { + return &priceBucket{ + latest: atomic.Uint64{}, + sources: make(map[string]*priceUpdate), + } +} + +func aggregatePriceSources(sources map[string]*priceUpdate) float64 { + var weightedSum float64 + var totalWeight float64 + for _, entry := range sources { + weight := agedWeight(entry.weight, entry.stamp) + if weight == 0 { + continue + } + totalWeight += weight + weightedSum += weight * entry.price + } + if totalWeight == 0 { + return 0 + } + return weightedSum / totalWeight +} + +func (b *priceBucket) aggregatedPrice() float64 { + return math.Float64frombits(b.latest.Load()) +} + +// mergeAndUpdateAggregate merges a price update into the bucket and returns +// the new updated aggregated price. updated is true if the aggregated price +// was updated, false otherwise (if the update is older than the latest update +// for the source). +func (b *priceBucket) mergeAndUpdateAggregate(source string, upd *priceUpdate) (updated bool, agg float64) { + b.mtx.Lock() + defer b.mtx.Unlock() + + existing, found := b.sources[source] + if found && !upd.stamp.After(existing.stamp) { + return false, 0 + } + b.sources[source] = upd + + agg = aggregatePriceSources(b.sources) + b.latest.Store(math.Float64bits(agg)) + return true, agg +} + +// feeRateBucket is a collection of fee rate updates from a single source +// and the aggregated fee rate. +type feeRateBucket struct { + latest atomic.Value // *big.Int + + mtx sync.RWMutex + sources map[string]*feeRateUpdate +} + +func newFeeRateBucket() *feeRateBucket { + bucket := &feeRateBucket{ + latest: atomic.Value{}, + sources: make(map[string]*feeRateUpdate), + } + bucket.latest.Store((*big.Int)(nil)) + return bucket +} + +func aggregateFeeRateSources(sources map[string]*feeRateUpdate) *big.Int { + weightedSum := new(big.Float) + var totalWeight float64 + + for _, entry := range sources { + weight := agedWeight(entry.weight, entry.stamp) + if weight == 0 { + continue + } + totalWeight += weight + + // Multiply weight (float64) by feeRate (big.Int) using big.Float. + weightFloat := new(big.Float).SetFloat64(weight) + feeRateFloat := new(big.Float).SetInt(entry.feeRate) + product := new(big.Float).Mul(weightFloat, feeRateFloat) + weightedSum.Add(weightedSum, product) + } + if totalWeight == 0 { + return big.NewInt(0) + } + + totalWeightFloat := new(big.Float).SetFloat64(totalWeight) + avgFloat := new(big.Float).Quo(weightedSum, totalWeightFloat) + + // Round to nearest integer. + if avgFloat.Sign() >= 0 { + avgFloat.Add(avgFloat, new(big.Float).SetFloat64(0.5)) + } else { + avgFloat.Sub(avgFloat, new(big.Float).SetFloat64(0.5)) + } + rounded := new(big.Int) + avgFloat.Int(rounded) + + return rounded +} + +func (b *feeRateBucket) aggregatedRate() *big.Int { + return b.latest.Load().(*big.Int) +} + +// mergeAndUpdateAggregate merges a fee rate update into the bucket and returns +// the new updated aggregated fee rate. updated is true if the aggregated fee rate +// was updated, false otherwise (if the update is older than the latest update +// for the source). +func (b *feeRateBucket) mergeAndUpdateAggregate(source string, upd *feeRateUpdate) (updated bool, agg *big.Int) { + b.mtx.Lock() + defer b.mtx.Unlock() + + existing, found := b.sources[source] + if found && !upd.stamp.After(existing.stamp) { + return false, nil + } + b.sources[source] = upd + + agg = aggregateFeeRateSources(b.sources) + b.latest.Store(agg) + return true, agg +} + +func (o *Oracle) getPriceBucket(ticker Ticker) *priceBucket { + o.pricesMtx.RLock() + bucket := o.prices[ticker] + o.pricesMtx.RUnlock() + return bucket +} + +func (o *Oracle) getOrCreatePriceBucket(ticker Ticker) *priceBucket { + if bucket := o.getPriceBucket(ticker); bucket != nil { + return bucket + } + + o.pricesMtx.Lock() + defer o.pricesMtx.Unlock() + if bucket := o.prices[ticker]; bucket != nil { + return bucket + } + bucket := newPriceBucket() + o.prices[ticker] = bucket + return bucket +} + +func (o *Oracle) getFeeRateBucket(network Network) *feeRateBucket { + o.feeRatesMtx.RLock() + bucket := o.feeRates[network] + o.feeRatesMtx.RUnlock() + return bucket +} + +func (o *Oracle) getOrCreateFeeRateBucket(network Network) *feeRateBucket { + if bucket := o.getFeeRateBucket(network); bucket != nil { + return bucket + } + o.feeRatesMtx.Lock() + defer o.feeRatesMtx.Unlock() + if bucket := o.feeRates[network]; bucket != nil { + return bucket + } + bucket := newFeeRateBucket() + o.feeRates[network] = bucket + return bucket +} diff --git a/oracle/diviner.go b/oracle/diviner.go index 173f4c8..24b5a98 100644 --- a/oracle/diviner.go +++ b/oracle/diviner.go @@ -2,104 +2,105 @@ package oracle import ( "context" - "fmt" - "math" - "math/rand/v2" + "math/big" + "sync/atomic" "time" "github.com/decred/slog" - "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/oracle/sources" ) -// fetcher returns either a list of price updates or a list of fee rate updates. -type fetcher func(ctx context.Context) (any, error) - -// diviner wraps an httpSource and handles periodic fetching and emitting of +// diviner wraps a Source and handles periodic fetching and emitting of // price and fee rate updates. type diviner struct { - name string - fetcher func(ctx context.Context) (any, error) - weight float64 - period time.Duration - errPeriod time.Duration - log slog.Logger - publishUpdate func(ctx context.Context, update *pb.NodeOracleUpdate) error - resetTimer chan struct{} + source sources.Source + log slog.Logger + publishUpdate func(ctx context.Context, update *OracleUpdate) error + onScheduleChanged func(*OracleSnapshot) + resetTimer chan struct{} + nextFetchInfo atomic.Value // networkSchedule + errorInfo atomic.Value // fetchErrorInfo + getNetworkSchedule func() networkSchedule +} + +type fetchErrorInfo struct { + message string + stamp time.Time } -func newDiviner(name string, fetcher fetcher, weight float64, period time.Duration, errPeriod time.Duration, publishUpdate func(ctx context.Context, update *pb.NodeOracleUpdate) error, log slog.Logger) *diviner { +func newDiviner( + src sources.Source, + publishUpdate func(ctx context.Context, update *OracleUpdate) error, + log slog.Logger, + getNetworkSchedule func() networkSchedule, + onScheduleChanged func(*OracleSnapshot), +) *diviner { return &diviner{ - name: name, - fetcher: fetcher, - weight: weight, - period: period, - errPeriod: errPeriod, - log: log, - publishUpdate: publishUpdate, - resetTimer: make(chan struct{}), + source: src, + log: log, + publishUpdate: publishUpdate, + resetTimer: make(chan struct{}), + getNetworkSchedule: getNetworkSchedule, + onScheduleChanged: onScheduleChanged, + } +} + +// fetchScheduleInfo returns the current fetch schedule info. +func (d *diviner) fetchScheduleInfo() networkSchedule { + if v := d.nextFetchInfo.Load(); v != nil { + return v.(networkSchedule) } + return networkSchedule{} +} + +func (d *diviner) fetchErrorInfo() (string, *time.Time) { + if v := d.errorInfo.Load(); v != nil { + info := v.(fetchErrorInfo) + if info.message == "" { + return "", nil + } + stamp := info.stamp + return info.message, &stamp + } + return "", nil } func (d *diviner) fetchUpdates(ctx context.Context) error { - divination, err := d.fetcher(ctx) + rateInfo, err := d.source.FetchRates(ctx) if err != nil { return err } - now := time.Now() + if len(rateInfo.Prices) == 0 && len(rateInfo.FeeRates) == 0 { + return nil + } - switch updates := divination.(type) { - case []*priceUpdate: - prices := make([]*SourcedPrice, 0, len(updates)) - for _, entry := range updates { - prices = append(prices, &SourcedPrice{ - Ticker: entry.ticker, - Price: entry.price, - }) - } + update := &OracleUpdate{ + Source: d.source.Name(), + Stamp: time.Now(), + Quota: d.source.QuotaStatus(), + } - sourcedUpdate := &SourcedPriceUpdate{ - Source: d.name, - Stamp: now, - Weight: d.weight, - Prices: prices, + if len(rateInfo.Prices) > 0 { + update.Prices = make(map[Ticker]float64, len(rateInfo.Prices)) + for _, entry := range rateInfo.Prices { + update.Prices[Ticker(entry.Ticker)] = entry.Price } + } - payload := pbNodePriceUpdate(sourcedUpdate) - go func() { - err := d.publishUpdate(ctx, payload) - if err != nil { - d.log.Errorf("Failed to publish sourced price update: %v", err) - } - }() - - case []*feeRateUpdate: - feeRates := make([]*SourcedFeeRate, 0, len(updates)) - for _, entry := range updates { - feeRates = append(feeRates, &SourcedFeeRate{ - Network: entry.network, - FeeRate: bigIntToBytes(entry.feeRate), - }) + if len(rateInfo.FeeRates) > 0 { + update.FeeRates = make(map[Network]*big.Int, len(rateInfo.FeeRates)) + for _, entry := range rateInfo.FeeRates { + update.FeeRates[Network(entry.Network)] = entry.FeeRate } + } - sourcedUpdate := &SourcedFeeRateUpdate{ - Source: d.name, - Stamp: now, - Weight: d.weight, - FeeRates: feeRates, + go func() { + if err := d.publishUpdate(ctx, update); err != nil { + d.log.Errorf("Failed to publish oracle update: %v", err) } - - payload := pbNodeFeeRateUpdate(sourcedUpdate) - go func() { - err := d.publishUpdate(ctx, payload) - if err != nil { - d.log.Errorf("Failed to publish sourced fee rate update: %v", err) - } - }() - default: - return fmt.Errorf("source %q returned unexpected type %T", d.name, divination) - } + }() return nil } @@ -112,10 +113,7 @@ func (d *diviner) reschedule() { } func (d *diviner) run(ctx context.Context) { - // Initialize with a shorter period to fetch initial oracle updates. - initialPeriod := time.Second * 5 - delay := randomDelay(time.Second) - timer := time.NewTimer(initialPeriod + delay) + timer := time.NewTimer(0) defer timer.Stop() for { @@ -123,58 +121,57 @@ func (d *diviner) run(ctx context.Context) { case <-ctx.Done(): return case <-d.resetTimer: - timer.Reset(d.period) + info := d.getNetworkSchedule() + timer.Reset(time.Until(info.NextFetchTime)) + d.nextFetchInfo.Store(info) + d.fireScheduleChanged(info) case <-timer.C: if err := d.fetchUpdates(ctx); err != nil { d.log.Errorf("Failed to fetch divination: %v", err) - timer.Reset(d.errPeriod) + // Retry after 1 minute on errors. + const errPeriod = time.Minute + errTime := time.Now() + d.errorInfo.Store(fetchErrorInfo{message: err.Error(), stamp: errTime}) + info := d.fetchScheduleInfo() + if info.NextFetchTime.IsZero() { + info = d.getNetworkSchedule() + } + info.NextFetchTime = errTime.Add(errPeriod) + d.nextFetchInfo.Store(info) + d.fireScheduleChanged(info) + timer.Reset(errPeriod) } else { - timer.Reset(d.period) + d.errorInfo.Store(fetchErrorInfo{message: "", stamp: time.Time{}}) + info := d.getNetworkSchedule() + timer.Reset(time.Until(info.NextFetchTime)) + d.nextFetchInfo.Store(info) + d.fireScheduleChanged(info) } } } } -func randomDelay(maxDelay time.Duration) time.Duration { - return time.Duration(math.Round((rand.Float64() * float64(maxDelay)))) -} - -// --- Protobuf Helper Functions --- - -func pbNodePriceUpdate(update *SourcedPriceUpdate) *pb.NodeOracleUpdate { - pbPrices := make([]*pb.SourcedPrice, len(update.Prices)) - for i, p := range update.Prices { - pbPrices[i] = &pb.SourcedPrice{ - Ticker: string(p.Ticker), - Price: p.Price, - } +func (d *diviner) fireScheduleChanged(info networkSchedule) { + errMsg, errStamp := d.fetchErrorInfo() + nft := info.NextFetchTime + minPeriod := info.MinPeriod + nsp := info.NetworkSustainablePeriod + nnft := info.NetworkNextFetchTime + status := &SourceStatus{ + NextFetchTime: &nft, + MinFetchInterval: &minPeriod, + NetworkSustainableRate: &info.NetworkSustainableRate, + NetworkSustainablePeriod: &nsp, + NetworkNextFetchTime: &nnft, + OrderedNodes: info.OrderedNodes, } - return &pb.NodeOracleUpdate{ - Update: &pb.NodeOracleUpdate_PriceUpdate{ - PriceUpdate: &pb.SourcedPriceUpdate{ - Source: update.Source, - Timestamp: update.Stamp.Unix(), - Prices: pbPrices, - }, - }, + if errMsg != "" && errStamp != nil { + status.LastError = errMsg + status.LastErrorTime = errStamp } -} - -func pbNodeFeeRateUpdate(update *SourcedFeeRateUpdate) *pb.NodeOracleUpdate { - pbFeeRates := make([]*pb.SourcedFeeRate, len(update.FeeRates)) - for i, fr := range update.FeeRates { - pbFeeRates[i] = &pb.SourcedFeeRate{ - Network: string(fr.Network), - FeeRate: fr.FeeRate, - } - } - return &pb.NodeOracleUpdate{ - Update: &pb.NodeOracleUpdate_FeeRateUpdate{ - FeeRateUpdate: &pb.SourcedFeeRateUpdate{ - Source: update.Source, - Timestamp: update.Stamp.Unix(), - FeeRates: pbFeeRates, - }, + d.onScheduleChanged(&OracleSnapshot{ + Sources: map[string]*SourceStatus{ + d.source.Name(): status, }, - } + }) } diff --git a/oracle/diviner_test.go b/oracle/diviner_test.go index ae9f163..3bdda34 100644 --- a/oracle/diviner_test.go +++ b/oracle/diviner_test.go @@ -5,392 +5,269 @@ import ( "fmt" "math/big" "os" - "sync" - "sync/atomic" + "reflect" "testing" "time" "github.com/decred/slog" - "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/oracle/sources" ) -func TestDivinerFetchUpdates(t *testing.T) { - t.Run("fetches and emits price updates with weight", func(t *testing.T) { - emitted := make(chan *pb.NodeOracleUpdate, 1) - - fetcher := func(ctx context.Context) (any, error) { - return []*priceUpdate{ - {ticker: "BTC", price: 50000.0}, - {ticker: "ETH", price: 3000.0}, - }, nil - } - - publishUpdate := func(ctx context.Context, update *pb.NodeOracleUpdate) error { - emitted <- update - return nil - } +// mockSource implements sources.Source for testing. +type mockSource struct { + name string + weight float64 + period time.Duration + minPeriod time.Duration + quota *sources.QuotaStatus + fetchFunc func(ctx context.Context) (*sources.RateInfo, error) +} - div := newDiviner( - "test-source", - fetcher, - 0.8, - time.Minute*5, - time.Minute, - publishUpdate, - slog.NewBackend(os.Stdout).Logger("test"), - ) +func (m *mockSource) Name() string { return m.name } +func (m *mockSource) Weight() float64 { return m.weight } +func (m *mockSource) Period() time.Duration { return m.period } +func (m *mockSource) MinPeriod() time.Duration { return m.minPeriod } +func (m *mockSource) QuotaStatus() *sources.QuotaStatus { + if m.quota != nil { + return m.quota + } + return &sources.QuotaStatus{ + FetchesRemaining: 100, + FetchesLimit: 100, + ResetTime: time.Now().Add(24 * time.Hour), + } +} +func (m *mockSource) FetchRates(ctx context.Context) (*sources.RateInfo, error) { + return m.fetchFunc(ctx) +} - err := div.fetchUpdates(context.Background()) - if err != nil { - t.Fatalf("fetchUpdates failed: %v", err) - } +func TestDiviner(t *testing.T) { + tests := []struct { + name string + rateInfo *sources.RateInfo + fetchErr error + quota *sources.QuotaStatus + expectedUpdate *OracleUpdate + expectErrorMsg bool + }{ + { + name: "successful price fetch", + quota: &sources.QuotaStatus{ + FetchesRemaining: 42, + FetchesLimit: 100, + }, + rateInfo: &sources.RateInfo{ + Prices: []*sources.PriceUpdate{ + {Ticker: "BTC", Price: 50000.0}, + {Ticker: "ETH", Price: 3000.0}, + }, + }, + expectedUpdate: &OracleUpdate{ + Source: "test-source", + Prices: map[Ticker]float64{ + "BTC": 50000.0, + "ETH": 3000.0, + }, + Quota: &sources.QuotaStatus{ + FetchesRemaining: 42, + FetchesLimit: 100, + }, + }, + }, + { + name: "successful fee rate fetch", + quota: &sources.QuotaStatus{ + FetchesRemaining: 42, + FetchesLimit: 100, + }, + rateInfo: &sources.RateInfo{ + FeeRates: []*sources.FeeRateUpdate{ + {Network: "BTC", FeeRate: big.NewInt(50)}, + }, + }, + expectedUpdate: &OracleUpdate{ + Source: "test-source", + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(50), + }, + Quota: &sources.QuotaStatus{ + FetchesRemaining: 42, + FetchesLimit: 100, + }, + }, + }, + { + name: "fetch failure", + quota: &sources.QuotaStatus{ + FetchesRemaining: 42, + FetchesLimit: 100, + }, + fetchErr: fmt.Errorf("fetch error"), + expectErrorMsg: true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + log := slog.NewBackend(os.Stdout).Logger("test") + + resetTime := time.Now().Add(10 * time.Minute) + src := &mockSource{ + name: "test-source", + weight: 0.8, + period: 5 * time.Minute, + minPeriod: 30 * time.Second, + quota: &sources.QuotaStatus{ + FetchesRemaining: test.quota.FetchesRemaining, + FetchesLimit: test.quota.FetchesLimit, + ResetTime: resetTime, + }, + fetchFunc: func(ctx context.Context) (*sources.RateInfo, error) { + if test.fetchErr != nil { + return nil, test.fetchErr + } + return test.rateInfo, nil + }, + } - select { - case update := <-emitted: - if update.GetPriceUpdate() == nil { - t.Fatalf("Expected price update, got %T", update.Update) + baseTime := time.Unix(0, 0) + expectedSchedule := networkSchedule{ + NextFetchTime: baseTime.Add(30 * time.Second), + NetworkSustainableRate: 0.5, + MinPeriod: src.minPeriod, + NetworkSustainablePeriod: 2 * time.Second, + NetworkNextFetchTime: baseTime.Add(2 * time.Second), + OrderedNodes: []string{"node-a", "node-b"}, } - priceUpdate := update.GetPriceUpdate() - if priceUpdate.Source != "test-source" { - t.Errorf("Expected source 'test-source', got %s", priceUpdate.Source) + getNetworkSchedule := func() networkSchedule { + return expectedSchedule } - if len(priceUpdate.Prices) != 2 { - t.Errorf("Expected 2 prices, got %d", len(priceUpdate.Prices)) + + updateCh := make(chan *OracleUpdate, 1) + publishUpdate := func(ctx context.Context, update *OracleUpdate) error { + updateCh <- update + return nil } - case <-time.After(100 * time.Millisecond): - t.Error("Expected update to be emitted") - } - }) - t.Run("fetches and emits fee rate updates", func(t *testing.T) { - emitted := make(chan *pb.NodeOracleUpdate, 1) + scheduleCh := make(chan *OracleSnapshot, 1) + onScheduleChanged := func(update *OracleSnapshot) { + scheduleCh <- update + } - fetcher := func(ctx context.Context) (any, error) { - return []*feeRateUpdate{ - {network: "BTC", feeRate: big.NewInt(50)}, - }, nil - } + div := newDiviner(src, publishUpdate, log, getNetworkSchedule, onScheduleChanged) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go div.run(ctx) + + var ( + update *OracleUpdate + scheduleUpdate *OracleSnapshot + ) + + deadline := time.After(10 * time.Second) + for update == nil || scheduleUpdate == nil { + select { + case update = <-updateCh: + case scheduleUpdate = <-scheduleCh: + case <-deadline: + t.Fatal("Timed out waiting for updates") + } + if test.fetchErr != nil && scheduleUpdate != nil { + break + } + } - publishUpdate := func(ctx context.Context, update *pb.NodeOracleUpdate) error { - emitted <- update - return nil - } + if test.fetchErr == nil { + if update == nil { + t.Fatal("Expected a publish update") + } + expectedUpdate := cloneOracleUpdate(test.expectedUpdate) + if expectedUpdate != nil && expectedUpdate.Quota != nil { + expectedUpdate.Quota.ResetTime = resetTime + } + // The diviner sets Stamp to time.Now() at fetch time, so + // copy the actual stamp into the expected value before + // comparing. + expectedUpdate.Stamp = update.Stamp + if !reflect.DeepEqual(update, expectedUpdate) { + t.Errorf("Expected update %+v, got %+v", expectedUpdate, update) + } + } else if update != nil { + t.Fatal("Did not expect a publish update on error") + } - div := newDiviner( - "test-source", - fetcher, - 1.0, - time.Minute*5, - time.Minute, - publishUpdate, - slog.NewBackend(os.Stdout).Logger("test"), - ) + if scheduleUpdate == nil { + t.Fatal("Expected schedule update") + } - err := div.fetchUpdates(context.Background()) - if err != nil { - t.Fatalf("fetchUpdates failed: %v", err) - } + srcStatus, ok := scheduleUpdate.Sources["test-source"] + if !ok { + t.Fatal("Expected schedule update to contain 'test-source' in Sources") + } - select { - case update := <-emitted: - if update.GetFeeRateUpdate() == nil { - t.Fatalf("Expected fee rate update, got %T", update.Update) + minPeriod := expectedSchedule.MinPeriod + nsp := expectedSchedule.NetworkSustainablePeriod + nnft := expectedSchedule.NetworkNextFetchTime + expectedStatus := &SourceStatus{ + MinFetchInterval: &minPeriod, + NetworkSustainableRate: &expectedSchedule.NetworkSustainableRate, + NetworkSustainablePeriod: &nsp, + NetworkNextFetchTime: &nnft, + OrderedNodes: expectedSchedule.OrderedNodes, } - feeUpdate := update.GetFeeRateUpdate() - if feeUpdate.Source != "test-source" { - t.Errorf("Expected source 'test-source', got %s", feeUpdate.Source) + if test.fetchErr == nil { + nft := expectedSchedule.NextFetchTime + expectedStatus.NextFetchTime = &nft + } else { + expectedStatus.NextFetchTime = srcStatus.NextFetchTime } - if len(feeUpdate.FeeRates) == 0 { - t.Error("Expected at least one fee rate") + if test.expectErrorMsg { + expectedStatus.LastError = "fetch error" + expectedStatus.LastErrorTime = srcStatus.LastErrorTime } - case <-time.After(100 * time.Millisecond): - t.Error("Expected update to be emitted") - } - }) - t.Run("returns error on fetch failure", func(t *testing.T) { - fetcher := func(ctx context.Context) (any, error) { - return nil, fmt.Errorf("fetch error") - } - - div := newDiviner( - "test-source", - fetcher, - 1.0, - time.Minute*5, - time.Minute, - func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - err := div.fetchUpdates(context.Background()) - if err == nil { - t.Error("Expected error on fetch failure") - } - }) - - t.Run("includes weight in updates", func(t *testing.T) { - emitted := make(chan *pb.NodeOracleUpdate, 1) - - fetcher := func(ctx context.Context) (any, error) { - return []*priceUpdate{ - {ticker: "BTC", price: 50000.0}, - }, nil - } - - publishUpdate := func(ctx context.Context, update *pb.NodeOracleUpdate) error { - emitted <- update - return nil - } - - div := newDiviner( - "weighted-source", - fetcher, - 0.5, - time.Minute*5, - time.Minute, - publishUpdate, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - err := div.fetchUpdates(context.Background()) - if err != nil { - t.Fatalf("fetchUpdates failed: %v", err) - } - - select { - case update := <-emitted: - if update.GetPriceUpdate() == nil { - t.Fatalf("Expected price update") + if !reflect.DeepEqual(expectedStatus, srcStatus) { + t.Fatalf("Unexpected schedule update source status: %#v", srcStatus) } - // Weight is stored in diviner but not exposed in protobuf - case <-time.After(100 * time.Millisecond): - t.Error("Expected update to be emitted") - } - }) - - t.Run("rejects unexpected divination type", func(t *testing.T) { - fetcher := func(ctx context.Context) (any, error) { - return "invalid type", nil - } - div := newDiviner( - "test-source", - fetcher, - 1.0, - time.Minute*5, - time.Minute, - func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - err := div.fetchUpdates(context.Background()) - if err == nil { - t.Error("Expected error on unexpected divination type") - } - }) - - t.Run("publish error is logged but doesn't block", func(t *testing.T) { - emitted := make(chan *pb.NodeOracleUpdate, 10) - - fetcher := func(ctx context.Context) (any, error) { - return []*priceUpdate{ - {ticker: "BTC", price: 50000.0}, - }, nil - } - - // Publish function that returns error but still buffers to verify it was called - publishUpdate := func(ctx context.Context, update *pb.NodeOracleUpdate) error { - emitted <- update - return fmt.Errorf("publish error") - } - - div := newDiviner( - "test-source", - fetcher, - 1.0, - time.Millisecond, - time.Millisecond, - publishUpdate, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - err := div.fetchUpdates(context.Background()) - if err != nil { - t.Fatalf("fetchUpdates failed: %v", err) - } - - // The fire-and-forget goroutine should still send the update - // even though publishUpdate returns an error - select { - case <-emitted: - // Good, update was sent to publish even though it will fail - case <-time.After(100 * time.Millisecond): - t.Error("Expected publish to be attempted despite error") - } - }) + if test.fetchErr != nil && srcStatus.NextFetchTime.Sub(baseTime) < 50*time.Second { + t.Errorf("Expected retry next fetch to be ~1 minute later, got %v", srcStatus.NextFetchTime.Sub(baseTime)) + } + }) + } } -func TestDivinerRun(t *testing.T) { - t.Run("runs and fetches periodically", func(t *testing.T) { - callCount := int32(0) - - fetcher := func(ctx context.Context) (any, error) { - atomic.AddInt32(&callCount, 1) - return []*priceUpdate{ - {ticker: "BTC", price: 50000.0}, - }, nil - } - - div := newDiviner( - "test-source", - fetcher, - 1.0, - 50*time.Millisecond, - 25*time.Millisecond, - func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - ctx, cancel := context.WithCancel(context.Background()) - - go div.run(ctx) - - // Wait for at least 2 calls. The initial timer has a 5 second interval - // plus a random delay of up to 1 second, then subsequent calls at 50ms intervals. - // We need to wait: 5s (initial) + 1s (max delay) + 100ms (2 periods) = 6.1s - time.Sleep(6200 * time.Millisecond) - cancel() - - count := atomic.LoadInt32(&callCount) - if count < 2 { - t.Errorf("Expected at least 2 calls, got %d", count) - } - }) - - t.Run("stops on context cancellation", func(t *testing.T) { - fetcher := func(ctx context.Context) (any, error) { - return []*priceUpdate{ - {ticker: "BTC", price: 50000.0}, - }, nil - } - - div := newDiviner( - "test-source", - fetcher, - 1.0, - time.Hour, - time.Hour, - func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - ctx, cancel := context.WithCancel(context.Background()) +func cloneOracleUpdate(update *OracleUpdate) *OracleUpdate { + if update == nil { + return nil + } - done := make(chan struct{}) - go func() { - div.run(ctx) - close(done) - }() + clone := &OracleUpdate{ + Source: update.Source, + Stamp: update.Stamp, + } - // Cancel immediately - cancel() - - select { - case <-done: - // Good, run exited - case <-time.After(time.Second): - t.Error("run did not exit after context cancellation") - } - }) - - t.Run("reschedule resets timer", func(t *testing.T) { - callCount := int32(0) - - fetcher := func(ctx context.Context) (any, error) { - atomic.AddInt32(&callCount, 1) - return []*priceUpdate{ - {ticker: "BTC", price: 50000.0}, - }, nil - } - - div := newDiviner( - "test-source", - fetcher, - 1.0, - 500*time.Millisecond, - 500*time.Millisecond, - func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, - slog.NewBackend(os.Stdout).Logger("test"), - ) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go div.run(ctx) - - // Wait a bit, then reschedule multiple times - time.Sleep(50 * time.Millisecond) - div.reschedule() - time.Sleep(50 * time.Millisecond) - div.reschedule() - time.Sleep(50 * time.Millisecond) - - count := atomic.LoadInt32(&callCount) - - // Timer should be continuously reset, so we shouldn't have any calls yet - // (period is 500ms, we only waited ~150ms with resets) - if count > 0 { - t.Logf("Got %d calls (timer may have fired due to initial delay)", count) + if update.Prices != nil { + clone.Prices = make(map[Ticker]float64, len(update.Prices)) + for k, v := range update.Prices { + clone.Prices[k] = v } - }) - - t.Run("uses errPeriod on error", func(t *testing.T) { - callTimes := make([]time.Time, 0, 5) - var mu sync.Mutex - - fetcher := func(ctx context.Context) (any, error) { - mu.Lock() - callTimes = append(callTimes, time.Now()) - mu.Unlock() - // Return error to trigger errPeriod - return nil, fmt.Errorf("fetch error") - } - - div := newDiviner( - "test-source", - fetcher, - 1.0, - 50*time.Millisecond, - 30*time.Millisecond, - func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, - slog.NewBackend(os.Stdout).Logger("test"), - ) + } - ctx, cancel := context.WithCancel(context.Background()) - - go div.run(ctx) - - // Wait for at least 2 error retries. The initial timer has a 5 second interval - // plus a random delay of up to 1 second, then subsequent retries at errPeriod (30ms). - // We need to wait: 5s (initial) + 1s (max delay) + 60ms (2 errPeriods) = 6.06s - time.Sleep(6200 * time.Millisecond) - cancel() - - mu.Lock() - times := callTimes - mu.Unlock() - - if len(times) < 2 { - t.Fatalf("Expected at least 2 calls, got %d", len(times)) + if update.FeeRates != nil { + clone.FeeRates = make(map[Network]*big.Int, len(update.FeeRates)) + for k, v := range update.FeeRates { + clone.FeeRates[k] = new(big.Int).Set(v) } + } - // Check interval between calls - should be closer to errPeriod - interval := times[1].Sub(times[0]) - if interval > 500*time.Millisecond { - t.Errorf("Expected short retry interval (errPeriod), got %v", interval) - } - }) + if update.Quota != nil { + q := *update.Quota + clone.Quota = &q + } + return clone } diff --git a/oracle/fetch_tracker.go b/oracle/fetch_tracker.go new file mode 100644 index 0000000..8eb0620 --- /dev/null +++ b/oracle/fetch_tracker.go @@ -0,0 +1,199 @@ +package oracle + +import ( + "sync" + "time" +) + +const trackingPeriod = 24 * time.Hour + +// fetchRecord represents a single fetch event. +type fetchRecord struct { + SourceID uint16 + NodeID uint16 + Stamp time.Time +} + +// fetchTracker tracks fetch events for the past 24 hours. +type fetchTracker struct { + mtx sync.Mutex + records []fetchRecord + // To reduce memory, records store uint16 IDs rather than full strings. + // ID mappings are append-only and bounded by uint16 max (65535). + sourceIDs map[string]uint16 + nodeIDs map[string]uint16 + sourceNames []string + nodeNames []string + nextSourceID uint16 + nextNodeID uint16 + counts map[uint16]map[uint16]int + latest map[uint16]time.Time +} + +// newFetchTracker creates a new fetchTracker. +func newFetchTracker() *fetchTracker { + return &fetchTracker{ + sourceIDs: make(map[string]uint16), + nodeIDs: make(map[string]uint16), + counts: make(map[uint16]map[uint16]int), + latest: make(map[uint16]time.Time), + } +} + +// recordFetch records a fetch event. +func (ft *fetchTracker) recordFetch(source, nodeID string, stamp time.Time) { + ft.mtx.Lock() + defer ft.mtx.Unlock() + sourceID, ok := assignID(source, ft.sourceIDs, &ft.sourceNames, &ft.nextSourceID) + if !ok { + return + } + nodeIDInt, ok := assignID(nodeID, ft.nodeIDs, &ft.nodeNames, &ft.nextNodeID) + if !ok { + return + } + r := fetchRecord{ + SourceID: sourceID, + NodeID: nodeIDInt, + Stamp: stamp, + } + + // Insert in sorted order by stamp. Almost always appends since records + // arrive roughly chronologically. + i := len(ft.records) + for i > 0 && ft.records[i-1].Stamp.After(stamp) { + i-- + } + ft.records = append(ft.records, r) + if i < len(ft.records)-1 { + copy(ft.records[i+1:], ft.records[i:len(ft.records)-1]) + ft.records[i] = r + } + + // Update the counts and latest fetch for the source. + if ft.counts[sourceID] == nil { + ft.counts[sourceID] = make(map[uint16]int) + } + ft.counts[sourceID][nodeIDInt]++ + if existing, ok := ft.latest[sourceID]; !ok || stamp.After(existing) { + ft.latest[sourceID] = stamp + } +} + +// assignID returns the uint16 ID for name, creating one if needed. Returns +// false if the ID space (uint16) is exhausted. +func assignID(name string, ids map[string]uint16, names *[]string, next *uint16) (uint16, bool) { + if id, ok := ids[name]; ok { + return id, true + } + if *next == ^uint16(0) { + return 0, false + } + id := *next + *next++ + ids[name] = id + *names = append(*names, name) + return id, true +} + +// dropExpired removes records older than the given cutoff. Records are kept +// sorted by stamp, so we scan from the front and stop at the first non-expired. +// ft.mtx MUST be locked when calling this function. +func (ft *fetchTracker) dropExpired(cutoff time.Time) { + expiredCount := 0 + for expiredCount < len(ft.records) && ft.records[expiredCount].Stamp.Before(cutoff) { + r := ft.records[expiredCount] + if nodes, ok := ft.counts[r.SourceID]; ok { + if nodes[r.NodeID] > 1 { + nodes[r.NodeID]-- + } else { + delete(nodes, r.NodeID) + if len(nodes) == 0 { + delete(ft.counts, r.SourceID) + } + } + } + expiredCount++ + } + if expiredCount > 0 { + ft.records = ft.records[expiredCount:] + } + for sourceID, stamp := range ft.latest { + if stamp.Before(cutoff) { + delete(ft.latest, sourceID) + } + } +} + +// sourceFetchCounts returns per-node fetch counts for a single source over the +// past 24 hours. +func (ft *fetchTracker) sourceFetchCounts(source string) map[string]int { + ft.mtx.Lock() + defer ft.mtx.Unlock() + ft.dropExpired(time.Now().Add(-trackingPeriod)) + sourceID, ok := ft.sourceIDs[source] + if !ok { + return nil + } + nodes, ok := ft.counts[sourceID] + if !ok { + return nil + } + result := make(map[string]int, len(nodes)) + for nodeID, count := range nodes { + if name, ok := ft.nodeName(nodeID); ok { + result[name] = count + } + } + return result +} + +// fetchCounts returns per-source, per-node fetch counts for the past 24 hours. +func (ft *fetchTracker) fetchCounts() map[string]map[string]int { + ft.mtx.Lock() + defer ft.mtx.Unlock() + ft.dropExpired(time.Now().Add(-trackingPeriod)) + result := make(map[string]map[string]int, len(ft.counts)) + for sourceID, nodes := range ft.counts { + source, ok := ft.sourceName(sourceID) + if !ok { + continue + } + nodeMap := make(map[string]int, len(nodes)) + for nodeID, count := range nodes { + if name, ok := ft.nodeName(nodeID); ok { + nodeMap[name] = count + } + } + result[source] = nodeMap + } + return result +} + +// latestPerSource returns the most recent fetch timestamp per source. +func (ft *fetchTracker) latestPerSource() map[string]time.Time { + ft.mtx.Lock() + defer ft.mtx.Unlock() + ft.dropExpired(time.Now().Add(-trackingPeriod)) + result := make(map[string]time.Time, len(ft.latest)) + for sourceID, stamp := range ft.latest { + if source, ok := ft.sourceName(sourceID); ok { + result[source] = stamp + } + } + return result +} + +func (ft *fetchTracker) sourceName(id uint16) (string, bool) { + if int(id) >= len(ft.sourceNames) { + return "", false + } + return ft.sourceNames[id], true +} + +func (ft *fetchTracker) nodeName(id uint16) (string, bool) { + if int(id) >= len(ft.nodeNames) { + return "", false + } + return ft.nodeNames[id], true +} diff --git a/oracle/fetch_tracker_test.go b/oracle/fetch_tracker_test.go new file mode 100644 index 0000000..3946a21 --- /dev/null +++ b/oracle/fetch_tracker_test.go @@ -0,0 +1,86 @@ +package oracle + +import ( + "testing" + "time" +) + +func TestFetchTracker_RecordAndCounts(t *testing.T) { + ft := newFetchTracker() + now := time.Now() + + ft.recordFetch("source1", "node-a", now) + ft.recordFetch("source1", "node-a", now.Add(-time.Hour)) + ft.recordFetch("source1", "node-b", now) + ft.recordFetch("source2", "node-a", now) + + counts := ft.fetchCounts() + if counts["source1"]["node-a"] != 2 { + t.Errorf("expected 2, got %d", counts["source1"]["node-a"]) + } + if counts["source1"]["node-b"] != 1 { + t.Errorf("expected 1, got %d", counts["source1"]["node-b"]) + } + if counts["source2"]["node-a"] != 1 { + t.Errorf("expected 1, got %d", counts["source2"]["node-a"]) + } +} + +func TestFetchTracker_LatestPerSource(t *testing.T) { + ft := newFetchTracker() + now := time.Now() + + ft.recordFetch("source1", "node-a", now.Add(-time.Hour)) + ft.recordFetch("source1", "node-b", now) + ft.recordFetch("source2", "node-a", now.Add(-2*time.Hour)) + + latest := ft.latestPerSource() + + if !latest["source1"].Equal(now) { + t.Errorf("expected latest stamp for source1 to be %v, got %v", now, latest["source1"]) + } + if !latest["source2"].Equal(now.Add(-2 * time.Hour)) { + t.Errorf("expected latest stamp for source2 to be %v, got %v", now.Add(-2*time.Hour), latest["source2"]) + } +} + +func TestFetchTracker_CountsExcludes24hOld(t *testing.T) { + ft := newFetchTracker() + now := time.Now() + + ft.recordFetch("source1", "node-a", now.Add(-25*time.Hour)) + ft.recordFetch("source1", "node-a", now) + + counts := ft.fetchCounts() + if counts["source1"]["node-a"] != 1 { + t.Errorf("expected count to be 1 (excluding old record), got %d", counts["source1"]["node-a"]) + } +} + +func TestFetchTracker_OutOfOrderExpiry(t *testing.T) { + ft := newFetchTracker() + now := time.Now() + + // Insert a recent record followed by an expired one (out of order). + ft.recordFetch("source1", "node-a", now) + ft.recordFetch("source1", "node-a", now.Add(-25*time.Hour)) + + counts := ft.fetchCounts() + if counts["source1"]["node-a"] != 1 { + t.Errorf("expected count to be 1 (out-of-order expired record should be dropped), got %d", counts["source1"]["node-a"]) + } +} + +func TestFetchTracker_Empty(t *testing.T) { + ft := newFetchTracker() + + counts := ft.fetchCounts() + if len(counts) != 0 { + t.Errorf("expected empty counts, got %d entries", len(counts)) + } + + latest := ft.latestPerSource() + if len(latest) != 0 { + t.Errorf("expected empty latest, got %d entries", len(latest)) + } +} diff --git a/oracle/oracle.go b/oracle/oracle.go index 269fb56..c007bad 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -2,26 +2,15 @@ package oracle import ( "context" + "fmt" "math/big" "net/http" - "slices" "sync" "time" "github.com/decred/slog" - "github.com/bisoncraft/mesh/tatanka/pb" -) - -const ( - fullValidityPeriod = time.Minute * 5 - validityExpiration = time.Minute * 30 - decayPeriod = validityExpiration - fullValidityPeriod - requestTimeout = time.Second * 5 - - // PriceTopicPrefix is the topic prefix for price updates sent to clients. - PriceTopicPrefix = "price." - // FeeRateTopicPrefix is the topic prefix for fee rate updates sent to clients. - FeeRateTopicPrefix = "fee_rate." + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/providers" ) // Ticker is the upper-case symbol used to indicate an asset. @@ -30,48 +19,20 @@ type Ticker string // Network is the network symbol of a Blockchain. type Network string -// SourcedPrice represents a single price entry within a sourced update batch. -type SourcedPrice struct { - Ticker Ticker - Price float64 -} - -// SourcedPriceUpdate is a batch of price updates from a single source, used for -// sharing with other Tatanka Mesh nodes. -type SourcedPriceUpdate struct { - Source string - Stamp time.Time - Weight float64 - Prices []*SourcedPrice -} - -// SourcedFeeRate represents a single fee rate entry within a sourced update batch. -type SourcedFeeRate struct { - Network Network - FeeRate []byte // big-endian encoded big integer -} - -// SourcedFeeRateUpdate is a batch of fee rate updates from a single source, used -// for sharing with other Tatanka Mesh nodes. -type SourcedFeeRateUpdate struct { +// OracleUpdate is the payload published to the mesh for oracle data. +// At least one of Prices or FeeRates should be populated. +type OracleUpdate struct { Source string Stamp time.Time - Weight float64 - FeeRates []*SourcedFeeRate + Prices map[Ticker]float64 + FeeRates map[Network]*big.Int + Quota *sources.QuotaStatus } -// PriceUpdate is an aggregated price update. These are emitted when an update -// is received from a source. -type PriceUpdate struct { - Ticker Ticker - Price float64 -} - -// FeeRateUpdate is an aggregated fee rate update. These are emitted when an -// update is received from a source. -type FeeRateUpdate struct { - Network Network - FeeRate *big.Int +// MergeResult contains the aggregated rates that changed after a merge. +type MergeResult struct { + Prices map[Ticker]float64 + FeeRates map[Network]*big.Int } // HTTPClient defines the requirements for implementing an http client. @@ -79,347 +40,394 @@ type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } +// Config contains configuration for the Oracle. type Config struct { - Log slog.Logger - CMCKey string - TatumKey string - CryptoApisKey string - HTTPClient HTTPClient // Optional. If nil, http.DefaultClient is used. - PublishUpdate func(ctx context.Context, update *pb.NodeOracleUpdate) error + // NodeID is the ID of the local node running the oracle. + NodeID string + + // PublishUpdate is called when the oracle has fetched new data from a + // source. + PublishUpdate func(ctx context.Context, update *OracleUpdate) error + + // OnStateUpdate is called when some state in the oracle has changed. + // Only the updated fields are populated. The full snapshot can be fetched + // using OracleSnapshot, and then updates received on this function can be + // combined with the full snapshot to get the current state. + OnStateUpdate func(*OracleSnapshot) + + // PublishQuotaHeartbeat is called periodically to update other nodes with + // the current quota status for all sources. + PublishQuotaHeartbeat func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error + + // Log is the logger used to log messages. + Log slog.Logger + + // CMCKey is the token used to fetch data from the CoinMarketCap API. + CMCKey string + + // TatumKey is the token used to fetch data from the Tatum API. + TatumKey string + + // BlockcypherToken is the token used to fetch data from the Blockcypher API. + BlockcypherToken string + + // HTTPClient is the HTTP client used to fetch data from the sources. + // If nil, http.DefaultClient is used. + HTTPClient HTTPClient +} + +// verify validates the Oracle configuration. +func (cfg *Config) verify() error { + if cfg == nil { + return fmt.Errorf("oracle config is nil") + } + if cfg.PublishUpdate == nil { + return fmt.Errorf("publish update callback is required") + } + if cfg.OnStateUpdate == nil { + return fmt.Errorf("state update callback is required") + } + if cfg.PublishQuotaHeartbeat == nil { + return fmt.Errorf("publish quota heartbeat callback is required") + } + if cfg.NodeID == "" { + return fmt.Errorf("node ID is required") + } + return nil } +// Oracle manages price and fee rate data from multiple sources. type Oracle struct { - log slog.Logger - httpClient HTTPClient - httpSources []*httpSource + log slog.Logger + httpClient HTTPClient + srcs []sources.Source feeRatesMtx sync.RWMutex - feeRates map[Network]map[string]*feeRateUpdate + feeRates map[Network]*feeRateBucket pricesMtx sync.RWMutex - prices map[Ticker]map[string]*priceUpdate + prices map[Ticker]*priceBucket divinersMtx sync.RWMutex diviners map[string]*diviner - publishUpdate func(ctx context.Context, update *pb.NodeOracleUpdate) error + publishUpdate func(ctx context.Context, update *OracleUpdate) error + onStateUpdate func(*OracleSnapshot) + quotaManager *quotaManager + fetchTracker *fetchTracker + nodeID string } +// New creates a new Oracle with the given configuration. func New(cfg *Config) (*Oracle, error) { - httpSources := slices.Clone(unauthedHttpSources) + if err := cfg.verify(); err != nil { + return nil, err + } - if cfg.CMCKey != "" { - httpSources = append(httpSources, coinmarketcapSource(cfg.CMCKey)) + httpClient := cfg.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient } - if cfg.TatumKey != "" { - httpSources = append(httpSources, - tatumBitcoinSource(cfg.TatumKey), - tatumLitecoinSource(cfg.TatumKey), - tatumDogecoinSource(cfg.TatumKey), - ) + // Add all sources that don't require an API key. + unlimitedSources := []sources.Source{ + providers.NewDcrdataSource(httpClient, cfg.Log), + providers.NewMempoolDotSpaceSource(httpClient, cfg.Log), + providers.NewCoinpaprikaSource(httpClient, cfg.Log), + providers.NewBitcoreBitcoinCashSource(httpClient, cfg.Log), + providers.NewBitcoreDogecoinSource(httpClient, cfg.Log), + providers.NewBitcoreLitecoinSource(httpClient, cfg.Log), + providers.NewFiroOrgSource(httpClient, cfg.Log), } + allSources := make([]sources.Source, 0, len(unlimitedSources)) + allSources = append(allSources, unlimitedSources...) - if cfg.CryptoApisKey != "" { - httpSources = append(httpSources, - cryptoApisBitcoinSource(cfg.CryptoApisKey), - cryptoApisBitcoinCashSource(cfg.CryptoApisKey), - cryptoApisDogecoinSource(cfg.CryptoApisKey), - cryptoApisDashSource(cfg.CryptoApisKey), - cryptoApisLitecoinSource(cfg.CryptoApisKey), - ) + if cfg.BlockcypherToken != "" { + blockcypherSource := providers.NewBlockcypherLitecoinSource(httpClient, cfg.Log, cfg.BlockcypherToken) + allSources = append(allSources, blockcypherSource) } - if err := setHTTPSourceDefaults(httpSources); err != nil { - return nil, err + if cfg.CMCKey != "" { + cmcSource := providers.NewCoinMarketCapSource(httpClient, cfg.Log, cfg.CMCKey) + allSources = append(allSources, cmcSource) } - httpClient := cfg.HTTPClient - if httpClient == nil { - httpClient = http.DefaultClient + if cfg.TatumKey != "" { + tatumSources := providers.NewTatumSources(providers.TatumConfig{ + HTTPClient: httpClient, + Log: cfg.Log, + APIKey: cfg.TatumKey, + }) + allSources = append(allSources, tatumSources.All()...) } + quotaManager := newQuotaManager("aManagerConfig{ + log: cfg.Log, + nodeID: cfg.NodeID, + publishQuotaHeartbeat: cfg.PublishQuotaHeartbeat, + onStateUpdate: cfg.OnStateUpdate, + sources: allSources, + }) + oracle := &Oracle{ log: cfg.Log, httpClient: httpClient, - httpSources: httpSources, - feeRates: make(map[Network]map[string]*feeRateUpdate), - prices: make(map[Ticker]map[string]*priceUpdate), + srcs: allSources, + feeRates: make(map[Network]*feeRateBucket), + prices: make(map[Ticker]*priceBucket), diviners: make(map[string]*diviner), publishUpdate: cfg.PublishUpdate, + onStateUpdate: cfg.OnStateUpdate, + quotaManager: quotaManager, + fetchTracker: newFetchTracker(), + nodeID: cfg.NodeID, } - for _, source := range httpSources { - src := source - fetcher := func(ctx context.Context) (any, error) { - return src.fetch(ctx, httpClient) - } - div := newDiviner(src.name, fetcher, src.weight, src.period, src.errPeriod, oracle.publishUpdate, oracle.log) - oracle.diviners[div.name] = div + // Create diviners for each source + for _, src := range oracle.srcs { + getNetworkSchedule := func(s sources.Source) func() networkSchedule { + return func() networkSchedule { + return quotaManager.getNetworkSchedule(s.Name(), s.MinPeriod()) + } + }(src) + div := newDiviner(src, oracle.publishUpdate, oracle.log, getNetworkSchedule, cfg.OnStateUpdate) + oracle.diviners[src.Name()] = div } return oracle, nil } -// priceWeightCounter is used to calculate weighted averages for prices. -type priceWeightCounter struct { - weightedSum float64 - totalWeight float64 -} - -// feeRateWeightCounter is used to calculate weighted fee rate averages with arbitrary precision. -type feeRateWeightCounter struct { - weightedSum *big.Float - totalWeight float64 -} - -// agedWeight returns a weight based on the age of an update. -func agedWeight(weight float64, stamp time.Time) float64 { - // Older updates lose weight. - age := time.Since(stamp) - if age < 0 { - age = 0 - } +// allFeeRates returns the aggregated tx fee rates for all known networks. +func (o *Oracle) allFeeRates() map[Network]*big.Int { + o.feeRatesMtx.RLock() + defer o.feeRatesMtx.RUnlock() - switch { - case age < fullValidityPeriod: - return weight - case age > validityExpiration: - return 0 - default: - // Calculate remaining validity as a fraction of the decay period. - remainingValidity := validityExpiration - age - return weight * (float64(remainingValidity) / float64(decayPeriod)) + feeRates := make(map[Network]*big.Int, len(o.feeRates)) + for net, bucket := range o.feeRates { + if rate := bucket.aggregatedRate(); rate != nil && rate.Sign() > 0 { + feeRates[net] = rate + } } + return feeRates } -func (o *Oracle) getFeeRates(nets map[Network]bool) map[Network]*big.Int { - o.feeRatesMtx.RLock() - size := len(nets) - if nets == nil { - size = len(o.feeRates) +// Merge merges an oracle update from another node into this oracle. +// Returns the aggregated rates that changed. +func (o *Oracle) Merge(update *OracleUpdate, senderID string) *MergeResult { + if update == nil || (len(update.Prices) == 0 && len(update.FeeRates) == 0) { + return nil } - counters := make(map[Network]*feeRateWeightCounter, size) - for net, updates := range o.feeRates { - if nets != nil && !nets[net] { - continue - } - - counter, found := counters[net] - if !found { - counter = &feeRateWeightCounter{ - weightedSum: new(big.Float), - } - counters[net] = counter - } - - for _, entry := range updates { - weight := agedWeight(entry.weight, entry.stamp) - if weight == 0 { - continue - } - counter.totalWeight += weight + weight := o.sourceWeight(update.Source) + result := &MergeResult{} - // Multiply weight (float64) by feeRate (big.Int) using big.Float - weightFloat := new(big.Float).SetFloat64(weight) - feeRateFloat := new(big.Float).SetInt(entry.feeRate) - product := new(big.Float).Mul(weightFloat, feeRateFloat) - counter.weightedSum.Add(counter.weightedSum, product) - } + if len(update.FeeRates) > 0 { + result.FeeRates = o.mergeFeeRates(update, weight) } - - o.feeRatesMtx.RUnlock() - - // Calculate weighted averages. - feeRates := make(map[Network]*big.Int, len(counters)) - for net, counter := range counters { - if counter.totalWeight == 0 { - continue - } - - // Divide weightedSum (big.Float) by totalWeight (float64) - totalWeightFloat := new(big.Float).SetFloat64(counter.totalWeight) - avgFloat := new(big.Float).Quo(counter.weightedSum, totalWeightFloat) - - // Round to nearest integer - if avgFloat.Sign() >= 0 { - avgFloat.Add(avgFloat, new(big.Float).SetFloat64(0.5)) - } else { - avgFloat.Sub(avgFloat, new(big.Float).SetFloat64(0.5)) - } - - // Convert to big.Int (this truncates towards zero after rounding) - rounded := new(big.Int) - avgFloat.Int(rounded) - feeRates[net] = rounded + if len(update.Prices) > 0 { + result.Prices = o.mergePrices(update, weight) } - return feeRates -} + o.fetchTracker.recordFetch(update.Source, senderID, update.Stamp) + o.rescheduleDiviner(update.Source, senderID) -// FeeRates returns the aggregated tx fee rates for all known networks. -func (o *Oracle) FeeRates() map[Network]*big.Int { - return o.getFeeRates(nil) + return result } -// MergeFeeRates merges fee rates from another oracle into this oracle. -// Returns a map of the networks whose aggregated fee rates were updated. -func (o *Oracle) MergeFeeRates(sourcedUpdate *SourcedFeeRateUpdate) map[Network]*big.Int { - if sourcedUpdate == nil || len(sourcedUpdate.FeeRates) == 0 { +func (o *Oracle) mergeFeeRates(update *OracleUpdate, weight float64) map[Network]*big.Int { + if len(update.FeeRates) == 0 { return nil } - o.feeRatesMtx.Lock() - updatedNetworks := make(map[Network]bool) + updatedFeeRates := make(map[Network]*big.Int) + snapshotFeeRates := make(map[string]*SnapshotRate) + var latestFeeRates map[string]string - for _, fr := range sourcedUpdate.FeeRates { + for network, feeRate := range update.FeeRates { proposedUpdate := &feeRateUpdate{ - network: fr.Network, - feeRate: bytesToBigInt(fr.FeeRate), - stamp: sourcedUpdate.Stamp, - weight: sourcedUpdate.Weight, + network: network, + feeRate: feeRate, + stamp: update.Stamp, + weight: weight, } - netSources, found := o.feeRates[fr.Network] - if !found { - o.feeRates[fr.Network] = map[string]*feeRateUpdate{ - sourcedUpdate.Source: proposedUpdate, + + bucket := o.getOrCreateFeeRateBucket(network) + updated, agg := bucket.mergeAndUpdateAggregate(update.Source, proposedUpdate) + if updated && agg.Sign() > 0 { + updatedFeeRates[network] = agg + snapshotFeeRates[string(network)] = &SnapshotRate{ + Value: agg.String(), + Contributions: map[string]*SourceContribution{ + update.Source: { + Value: feeRate.String(), + Stamp: update.Stamp, + Weight: weight, + }, + }, } - updatedNetworks[fr.Network] = true - continue - } - existingUpdate, found := netSources[sourcedUpdate.Source] - if !found { - netSources[sourcedUpdate.Source] = proposedUpdate - updatedNetworks[fr.Network] = true - continue - } - if sourcedUpdate.Stamp.After(existingUpdate.stamp) { - netSources[sourcedUpdate.Source] = proposedUpdate - updatedNetworks[fr.Network] = true + if latestFeeRates == nil { + latestFeeRates = make(map[string]string) + } + latestFeeRates[string(network)] = feeRate.String() } } - o.feeRatesMtx.Unlock() - o.rescheduleDiviner(sourcedUpdate.Source) + if len(snapshotFeeRates) > 0 { + fetchCounts := o.fetchTracker.sourceFetchCounts(update.Source) + stamp := update.Stamp + o.onStateUpdate(&OracleSnapshot{ + Sources: map[string]*SourceStatus{ + update.Source: { + LastFetch: &stamp, + Fetches24h: fetchCounts, + LatestData: map[string]map[string]string{ + FeeRateData: latestFeeRates, + }, + }, + }, + FeeRates: snapshotFeeRates, + }) + } - return o.getFeeRates(updatedNetworks) + return updatedFeeRates } -func (o *Oracle) getPrices(tickers map[Ticker]bool) map[Ticker]float64 { +// allPrices returns the aggregated prices for all known tickers. +func (o *Oracle) allPrices() map[Ticker]float64 { o.pricesMtx.RLock() - size := len(tickers) - if tickers == nil { - size = len(o.prices) - } - counters := make(map[Ticker]*priceWeightCounter, size) + defer o.pricesMtx.RUnlock() - for ticker, updates := range o.prices { - if tickers != nil && !tickers[ticker] { - continue - } - counter, found := counters[ticker] - if !found { - counter = &priceWeightCounter{} - counters[ticker] = counter - } - for _, entry := range updates { - weight := agedWeight(entry.weight, entry.stamp) - if weight == 0 { - continue - } - counter.totalWeight += weight - counter.weightedSum += weight * entry.price + prices := make(map[Ticker]float64, len(o.prices)) + for ticker, bucket := range o.prices { + if price := bucket.aggregatedPrice(); price > 0 { + prices[ticker] = price } } - o.pricesMtx.RUnlock() + return prices +} - priceMap := make(map[Ticker]float64, len(counters)) - for ticker, counter := range counters { - if counter.totalWeight == 0 { - continue - } - priceMap[ticker] = counter.weightedSum / counter.totalWeight +// Price returns the cached aggregated price for a single ticker. +func (o *Oracle) Price(ticker Ticker) (float64, bool) { + bucket := o.getPriceBucket(ticker) + if bucket == nil { + return 0, false } - - return priceMap + if price := bucket.aggregatedPrice(); price > 0 { + return price, true + } + return 0, false } -// Prices returns the aggregated prices for all known tickers. -func (o *Oracle) Prices() map[Ticker]float64 { - return o.getPrices(nil) +// FeeRate returns the cached aggregated fee rate for a single network. +func (o *Oracle) FeeRate(network Network) (*big.Int, bool) { + bucket := o.getFeeRateBucket(network) + if bucket == nil { + return nil, false + } + if rate := bucket.aggregatedRate(); rate != nil && rate.Sign() > 0 { + return rate, true + } + return nil, false } -func (o *Oracle) rescheduleDiviner(name string) { +func (o *Oracle) rescheduleDiviner(name string, lastFetchNodeID string) { + // diviner reschedules itself after a fetch. + if lastFetchNodeID == o.nodeID { + return + } + o.divinersMtx.RLock() div, found := o.diviners[name] o.divinersMtx.RUnlock() if !found { - // Do nothing. return } div.reschedule() } -// GetSourceWeight returns the configured weight for a source by name. +// sourceWeight returns the configured weight for a source by name. // If the source is not found, returns 1.0 as a default weight. -func (o *Oracle) GetSourceWeight(sourceName string) float64 { +func (o *Oracle) sourceWeight(sourceName string) float64 { o.divinersMtx.RLock() div, found := o.diviners[sourceName] o.divinersMtx.RUnlock() if !found { return 1.0 } - return div.weight + return div.source.Weight() } -// MergePrices merges prices from another oracle into this oracle. -// Returns a map of the tickers whose aggregated prices were updated. -func (o *Oracle) MergePrices(sourcedUpdate *SourcedPriceUpdate) map[Ticker]float64 { - if sourcedUpdate == nil || len(sourcedUpdate.Prices) == 0 { +func (o *Oracle) mergePrices(update *OracleUpdate, weight float64) map[Ticker]float64 { + if len(update.Prices) == 0 { return nil } - o.pricesMtx.Lock() - updatedTickers := make(map[Ticker]bool) + updatedPrices := make(map[Ticker]float64) + snapshotPrices := make(map[string]*SnapshotRate) + var latestPrices map[string]string - for _, p := range sourcedUpdate.Prices { + for ticker, price := range update.Prices { proposedUpdate := &priceUpdate{ - ticker: p.Ticker, - price: p.Price, - stamp: sourcedUpdate.Stamp, - weight: sourcedUpdate.Weight, + ticker: ticker, + price: price, + stamp: update.Stamp, + weight: weight, } - tickerSources, found := o.prices[p.Ticker] - if !found { - o.prices[p.Ticker] = map[string]*priceUpdate{ - sourcedUpdate.Source: proposedUpdate, + + bucket := o.getOrCreatePriceBucket(ticker) + updated, agg := bucket.mergeAndUpdateAggregate(update.Source, proposedUpdate) + if updated && agg > 0 { + updatedPrices[ticker] = agg + snapshotPrices[string(ticker)] = &SnapshotRate{ + Value: fmt.Sprintf("%f", agg), + Contributions: map[string]*SourceContribution{ + update.Source: { + Value: fmt.Sprintf("%f", price), + Stamp: update.Stamp, + Weight: weight, + }, + }, } - updatedTickers[p.Ticker] = true - continue - } - existingUpdate, found := tickerSources[sourcedUpdate.Source] - if !found { - tickerSources[sourcedUpdate.Source] = proposedUpdate - updatedTickers[p.Ticker] = true - continue - } - if sourcedUpdate.Stamp.After(existingUpdate.stamp) { - tickerSources[sourcedUpdate.Source] = proposedUpdate - updatedTickers[p.Ticker] = true + if latestPrices == nil { + latestPrices = make(map[string]string) + } + latestPrices[string(ticker)] = fmt.Sprintf("%f", price) } } - o.pricesMtx.Unlock() - o.rescheduleDiviner(sourcedUpdate.Source) + if len(snapshotPrices) > 0 { + fetchCounts := o.fetchTracker.sourceFetchCounts(update.Source) + stamp := update.Stamp + o.onStateUpdate(&OracleSnapshot{ + Sources: map[string]*SourceStatus{ + update.Source: { + LastFetch: &stamp, + Fetches24h: fetchCounts, + LatestData: map[string]map[string]string{ + PriceData: latestPrices, + }, + }, + }, + Prices: snapshotPrices, + }) + } - return o.getPrices(updatedTickers) + return updatedPrices } +// Run starts the oracle and blocks until the context is done. func (o *Oracle) Run(ctx context.Context) { var wg sync.WaitGroup + // Run quota manager. + wg.Add(1) + go func() { + defer wg.Done() + o.quotaManager.run(ctx) + }() + + // Run all diviners o.divinersMtx.RLock() for _, div := range o.diviners { wg.Add(1) @@ -433,23 +441,12 @@ func (o *Oracle) Run(ctx context.Context) { wg.Wait() } -// bytesToBigInt converts big-endian encoded bytes to big.Int. -func bytesToBigInt(b []byte) *big.Int { - if len(b) == 0 { - return big.NewInt(0) - } - return new(big.Int).SetBytes(b) -} - -// bigIntToBytes converts big.Int to big-endian encoded bytes. -func bigIntToBytes(bi *big.Int) []byte { - if bi == nil || bi.Sign() == 0 { - return []byte{0} - } - return bi.Bytes() +// GetLocalQuotas returns all local source quotas for handshake/heartbeat. +func (o *Oracle) GetLocalQuotas() map[string]*sources.QuotaStatus { + return o.quotaManager.getLocalQuotas() } -// uint64ToBigInt converts uint64 to big.Int. -func uint64ToBigInt(val uint64) *big.Int { - return new(big.Int).SetUint64(val) +// UpdatePeerSourceQuota processes a single source's quota from a peer node. +func (o *Oracle) UpdatePeerSourceQuota(peerID string, quota *TimestampedQuotaStatus, source string) { + o.quotaManager.handlePeerSourceQuota(peerID, quota, source) } diff --git a/oracle/oracle_test.go b/oracle/oracle_test.go index 841f6c4..5145410 100644 --- a/oracle/oracle_test.go +++ b/oracle/oracle_test.go @@ -10,289 +10,49 @@ import ( "time" "github.com/decred/slog" - "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/bisoncraft/mesh/oracle/sources" ) - -func TestGetPrices(t *testing.T) { - backend := slog.NewBackend(os.Stdout) - log := backend.Logger("test") - now := time.Now() - - tests := []struct { - name string - prices map[Ticker]map[string]*priceUpdate - filter map[Ticker]bool - expected map[Ticker]float64 - }{ - { - name: "single source per ticker", - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, - }, - "ETH": { - "source1": {ticker: "ETH", price: 3000.0, stamp: now, weight: 1.0}, - }, - }, - filter: nil, - expected: map[Ticker]float64{ - "BTC": 50000.0, - "ETH": 3000.0, - }, - }, - { - name: "multiple sources weighted average", - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, - "source2": {ticker: "BTC", price: 52000.0, stamp: now, weight: 1.0}, - }, - }, - filter: nil, - expected: map[Ticker]float64{ - "BTC": 51000.0, - }, - }, - { - name: "different weights", - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 0.25}, - "source2": {ticker: "BTC", price: 52000.0, stamp: now, weight: 0.75}, - }, - }, - filter: nil, - expected: map[Ticker]float64{ - "BTC": 51500.0, - }, - }, - { - name: "aged weights", - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, - "source2": {ticker: "BTC", price: 30000.0, stamp: now.Add(-validityExpiration - time.Second), weight: 1.0}, - }, - }, - filter: nil, - expected: map[Ticker]float64{ - "BTC": 50000.0, - }, - }, - { - name: "filtered tickers", - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, - }, - "ETH": { - "source1": {ticker: "ETH", price: 3000.0, stamp: now, weight: 1.0}, - }, - "DCR": { - "source1": {ticker: "DCR", price: 25.0, stamp: now, weight: 1.0}, - }, - }, - filter: map[Ticker]bool{ - "BTC": true, - "ETH": true, - }, - expected: map[Ticker]float64{ - "BTC": 50000.0, - "ETH": 3000.0, - }, - }, - { - name: "all expired sources", - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now.Add(-validityExpiration - time.Second), weight: 1.0}, - "source2": {ticker: "BTC", price: 52000.0, stamp: now.Add(-validityExpiration - time.Second), weight: 1.0}, - }, - }, - filter: nil, - expected: map[Ticker]float64{}, - }, - { - name: "empty oracle", - prices: map[Ticker]map[string]*priceUpdate{}, - filter: nil, - expected: map[Ticker]float64{}, - }, +// makePriceBuckets converts a test-friendly format to the Oracle's bucket format. +func makePriceBuckets(m map[Ticker]map[string]*priceUpdate) map[Ticker]*priceBucket { + result := make(map[Ticker]*priceBucket, len(m)) + for ticker, sources := range m { + bucket := newPriceBucket() + for source, update := range sources { + bucket.mergeAndUpdateAggregate(source, update) + } + result[ticker] = bucket } + return result +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: tt.prices, - } - - result := oracle.getPrices(tt.filter) - - if len(result) != len(tt.expected) { - t.Errorf("Expected %d tickers, got %d", len(tt.expected), len(result)) - } - - for ticker, expectedPrice := range tt.expected { - actualPrice, found := result[ticker] - if !found { - t.Errorf("Expected ticker %s to be in result", ticker) - continue - } - if actualPrice != expectedPrice { - t.Errorf("For ticker %s, expected price %.2f, got %.2f", - ticker, expectedPrice, actualPrice) - } - } - - for ticker := range result { - if _, expected := tt.expected[ticker]; !expected { - t.Errorf("Unexpected ticker %s in result", ticker) - } - } - }) +// makeFeeRateBuckets converts a test-friendly format to the Oracle's bucket format. +func makeFeeRateBuckets(m map[Network]map[string]*feeRateUpdate) map[Network]*feeRateBucket { + result := make(map[Network]*feeRateBucket, len(m)) + for network, sources := range m { + bucket := newFeeRateBucket() + for source, update := range sources { + bucket.mergeAndUpdateAggregate(source, update) + } + result[network] = bucket } + return result } -func TestGetFeeRates(t *testing.T) { - backend := slog.NewBackend(os.Stdout) - log := backend.Logger("test") - now := time.Now() - - tests := []struct { - name string - feeRates map[Network]map[string]*feeRateUpdate - filter map[Network]bool - expected map[Network]*big.Int - }{ - { - name: "single source per network", - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, - }, - "ETH": { - "source1": {network: "ETH", feeRate: big.NewInt(200), stamp: now, weight: 1.0}, - }, - }, - filter: nil, - expected: map[Network]*big.Int{ - "BTC": big.NewInt(100), - "ETH": big.NewInt(200), - }, - }, - { - name: "multiple sources weighted average", - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, - "source2": {network: "BTC", feeRate: big.NewInt(200), stamp: now, weight: 1.0}, - }, - }, - filter: nil, - expected: map[Network]*big.Int{ - "BTC": big.NewInt(150), - }, - }, - { - name: "different weights", - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 0.25}, - "source2": {network: "BTC", feeRate: big.NewInt(200), stamp: now, weight: 0.75}, - }, - }, - filter: nil, - expected: map[Network]*big.Int{ - "BTC": big.NewInt(175), - }, - }, - { - name: "aged weights", - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, - "source2": {network: "BTC", feeRate: big.NewInt(200), stamp: now.Add(-validityExpiration - time.Second), weight: 1.0}, - }, - }, - filter: nil, - expected: map[Network]*big.Int{ - "BTC": big.NewInt(100), - }, - }, - { - name: "filtered networks", - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, - }, - "ETH": { - "source1": {network: "ETH", feeRate: big.NewInt(200), stamp: now, weight: 1.0}, - }, - "DCR": { - "source1": {network: "DCR", feeRate: big.NewInt(50), stamp: now, weight: 1.0}, - }, - }, - filter: map[Network]bool{ - "BTC": true, - "ETH": true, - }, - expected: map[Network]*big.Int{ - "BTC": big.NewInt(100), - "ETH": big.NewInt(200), - }, - }, - { - name: "all expired sources", - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now.Add(-validityExpiration - time.Second), weight: 1.0}, - "source2": {network: "BTC", feeRate: big.NewInt(200), stamp: now.Add(-validityExpiration - time.Second), weight: 1.0}, - }, - }, - filter: nil, - expected: map[Network]*big.Int{}, - }, - { - name: "empty oracle", - feeRates: map[Network]map[string]*feeRateUpdate{}, - filter: nil, - expected: map[Network]*big.Int{}, - }, +func newTestOracle(log slog.Logger) *Oracle { + return &Oracle{ + log: log, + prices: make(map[Ticker]*priceBucket), + feeRates: make(map[Network]*feeRateBucket), + diviners: make(map[string]*diviner), + fetchTracker: newFetchTracker(), + onStateUpdate: func(*OracleSnapshot) {}, } +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: tt.feeRates, - } - - result := oracle.getFeeRates(tt.filter) - - if len(result) != len(tt.expected) { - t.Errorf("Expected %d networks, got %d", len(tt.expected), len(result)) - } - - for network, expectedRate := range tt.expected { - actualRate, found := result[network] - if !found { - t.Errorf("Expected network %s to be in result", network) - continue - } - if actualRate.Cmp(expectedRate) != 0 { - t.Errorf("For network %s, expected fee rate %s, got %s", - network, expectedRate.String(), actualRate.String()) - } - } - - for network := range result { - if _, expected := tt.expected[network]; !expected { - t.Errorf("Unexpected network %s in result", network) - } - } - }) +func setSourceWeights(oracle *Oracle, weights map[string]float64) { + for name, weight := range weights { + oracle.diviners[name] = &diviner{source: &mockSource{name: name, weight: weight}} } } @@ -304,19 +64,19 @@ func TestMergePrices(t *testing.T) { tests := []struct { name string existingPrices map[Ticker]map[string]*priceUpdate - sourcedUpdate *SourcedPriceUpdate + update *OracleUpdate + sourceWeights map[string]float64 expectedPrices map[Ticker]map[string]*priceUpdate expectedResult map[Ticker]float64 }{ { name: "new ticker from external source", existingPrices: map[Ticker]map[string]*priceUpdate{}, - sourcedUpdate: &SourcedPriceUpdate{ + update: &OracleUpdate{ Source: "external-oracle", Stamp: now, - Weight: 1.0, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: 50000.0}, + Prices: map[Ticker]float64{ + "BTC": 50000.0, }, }, expectedPrices: map[Ticker]map[string]*priceUpdate{ @@ -345,12 +105,11 @@ func TestMergePrices(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedPriceUpdate{ + update: &OracleUpdate{ Source: "external-oracle", Stamp: newerStamp, - Weight: 1.0, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: 50000.0}, + Prices: map[Ticker]float64{ + "BTC": 50000.0, }, }, expectedPrices: map[Ticker]map[string]*priceUpdate{ @@ -379,12 +138,11 @@ func TestMergePrices(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedPriceUpdate{ + update: &OracleUpdate{ Source: "external-oracle", Stamp: oldStamp, - Weight: 1.0, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: 48000.0}, + Prices: map[Ticker]float64{ + "BTC": 48000.0, }, }, expectedPrices: map[Ticker]map[string]*priceUpdate{ @@ -411,15 +169,17 @@ func TestMergePrices(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedPriceUpdate{ + update: &OracleUpdate{ Source: "source2", Stamp: now, - Weight: 0.8, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: 51000.0}, - {Ticker: "ETH", Price: 3000.0}, + Prices: map[Ticker]float64{ + "BTC": 51000.0, + "ETH": 3000.0, }, }, + sourceWeights: map[string]float64{ + "source2": 0.8, + }, expectedPrices: map[Ticker]map[string]*priceUpdate{ "BTC": { "source1": { @@ -461,12 +221,11 @@ func TestMergePrices(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedPriceUpdate{ + update: &OracleUpdate{ Source: "source2", Stamp: now, - Weight: 1.0, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: 51000.0}, + Prices: map[Ticker]float64{ + "BTC": 51000.0, }, }, expectedPrices: map[Ticker]map[string]*priceUpdate{ @@ -496,13 +255,19 @@ func TestMergePrices(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: tt.existingPrices, - diviners: make(map[string]*diviner), + oracle := newTestOracle(log) + oracle.prices = makePriceBuckets(tt.existingPrices) + if len(tt.sourceWeights) > 0 { + setSourceWeights(oracle, tt.sourceWeights) } - result := oracle.MergePrices(tt.sourcedUpdate) + mergeResult := oracle.Merge(tt.update, "test-sender") + + // Extract price results + var result map[Ticker]float64 + if mergeResult != nil { + result = mergeResult.Prices + } // Verify the merged prices match expected if len(oracle.prices) != len(tt.expectedPrices) { @@ -510,19 +275,19 @@ func TestMergePrices(t *testing.T) { } for ticker, expectedSources := range tt.expectedPrices { - actualSources, found := oracle.prices[ticker] + actualBucket, found := oracle.prices[ticker] if !found { t.Errorf("Expected ticker %s to be in oracle.prices", ticker) continue } - if len(actualSources) != len(expectedSources) { + if len(actualBucket.sources) != len(expectedSources) { t.Errorf("For ticker %s, expected %d sources, got %d", - ticker, len(expectedSources), len(actualSources)) + ticker, len(expectedSources), len(actualBucket.sources)) } for source, expectedUpdate := range expectedSources { - actualUpdate, found := actualSources[source] + actualUpdate, found := actualBucket.sources[source] if !found { t.Errorf("Expected source %s for ticker %s", source, ticker) continue @@ -587,19 +352,19 @@ func TestMergeFeeRates(t *testing.T) { tests := []struct { name string existingFeeRates map[Network]map[string]*feeRateUpdate - sourcedUpdate *SourcedFeeRateUpdate + update *OracleUpdate + sourceWeights map[string]float64 expectedFeeRates map[Network]map[string]*feeRateUpdate expectedResult map[Network]*big.Int }{ { name: "new network from external source", existingFeeRates: map[Network]map[string]*feeRateUpdate{}, - sourcedUpdate: &SourcedFeeRateUpdate{ + update: &OracleUpdate{ Source: "external-oracle", Stamp: now, - Weight: 1.0, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: []byte{0, 0, 0, 100}}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(100), }, }, expectedFeeRates: map[Network]map[string]*feeRateUpdate{ @@ -628,12 +393,11 @@ func TestMergeFeeRates(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedFeeRateUpdate{ + update: &OracleUpdate{ Source: "external-oracle", Stamp: newerStamp, - Weight: 1.0, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: []byte{0, 0, 0, 100}}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(100), }, }, expectedFeeRates: map[Network]map[string]*feeRateUpdate{ @@ -662,12 +426,11 @@ func TestMergeFeeRates(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedFeeRateUpdate{ + update: &OracleUpdate{ Source: "external-oracle", Stamp: oldStamp, - Weight: 1.0, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: []byte{0, 0, 0, 80}}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(80), }, }, expectedFeeRates: map[Network]map[string]*feeRateUpdate{ @@ -694,15 +457,17 @@ func TestMergeFeeRates(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedFeeRateUpdate{ + update: &OracleUpdate{ Source: "source2", Stamp: now, - Weight: 0.8, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: []byte{0, 0, 0, 120}}, - {Network: "ETH", FeeRate: []byte{0, 0, 0, 50}}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(120), + "ETH": big.NewInt(50), }, }, + sourceWeights: map[string]float64{ + "source2": 0.8, + }, expectedFeeRates: map[Network]map[string]*feeRateUpdate{ "BTC": { "source1": { @@ -744,12 +509,11 @@ func TestMergeFeeRates(t *testing.T) { }, }, }, - sourcedUpdate: &SourcedFeeRateUpdate{ + update: &OracleUpdate{ Source: "source2", Stamp: now, - Weight: 1.0, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: []byte{0, 0, 0, 120}}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(120), }, }, expectedFeeRates: map[Network]map[string]*feeRateUpdate{ @@ -779,13 +543,19 @@ func TestMergeFeeRates(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: tt.existingFeeRates, - diviners: make(map[string]*diviner), + oracle := newTestOracle(log) + oracle.feeRates = makeFeeRateBuckets(tt.existingFeeRates) + if len(tt.sourceWeights) > 0 { + setSourceWeights(oracle, tt.sourceWeights) } - result := oracle.MergeFeeRates(tt.sourcedUpdate) + mergeResult := oracle.Merge(tt.update, "test-sender") + + // Extract fee rate results + var result map[Network]*big.Int + if mergeResult != nil { + result = mergeResult.FeeRates + } // Verify the merged fee rates match expected if len(oracle.feeRates) != len(tt.expectedFeeRates) { @@ -793,19 +563,19 @@ func TestMergeFeeRates(t *testing.T) { } for network, expectedSources := range tt.expectedFeeRates { - actualSources, found := oracle.feeRates[network] + actualBucket, found := oracle.feeRates[network] if !found { t.Errorf("Expected network %s to be in oracle.feeRates", network) continue } - if len(actualSources) != len(expectedSources) { + if len(actualBucket.sources) != len(expectedSources) { t.Errorf("For network %s, expected %d sources, got %d", - network, len(expectedSources), len(actualSources)) + network, len(expectedSources), len(actualBucket.sources)) } for source, expectedUpdate := range expectedSources { - actualUpdate, found := actualSources[source] + actualUpdate, found := actualBucket.sources[source] if !found { t.Errorf("Expected source %s for network %s", source, network) continue @@ -862,26 +632,22 @@ func TestMergeFeeRates(t *testing.T) { } } - func TestConcurrency(t *testing.T) { backend := slog.NewBackend(os.Stdout) log := backend.Logger("test") t.Run("multiple goroutines reading prices simultaneously", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: make(map[Ticker]map[string]*priceUpdate), - } - now := time.Now() - // Pre-populate with some price data - oracle.prices["BTC"] = map[string]*priceUpdate{ - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, - "source2": {ticker: "BTC", price: 51000.0, stamp: now, weight: 1.0}, - } - oracle.prices["ETH"] = map[string]*priceUpdate{ - "source1": {ticker: "ETH", price: 3000.0, stamp: now, weight: 1.0}, - } + oracle := newTestOracle(log) + oracle.prices = makePriceBuckets(map[Ticker]map[string]*priceUpdate{ + "BTC": { + "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, + "source2": {ticker: "BTC", price: 51000.0, stamp: now, weight: 1.0}, + }, + "ETH": { + "source1": {ticker: "ETH", price: 3000.0, stamp: now, weight: 1.0}, + }, + }) // Launch multiple readers concurrently const numReaders = 50 @@ -890,7 +656,7 @@ func TestConcurrency(t *testing.T) { for i := 0; i < numReaders; i++ { go func() { for j := 0; j < 100; j++ { - prices := oracle.Prices() + prices := oracle.allPrices() if len(prices) > 0 { // Verify data integrity if btcPrice, found := prices["BTC"]; found { @@ -911,20 +677,17 @@ func TestConcurrency(t *testing.T) { }) t.Run("multiple goroutines reading fee rates simultaneously", func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: make(map[Network]map[string]*feeRateUpdate), - } - now := time.Now() - // Pre-populate with some fee rate data - oracle.feeRates["BTC"] = map[string]*feeRateUpdate{ - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, - "source2": {network: "BTC", feeRate: big.NewInt(120), stamp: now, weight: 1.0}, - } - oracle.feeRates["ETH"] = map[string]*feeRateUpdate{ - "source1": {network: "ETH", feeRate: big.NewInt(50), stamp: now, weight: 1.0}, - } + oracle := newTestOracle(log) + oracle.feeRates = makeFeeRateBuckets(map[Network]map[string]*feeRateUpdate{ + "BTC": { + "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, + "source2": {network: "BTC", feeRate: big.NewInt(120), stamp: now, weight: 1.0}, + }, + "ETH": { + "source1": {network: "ETH", feeRate: big.NewInt(50), stamp: now, weight: 1.0}, + }, + }) const numReaders = 50 done := make(chan bool, numReaders) @@ -932,7 +695,7 @@ func TestConcurrency(t *testing.T) { for i := 0; i < numReaders; i++ { go func() { for j := 0; j < 100; j++ { - feeRates := oracle.FeeRates() + feeRates := oracle.allFeeRates() if len(feeRates) > 0 { // Verify data integrity if btcRate, found := feeRates["BTC"]; found { @@ -953,11 +716,7 @@ func TestConcurrency(t *testing.T) { }) t.Run("concurrent reads and writes of prices", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: make(map[Ticker]map[string]*priceUpdate), - diviners: make(map[string]*diviner), - } + oracle := newTestOracle(log) const numReaders = 20 const numWriters = 5 @@ -969,8 +728,8 @@ func TestConcurrency(t *testing.T) { for i := 0; i < numReaders; i++ { go func() { for j := 0; j < 50; j++ { - _ = oracle.Prices() - _ = oracle.getPrices(map[Ticker]bool{"BTC": true}) + _ = oracle.allPrices() + _, _ = oracle.Price("BTC") } done <- true }() @@ -981,16 +740,15 @@ func TestConcurrency(t *testing.T) { writerID := i go func() { for j := 0; j < 10; j++ { - sourcedUpdate := &SourcedPriceUpdate{ + update := &OracleUpdate{ Source: fmt.Sprintf("writer-%d", writerID), Stamp: now.Add(time.Duration(j) * time.Millisecond), - Weight: 1.0, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: float64(50000 + j)}, - {Ticker: "ETH", Price: float64(3000 + j)}, + Prices: map[Ticker]float64{ + "BTC": float64(50000 + j), + "ETH": float64(3000 + j), }, } - oracle.MergePrices(sourcedUpdate) + oracle.Merge(update, fmt.Sprintf("writer-%d", writerID)) } done <- true }() @@ -1003,11 +761,7 @@ func TestConcurrency(t *testing.T) { }) t.Run("concurrent reads and writes of fee rates", func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: make(map[Network]map[string]*feeRateUpdate), - diviners: make(map[string]*diviner), - } + oracle := newTestOracle(log) const numReaders = 20 const numWriters = 5 @@ -1019,8 +773,8 @@ func TestConcurrency(t *testing.T) { for i := 0; i < numReaders; i++ { go func() { for j := 0; j < 50; j++ { - _ = oracle.FeeRates() - _ = oracle.getFeeRates(map[Network]bool{"BTC": true}) + _ = oracle.allFeeRates() + _, _ = oracle.FeeRate("BTC") } done <- true }() @@ -1031,16 +785,15 @@ func TestConcurrency(t *testing.T) { writerID := i go func() { for j := 0; j < 10; j++ { - sourcedUpdate := &SourcedFeeRateUpdate{ + update := &OracleUpdate{ Source: fmt.Sprintf("writer-%d", writerID), Stamp: now.Add(time.Duration(j) * time.Millisecond), - Weight: 1.0, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: bigIntToBytes(big.NewInt(int64(100 + j)))}, - {Network: "ETH", FeeRate: bigIntToBytes(big.NewInt(int64(50 + j)))}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(int64(100 + j)), + "ETH": big.NewInt(int64(50 + j)), }, } - oracle.MergeFeeRates(sourcedUpdate) + oracle.Merge(update, fmt.Sprintf("writer-%d", writerID)) } done <- true }() @@ -1053,12 +806,7 @@ func TestConcurrency(t *testing.T) { }) t.Run("concurrent merge and read operations", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: make(map[Ticker]map[string]*priceUpdate), - feeRates: make(map[Network]map[string]*feeRateUpdate), - diviners: make(map[string]*diviner), - } + oracle := newTestOracle(log) const numReaders = 20 const numMergers = 10 @@ -1070,8 +818,8 @@ func TestConcurrency(t *testing.T) { for i := 0; i < numReaders; i++ { go func() { for j := 0; j < 50; j++ { - _ = oracle.Prices() - _ = oracle.FeeRates() + _ = oracle.allPrices() + _ = oracle.allFeeRates() } done <- true }() @@ -1082,25 +830,17 @@ func TestConcurrency(t *testing.T) { mergerID := i go func() { for j := 0; j < 10; j++ { - sourcedPrices := &SourcedPriceUpdate{ + update := &OracleUpdate{ Source: fmt.Sprintf("merger-%d", mergerID), Stamp: now.Add(time.Duration(j) * time.Millisecond), - Weight: 1.0, - Prices: []*SourcedPrice{ - {Ticker: "BTC", Price: float64(50000 + j)}, + Prices: map[Ticker]float64{ + "BTC": float64(50000 + j), }, - } - oracle.MergePrices(sourcedPrices) - - sourcedFeeRates := &SourcedFeeRateUpdate{ - Source: fmt.Sprintf("merger-%d", mergerID), - Stamp: now.Add(time.Duration(j) * time.Millisecond), - Weight: 1.0, - FeeRates: []*SourcedFeeRate{ - {Network: "BTC", FeeRate: bigIntToBytes(big.NewInt(int64(100 + j)))}, + FeeRates: map[Network]*big.Int{ + "BTC": big.NewInt(int64(100 + j)), }, } - oracle.MergeFeeRates(sourcedFeeRates) + oracle.Merge(update, fmt.Sprintf("merger-%d", mergerID)) } done <- true }() @@ -1119,19 +859,17 @@ func TestPublicPrices(t *testing.T) { now := time.Now() t.Run("returns all prices", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: map[Ticker]map[string]*priceUpdate{ - "BTC": { - "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, - }, - "ETH": { - "source1": {ticker: "ETH", price: 3000.0, stamp: now, weight: 1.0}, - }, + oracle := newTestOracle(log) + oracle.prices = makePriceBuckets(map[Ticker]map[string]*priceUpdate{ + "BTC": { + "source1": {ticker: "BTC", price: 50000.0, stamp: now, weight: 1.0}, }, - } + "ETH": { + "source1": {ticker: "ETH", price: 3000.0, stamp: now, weight: 1.0}, + }, + }) - result := oracle.Prices() + result := oracle.allPrices() if len(result) != 2 { t.Errorf("Expected 2 prices, got %d", len(result)) @@ -1147,12 +885,9 @@ func TestPublicPrices(t *testing.T) { }) t.Run("returns empty map for empty oracle", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: make(map[Ticker]map[string]*priceUpdate), - } + oracle := newTestOracle(log) - result := oracle.Prices() + result := oracle.allPrices() if len(result) != 0 { t.Errorf("Expected 0 prices, got %d", len(result)) @@ -1166,19 +901,17 @@ func TestPublicFeeRates(t *testing.T) { now := time.Now() t.Run("returns all fee rates", func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: map[Network]map[string]*feeRateUpdate{ - "BTC": { - "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, - }, - "ETH": { - "source1": {network: "ETH", feeRate: big.NewInt(50), stamp: now, weight: 1.0}, - }, + oracle := newTestOracle(log) + oracle.feeRates = makeFeeRateBuckets(map[Network]map[string]*feeRateUpdate{ + "BTC": { + "source1": {network: "BTC", feeRate: big.NewInt(100), stamp: now, weight: 1.0}, }, - } + "ETH": { + "source1": {network: "ETH", feeRate: big.NewInt(50), stamp: now, weight: 1.0}, + }, + }) - result := oracle.FeeRates() + result := oracle.allFeeRates() if len(result) != 2 { t.Errorf("Expected 2 fee rates, got %d", len(result)) @@ -1194,12 +927,9 @@ func TestPublicFeeRates(t *testing.T) { }) t.Run("returns empty map for empty oracle", func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: make(map[Network]map[string]*feeRateUpdate), - } + oracle := newTestOracle(log) - result := oracle.FeeRates() + result := oracle.allFeeRates() if len(result) != 0 { t.Errorf("Expected 0 fee rates, got %d", len(result)) @@ -1211,15 +941,11 @@ func TestMergeWithEmptyUpdates(t *testing.T) { backend := slog.NewBackend(os.Stdout) log := backend.Logger("test") - t.Run("MergePrices with nil", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: make(map[Ticker]map[string]*priceUpdate), - diviners: make(map[string]*diviner), - } + t.Run("Merge with nil", func(t *testing.T) { + oracle := newTestOracle(log) // Should not panic - result := oracle.MergePrices(nil) + result := oracle.Merge(nil, "test-sender") if result != nil { t.Errorf("Expected nil result, got %v", result) @@ -1230,53 +956,26 @@ func TestMergeWithEmptyUpdates(t *testing.T) { } }) - t.Run("MergeFeeRates with nil", func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: make(map[Network]map[string]*feeRateUpdate), - diviners: make(map[string]*diviner), - } - - // Should not panic - result := oracle.MergeFeeRates(nil) - - if result != nil { - t.Errorf("Expected nil result, got %v", result) - } - - if len(oracle.feeRates) != 0 { - t.Errorf("Expected no fee rates, got %d", len(oracle.feeRates)) - } - }) + t.Run("Merge with empty prices map", func(t *testing.T) { + oracle := newTestOracle(log) - t.Run("MergePrices with empty prices slice", func(t *testing.T) { - oracle := &Oracle{ - log: log, - prices: make(map[Ticker]map[string]*priceUpdate), - diviners: make(map[string]*diviner), - } - - result := oracle.MergePrices(&SourcedPriceUpdate{ + result := oracle.Merge(&OracleUpdate{ Source: "test", - Prices: []*SourcedPrice{}, - }) + Prices: map[Ticker]float64{}, + }, "test-sender") if result != nil { t.Errorf("Expected nil result, got %v", result) } }) - t.Run("MergeFeeRates with empty fee rates slice", func(t *testing.T) { - oracle := &Oracle{ - log: log, - feeRates: make(map[Network]map[string]*feeRateUpdate), - diviners: make(map[string]*diviner), - } + t.Run("Merge with empty fee rates map", func(t *testing.T) { + oracle := newTestOracle(log) - result := oracle.MergeFeeRates(&SourcedFeeRateUpdate{ + result := oracle.Merge(&OracleUpdate{ Source: "test", - FeeRates: []*SourcedFeeRate{}, - }) + FeeRates: map[Network]*big.Int{}, + }, "test-sender") if result != nil { t.Errorf("Expected nil result, got %v", result) @@ -1419,76 +1118,22 @@ func TestAgedWeightBoundaries(t *testing.T) { }) } -func TestGetSourceWeight(t *testing.T) { - backend := slog.NewBackend(os.Stdout) - log := backend.Logger("test") - - t.Run("returns weight for existing source", func(t *testing.T) { - div1 := &diviner{name: "source1", weight: 0.8} - div2 := &diviner{name: "source2", weight: 0.5} - - oracle := &Oracle{ - log: log, - diviners: map[string]*diviner{ - "source1": div1, - "source2": div2, - }, - } - - weight := oracle.GetSourceWeight("source1") - if weight != 0.8 { - t.Errorf("Expected weight 0.8, got %.1f", weight) - } - - weight = oracle.GetSourceWeight("source2") - if weight != 0.5 { - t.Errorf("Expected weight 0.5, got %.1f", weight) - } - }) - - t.Run("returns default weight for non-existent source", func(t *testing.T) { - oracle := &Oracle{ - log: log, - diviners: make(map[string]*diviner), - } - - weight := oracle.GetSourceWeight("non-existent") - if weight != 1.0 { - t.Errorf("Expected default weight 1.0, got %.1f", weight) - } - }) - - t.Run("returns default weight when diviners is empty", func(t *testing.T) { - oracle := &Oracle{ - log: log, - diviners: make(map[string]*diviner), - } - - weight := oracle.GetSourceWeight("any-source") - if weight != 1.0 { - t.Errorf("Expected default weight 1.0, got %.1f", weight) - } - }) -} - func TestRescheduleDiviner(t *testing.T) { backend := slog.NewBackend(os.Stdout) log := backend.Logger("test") t.Run("reschedules existing diviner", func(t *testing.T) { mockDiv := &diviner{ - name: "test-source", + source: &mockSource{name: "test-source"}, resetTimer: make(chan struct{}, 1), } - oracle := &Oracle{ - log: log, - diviners: map[string]*diviner{ - "test-source": mockDiv, - }, + oracle := newTestOracle(log) + oracle.diviners = map[string]*diviner{ + "test-source": mockDiv, } - oracle.rescheduleDiviner("test-source") + oracle.rescheduleDiviner("test-source", "other-node") // Verify the reschedule signal was sent select { @@ -1500,23 +1145,17 @@ func TestRescheduleDiviner(t *testing.T) { }) t.Run("does nothing for non-existent diviner", func(t *testing.T) { - oracle := &Oracle{ - log: log, - diviners: make(map[string]*diviner), - } + oracle := newTestOracle(log) // Should not panic - oracle.rescheduleDiviner("non-existent") + oracle.rescheduleDiviner("non-existent", "other-node") }) t.Run("does nothing when diviners is empty", func(t *testing.T) { - oracle := &Oracle{ - log: log, - diviners: make(map[string]*diviner), - } + oracle := newTestOracle(log) // Should not panic - oracle.rescheduleDiviner("any-source") + oracle.rescheduleDiviner("any-source", "other-node") }) } @@ -1525,13 +1164,15 @@ func TestRun(t *testing.T) { log := backend.Logger("test") t.Run("Run completes with no diviners", func(t *testing.T) { - oracle := &Oracle{ - log: log, - diviners: make(map[string]*diviner), - } + qm := newQuotaManager("aManagerConfig{ + log: log, + nodeID: "test-node", + }) + oracle := newTestOracle(log) + oracle.diviners = make(map[string]*diviner) + oracle.quotaManager = qm ctx, cancel := context.WithCancel(context.Background()) - defer cancel() done := make(chan struct{}) go func() { @@ -1539,26 +1180,51 @@ func TestRun(t *testing.T) { close(done) }() + // Cancel immediately since there are no diviners + cancel() + select { case <-done: - // Success - Run completed immediately + // Success - Run exited after cancel case <-time.After(time.Second): - t.Error("Run did not complete with empty diviners") + t.Error("Run did not complete after context cancellation") } }) t.Run("Run waits for diviners and exits on context cancellation", func(t *testing.T) { + qm := newQuotaManager("aManagerConfig{ + log: log, + nodeID: "test-node", + }) + // Create mock diviners that wait for context mockDiviners := make(map[string]*diviner) for i := 0; i < 2; i++ { name := fmt.Sprintf("source%d", i) - mockDiviners[name] = &diviner{name: name} + localName := name + mockDiviners[name] = &diviner{ + source: &mockSource{ + name: name, + minPeriod: time.Hour, // Long period to avoid immediate fetch + fetchFunc: func(ctx context.Context) (*sources.RateInfo, error) { + <-ctx.Done() // Block until context cancelled + return nil, ctx.Err() + }, + }, + resetTimer: make(chan struct{}), + log: log, + getNetworkSchedule: func() networkSchedule { + now := time.Now() + activePeers := qm.getActivePeersForSource(localName, now) + return computeNetworkSchedule(activePeers, "local", time.Hour, now) + }, + onScheduleChanged: func(*OracleSnapshot) {}, + } } - oracle := &Oracle{ - log: log, - diviners: mockDiviners, - } + oracle := newTestOracle(log) + oracle.diviners = mockDiviners + oracle.quotaManager = qm ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) @@ -1589,8 +1255,11 @@ func TestNewOracle(t *testing.T) { t.Run("creates oracle with default sources", func(t *testing.T) { cfg := &Config{ - Log: log, - PublishUpdate: func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, + Log: log, + NodeID: "test-node", + PublishUpdate: func(ctx context.Context, update *OracleUpdate) error { return nil }, + OnStateUpdate: func(*OracleSnapshot) {}, + PublishQuotaHeartbeat: func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { return nil }, } oracle, err := New(cfg) @@ -1617,8 +1286,11 @@ func TestNewOracle(t *testing.T) { t.Run("initializes with unauthed sources", func(t *testing.T) { cfg := &Config{ - Log: log, - PublishUpdate: func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, + Log: log, + NodeID: "test-node", + PublishUpdate: func(ctx context.Context, update *OracleUpdate) error { return nil }, + OnStateUpdate: func(*OracleSnapshot) {}, + PublishQuotaHeartbeat: func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { return nil }, } oracle, err := New(cfg) @@ -1634,9 +1306,12 @@ func TestNewOracle(t *testing.T) { t.Run("nil http client uses default client", func(t *testing.T) { cfg := &Config{ - Log: log, - HTTPClient: nil, - PublishUpdate: func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, + Log: log, + NodeID: "test-node", + HTTPClient: nil, + PublishUpdate: func(ctx context.Context, update *OracleUpdate) error { return nil }, + OnStateUpdate: func(*OracleSnapshot) {}, + PublishQuotaHeartbeat: func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { return nil }, } oracle, err := New(cfg) @@ -1652,9 +1327,12 @@ func TestNewOracle(t *testing.T) { t.Run("custom http client is used", func(t *testing.T) { customClient := &mockHTTPClient{} cfg := &Config{ - Log: log, - HTTPClient: customClient, - PublishUpdate: func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, + Log: log, + NodeID: "test-node", + HTTPClient: customClient, + PublishUpdate: func(ctx context.Context, update *OracleUpdate) error { return nil }, + OnStateUpdate: func(*OracleSnapshot) {}, + PublishQuotaHeartbeat: func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { return nil }, } oracle, err := New(cfg) @@ -1669,8 +1347,11 @@ func TestNewOracle(t *testing.T) { t.Run("initializes empty price and fee rate maps", func(t *testing.T) { cfg := &Config{ - Log: log, - PublishUpdate: func(ctx context.Context, update *pb.NodeOracleUpdate) error { return nil }, + Log: log, + NodeID: "test-node", + PublishUpdate: func(ctx context.Context, update *OracleUpdate) error { return nil }, + OnStateUpdate: func(*OracleSnapshot) {}, + PublishQuotaHeartbeat: func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { return nil }, } oracle, err := New(cfg) diff --git a/oracle/quota_manager.go b/oracle/quota_manager.go new file mode 100644 index 0000000..6a1d12a --- /dev/null +++ b/oracle/quota_manager.go @@ -0,0 +1,364 @@ +package oracle + +import ( + "context" + "crypto/sha256" + "fmt" + "math/big" + "sort" + "sync" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" +) + +// TimestampedQuotaStatus wraps a QuotaStatus with the time it was received. +type TimestampedQuotaStatus struct { + *sources.QuotaStatus + ReceivedAt time.Time +} + +// networkSchedule contains the coordinated fetch schedule for a source. +type networkSchedule struct { + NextFetchTime time.Time + NetworkSustainableRate float64 + MinPeriod time.Duration + NetworkSustainablePeriod time.Duration + NetworkNextFetchTime time.Time + OrderedNodes []string +} + +const ( + // maxPeriod is the maximum period between fetches for a source. + maxPeriod = 1 * time.Hour + // quotaPeerActiveThreshold is the threshold for a peer to be considered "active". + // If there have been no quota updates from a peer within this period, they will + // not be considered as participating in the fetching for this source. + quotaPeerActiveThreshold = 6 * time.Minute + // quotaHeartbeatInterval is the interval at which the quota manager will broadcast + // the quotas for all sources to the network. + quotaHeartbeatInterval = 5 * time.Minute + // networkSafetyMargin is the buffer for network rate calculations. We do not account + // for this proportion of the quota when calculating the sustainable rate. + networkSafetyMargin = 0.1 + // propagationDelay is the amount of time we wait to receive results from the previous + // node in the fetch order before the next node attempts to fetch. + propagationDelay = 3 * time.Second +) + +// quotaManager coordinates quota tracking and network-wide quota sharing for +// oracle sources. It supports network-coordinated fetch scheduling where nodes +// deterministically order themselves to avoid redundant fetches. +type quotaManager struct { + log slog.Logger + nodeID string + onStateUpdate func(*OracleSnapshot) + publishHeartbeat func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error + + srcsMtx sync.RWMutex + srcs map[string]sources.Source + + peerQuotasMtx sync.RWMutex + peerQuotas map[string]map[string]*TimestampedQuotaStatus +} + +// quotaManagerConfig contains configuration for the quota manager. +type quotaManagerConfig struct { + log slog.Logger + nodeID string + publishQuotaHeartbeat func(ctx context.Context, quotas map[string]*sources.QuotaStatus) error + onStateUpdate func(*OracleSnapshot) + sources []sources.Source +} + +// newQuotaManager creates a new quota manager. +func newQuotaManager(cfg *quotaManagerConfig) *quotaManager { + srcs := make(map[string]sources.Source, len(cfg.sources)) + for _, src := range cfg.sources { + srcs[src.Name()] = src + } + + return "aManager{ + log: cfg.log, + nodeID: cfg.nodeID, + srcs: srcs, + peerQuotas: make(map[string]map[string]*TimestampedQuotaStatus), + publishHeartbeat: cfg.publishQuotaHeartbeat, + onStateUpdate: cfg.onStateUpdate, + } +} + +// Run starts the quota manager's background tasks. +func (qm *quotaManager) run(ctx context.Context) { + heartbeatTicker := time.NewTicker(quotaHeartbeatInterval) + defer heartbeatTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-heartbeatTicker.C: + if err := qm.publishHeartbeat(ctx, qm.getLocalQuotas()); err != nil { + qm.log.Warnf("Failed to publish quota heartbeat: %v", err) + } + qm.expireStalePeerQuotas() + } + } +} + +// HandlePeerSourceQuota processes an update to a peer's quota for a given source. +func (qm *quotaManager) handlePeerSourceQuota(peerID string, quota *TimestampedQuotaStatus, source string) { + qm.peerQuotasMtx.Lock() + if qm.peerQuotas[peerID] == nil { + qm.peerQuotas[peerID] = make(map[string]*TimestampedQuotaStatus) + } + qm.peerQuotas[peerID][source] = quota + qm.peerQuotasMtx.Unlock() + + qm.onStateUpdate(&OracleSnapshot{ + Sources: map[string]*SourceStatus{ + source: { + Quotas: map[string]*Quota{ + peerID: { + FetchesRemaining: quota.FetchesRemaining, + FetchesLimit: quota.FetchesLimit, + ResetTime: quota.ResetTime, + }, + }, + }, + }, + }) +} + +// expireStalePeerQuotas expires peer quotas that have not been updated within +// the active threshold and will no longer be used in fetch scheduling. +func (qm *quotaManager) expireStalePeerQuotas() { + qm.peerQuotasMtx.Lock() + defer qm.peerQuotasMtx.Unlock() + + now := time.Now() + for peerID, srcs := range qm.peerQuotas { + for source, quota := range srcs { + if now.Sub(quota.ReceivedAt) > quotaPeerActiveThreshold { + delete(srcs, source) + } + } + if len(srcs) == 0 { + delete(qm.peerQuotas, peerID) + } + } +} + +// getLocalQuotas returns this node's quota status for all sources. +func (qm *quotaManager) getLocalQuotas() map[string]*sources.QuotaStatus { + qm.srcsMtx.RLock() + defer qm.srcsMtx.RUnlock() + + result := make(map[string]*sources.QuotaStatus) + for name, src := range qm.srcs { + result[name] = src.QuotaStatus() + } + return result +} + +// getNetworkQuotas returns all nodes' quotas for all sources. +func (qm *quotaManager) getNetworkQuotas() map[string]map[string]*TimestampedQuotaStatus { + qm.peerQuotasMtx.RLock() + defer qm.peerQuotasMtx.RUnlock() + + // Copy map structure so callers can modify the map without affecting the original. + result := make(map[string]map[string]*TimestampedQuotaStatus) + for peerID, srcs := range qm.peerQuotas { + result[peerID] = make(map[string]*TimestampedQuotaStatus) + for source, quota := range srcs { + q := *quota + result[peerID][source] = &q + } + } + return result +} + +// getActivePeersForSource returns quotas for peers that shared their quota within +// the active threshold. +func (qm *quotaManager) getActivePeersForSource(source string, now time.Time) map[string]*TimestampedQuotaStatus { + result := make(map[string]*TimestampedQuotaStatus) + + // Add local node's quota + qm.srcsMtx.RLock() + if src, ok := qm.srcs[source]; ok { + result[qm.nodeID] = &TimestampedQuotaStatus{ + QuotaStatus: src.QuotaStatus(), + ReceivedAt: now, + } + } + qm.srcsMtx.RUnlock() + + // Add active peer quotas + qm.peerQuotasMtx.RLock() + defer qm.peerQuotasMtx.RUnlock() + + for peerID, srcs := range qm.peerQuotas { + if quota, ok := srcs[source]; ok { + if now.Sub(quota.ReceivedAt) <= quotaPeerActiveThreshold { + q := *quota + result[peerID] = &q + } + } + } + + return result +} + +func (qm *quotaManager) getNetworkSchedule(source string, minPeriod time.Duration) networkSchedule { + now := time.Now() + activePeers := qm.getActivePeersForSource(source, now) + return computeNetworkSchedule(activePeers, qm.nodeID, minPeriod, now) +} + +// computeNetworkSchedule computes a coordinated fetch schedule for a source +// across all active peers. The algorithm works in three steps: +// +// 1. Sustainable rate: Each peer's quota yields a rate (fetches/sec) after +// applying a safety margin. The network rate is the sum of all peer rates, +// and its reciprocal gives the sustainable period — clamped between +// minPeriod and maxPeriod. +// +// 2. Deterministic ordering: Peers are ranked by score = SHA256(timeWindow, +// nodeID) / rate. The time window rotates every minPeriod seconds so the +// ordering reshuffles periodically, while dividing by rate biases nodes +// with more remaining quota toward the front. Every node computes the +// same ordering independently. +// +// 3. Fetch timing: The first node in the order fetches after the clamped +// period. Each subsequent node adds a propagation delay, giving the +// earlier node time to share results before the next one attempts a +// redundant fetch. +func computeNetworkSchedule(activePeers map[string]*TimestampedQuotaStatus, nodeID string, minPeriod time.Duration, now time.Time) networkSchedule { + // Pre-compute sustainable rate for each active peer. + peerRates := make(map[string]float64, len(activePeers)) + var networkRate float64 + for id, quota := range activePeers { + rate := sustainableRate(quota, now) + peerRates[id] = rate + networkRate += rate + } + + // Raw sustainable period = 1 / network_rate (with maxPeriod fallback). + var sustainablePeriod time.Duration + if networkRate <= 0 { + sustainablePeriod = maxPeriod + } else { + sustainablePeriod = time.Duration(float64(time.Second) / networkRate) + } + clampedPeriod := clamp(sustainablePeriod, minPeriod, maxPeriod) + + // For a deterministic consistent changing value across the network, + // we use a time window based on the minimum period of the source. + windowSecs := int64(minPeriod.Seconds()) + if windowSecs <= 0 { + windowSecs = 1 + } + timeWindow := now.Unix() / windowSecs + + // Next we calculate a randomized score weighted by the sustainable rate of the peer + // to create an ordering of peers for their next fetch time. + type nodeScore struct { + id string + score *big.Int + } + scores := make([]nodeScore, 0, len(activePeers)) + for id := range activePeers { + rate := peerRates[id] + if rate <= 0 { + rate = 0.00001 // avoid division by zero + } + + // hash = SHA256(timeWindow || nodeID) + h := sha256.Sum256(fmt.Appendf(nil, "%d:%s", timeWindow, id)) + hashInt := new(big.Int).SetBytes(h[:]) + + // score = hash / rate, scaled to 9 decimal places to avoid floating + // point precision issues + scaledHash := new(big.Int).Mul(hashInt, big.NewInt(1e9)) + if rate > 1e9 { + rate = 1e9 // Cap to prevent int64 overflow when scaling. + } + rateInt := big.NewInt(int64(rate * 1e9)) + if rateInt.Cmp(big.NewInt(0)) <= 0 { + rateInt = big.NewInt(1) + } + score := new(big.Int).Div(scaledHash, rateInt) + + scores = append(scores, nodeScore{id, score}) + } + + // Sort by score ascending (lower = fetches first) + sort.Slice(scores, func(i, j int) bool { + c := scores[i].score.Cmp(scores[j].score) + if c != 0 { + return c < 0 + } + return scores[i].id < scores[j].id + }) + + // Extract ordered node IDs and find local node's position + orderedNodes := make([]string, len(scores)) + order := len(scores) + for i, s := range scores { + orderedNodes[i] = s.id + if s.id == nodeID { + order = i + } + } + + // Calculate next fetch time: clamped period + (order * delay) + nextFetchAfter := clampedPeriod + time.Duration(order)*propagationDelay + + return networkSchedule{ + NextFetchTime: now.Add(nextFetchAfter), + NetworkSustainableRate: networkRate, + MinPeriod: minPeriod, + NetworkSustainablePeriod: sustainablePeriod, + NetworkNextFetchTime: now.Add(clampedPeriod), + OrderedNodes: orderedNodes, + } +} + +// sustainableRate returns the sustainable fetch rate (fetches/second) for a peer. +// Applies safety margin to prevent quota exhaustion. +func sustainableRate(quota *TimestampedQuotaStatus, now time.Time) float64 { + // Unlimited quota + if quota.FetchesRemaining >= 1<<62 { + return 1.0 // Cap at 1 fetch/second for unlimited sources + } + + // Exhausted quota + if quota.FetchesRemaining <= 0 { + return 0 + } + + timeRemaining := quota.ResetTime.Sub(now) + if timeRemaining <= 0 { + return 1.0 // Quota should have reset, assume fresh + } + + // Apply safety margin: effective_remaining = remaining * (1 - margin) + effectiveRemaining := float64(quota.FetchesRemaining) * (1 - networkSafetyMargin) + if effectiveRemaining <= 0 { + return 0 + } + + // Rate = effective_remaining / time_remaining + return effectiveRemaining / timeRemaining.Seconds() +} + +func clamp(d, lo, hi time.Duration) time.Duration { + if d < lo { + return lo + } + if d > hi { + return hi + } + return d +} diff --git a/oracle/quota_manager_test.go b/oracle/quota_manager_test.go new file mode 100644 index 0000000..fc2d18b --- /dev/null +++ b/oracle/quota_manager_test.go @@ -0,0 +1,501 @@ +package oracle + +import ( + "context" + "math" + "testing" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" +) + +func newTestQuotaManager(nodeID string, srcs []sources.Source) (*quotaManager, *[]*OracleSnapshot) { + var updates []*OracleSnapshot + return newQuotaManager("aManagerConfig{ + log: slog.Disabled, + nodeID: nodeID, + sources: srcs, + publishQuotaHeartbeat: func(_ context.Context, _ map[string]*sources.QuotaStatus) error { + return nil + }, + onStateUpdate: func(snap *OracleSnapshot) { + updates = append(updates, snap) + }, + }), &updates +} + +func makeQuota(remaining, limit int64, resetIn time.Duration) *sources.QuotaStatus { + return &sources.QuotaStatus{ + FetchesRemaining: remaining, + FetchesLimit: limit, + ResetTime: time.Now().Add(resetIn), + } +} + +func makeTimestampedQuota(remaining, limit int64, resetIn time.Duration, receivedAt time.Time) *TimestampedQuotaStatus { + return &TimestampedQuotaStatus{ + QuotaStatus: makeQuota(remaining, limit, resetIn), + ReceivedAt: receivedAt, + } +} + +// --- computeNetworkSchedule tests (pure function, no quotaManager) --- + +func TestComputeNetworkScheduleSingleNode(t *testing.T) { + now := time.Now() + peers := map[string]*TimestampedQuotaStatus{ + "node-A": makeTimestampedQuota(100, 200, 24*time.Hour, now), + } + + sched := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + + if len(sched.OrderedNodes) != 1 { + t.Fatalf("expected 1 ordered node, got %d", len(sched.OrderedNodes)) + } + if sched.OrderedNodes[0] != "node-A" { + t.Errorf("expected node-A, got %s", sched.OrderedNodes[0]) + } + if sched.MinPeriod != 30*time.Second { + t.Errorf("expected min period 30s, got %v", sched.MinPeriod) + } + if sched.NetworkSustainableRate <= 0 { + t.Error("expected positive sustainable rate") + } + // Single node at position 0: no propagation delay. + expectedNext := sched.NetworkNextFetchTime + if sched.NextFetchTime != expectedNext { + t.Errorf("single node should have no propagation delay, got diff %v", + sched.NextFetchTime.Sub(expectedNext)) + } +} + +func TestComputeNetworkScheduleDeterministicOrder(t *testing.T) { + now := time.Now() + peers := map[string]*TimestampedQuotaStatus{ + "node-A": makeTimestampedQuota(100, 200, 24*time.Hour, now), + "node-B": makeTimestampedQuota(100, 200, 24*time.Hour, now), + } + + sched1 := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + sched2 := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + + if len(sched1.OrderedNodes) != 2 { + t.Fatalf("expected 2 ordered nodes, got %d", len(sched1.OrderedNodes)) + } + for i := range sched1.OrderedNodes { + if sched1.OrderedNodes[i] != sched2.OrderedNodes[i] { + t.Error("expected deterministic ordering across calls") + break + } + } +} + +func TestComputeNetworkScheduleConsistentAcrossNodes(t *testing.T) { + now := time.Now() + peers := map[string]*TimestampedQuotaStatus{ + "node-A": makeTimestampedQuota(100, 200, 24*time.Hour, now), + "node-B": makeTimestampedQuota(100, 200, 24*time.Hour, now), + "node-C": makeTimestampedQuota(100, 200, 24*time.Hour, now), + } + + // Different nodes calling with the same peer set should produce the same ordering. + schedA := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + schedB := computeNetworkSchedule(peers, "node-B", 30*time.Second, now) + + for i := range schedA.OrderedNodes { + if schedA.OrderedNodes[i] != schedB.OrderedNodes[i] { + t.Error("ordering should be the same regardless of which node computes it") + break + } + } +} + +func TestComputeNetworkScheduleRespectsMinPeriod(t *testing.T) { + now := time.Now() + // Unlimited quota — sustainable period would be 1s (rate=1.0), far below minPeriod. + peers := map[string]*TimestampedQuotaStatus{ + "node-A": { + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1 << 62, + FetchesLimit: 1 << 62, + ResetTime: now.Add(24 * time.Hour), + }, + ReceivedAt: now, + }, + } + + sched := computeNetworkSchedule(peers, "node-A", 5*time.Minute, now) + + expectedMin := now.Add(5*time.Minute - time.Second) + if sched.NetworkNextFetchTime.Before(expectedMin) { + t.Error("network next fetch time should respect min period") + } +} + +func TestComputeNetworkScheduleRespectsMaxPeriod(t *testing.T) { + now := time.Now() + // Exhausted quota — rate is 0, sustainable period would be infinite. + peers := map[string]*TimestampedQuotaStatus{ + "node-A": makeTimestampedQuota(0, 100, 24*time.Hour, now), + } + + sched := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + + if sched.NetworkSustainablePeriod != maxPeriod { + t.Errorf("expected sustainable period capped at maxPeriod (%v), got %v", + maxPeriod, sched.NetworkSustainablePeriod) + } +} + +func TestComputeNetworkSchedulePropagationDelay(t *testing.T) { + now := time.Now() + peers := map[string]*TimestampedQuotaStatus{ + "node-A": makeTimestampedQuota(100, 200, 24*time.Hour, now), + "node-B": makeTimestampedQuota(100, 200, 24*time.Hour, now), + } + + sched := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + + myOrder := -1 + for i, id := range sched.OrderedNodes { + if id == "node-A" { + myOrder = i + break + } + } + if myOrder < 0 { + t.Fatal("local node not found in ordered nodes") + } + + expectedDelay := time.Duration(myOrder) * propagationDelay + diff := sched.NextFetchTime.Sub(sched.NetworkNextFetchTime) + if diff < expectedDelay-time.Millisecond || diff > expectedDelay+time.Millisecond { + t.Errorf("expected propagation delay of %v for order %d, got %v", expectedDelay, myOrder, diff) + } +} + +func TestComputeNetworkScheduleHigherRateBiasesOrder(t *testing.T) { + now := time.Now() + // Give node-A much more remaining quota than node-B. + // Over many time windows, node-A should appear first more often. + aFirst := 0 + trials := 100 + for i := range trials { + trialTime := now.Add(time.Duration(i) * time.Hour) + peers := map[string]*TimestampedQuotaStatus{ + "node-A": { + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 10000, + FetchesLimit: 10000, + ResetTime: trialTime.Add(24 * time.Hour), + }, + ReceivedAt: trialTime, + }, + "node-B": { + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 10, + FetchesLimit: 10000, + ResetTime: trialTime.Add(24 * time.Hour), + }, + ReceivedAt: trialTime, + }, + } + sched := computeNetworkSchedule(peers, "node-A", 30*time.Second, trialTime) + if sched.OrderedNodes[0] == "node-A" { + aFirst++ + } + } + // node-A has ~1000x the rate, so it should be first most of the time. + if aFirst < trials/2 { + t.Errorf("node with higher rate should be ordered first more often, but was first only %d/%d times", aFirst, trials) + } +} + +func TestComputeNetworkScheduleNoPeers(t *testing.T) { + now := time.Now() + peers := map[string]*TimestampedQuotaStatus{} + + sched := computeNetworkSchedule(peers, "node-A", 30*time.Second, now) + + if len(sched.OrderedNodes) != 0 { + t.Errorf("expected 0 ordered nodes with no peers, got %d", len(sched.OrderedNodes)) + } + if sched.NetworkSustainablePeriod != maxPeriod { + t.Errorf("expected maxPeriod with no peers, got %v", sched.NetworkSustainablePeriod) + } +} + +func TestComputeNetworkScheduleNetworkRate(t *testing.T) { + now := time.Now() + resetTime := now.Add(time.Hour) + + single := map[string]*TimestampedQuotaStatus{ + "node-A": { + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1000, + FetchesLimit: 1000, + ResetTime: resetTime, + }, + ReceivedAt: now, + }, + } + double := map[string]*TimestampedQuotaStatus{ + "node-A": { + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1000, + FetchesLimit: 1000, + ResetTime: resetTime, + }, + ReceivedAt: now, + }, + "node-B": { + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1000, + FetchesLimit: 1000, + ResetTime: resetTime, + }, + ReceivedAt: now, + }, + } + + sched1 := computeNetworkSchedule(single, "node-A", time.Second, now) + sched2 := computeNetworkSchedule(double, "node-A", time.Second, now) + + // Two identical peers should have ~2x the network rate. + ratio := sched2.NetworkSustainableRate / sched1.NetworkSustainableRate + if math.Abs(ratio-2.0) > 0.01 { + t.Errorf("expected 2x network rate with 2 peers, got ratio %.2f", ratio) + } +} + +// --- sustainableRate tests --- + +func TestSustainableRate(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + quota *TimestampedQuotaStatus + wantRate float64 + wantZero bool + }{ + { + name: "unlimited quota returns capped rate", + quota: &TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1 << 62, + FetchesLimit: 1 << 62, + ResetTime: now.Add(24 * time.Hour), + }, + }, + wantRate: 1.0, + }, + { + name: "exhausted quota returns zero", + quota: &TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 0, + FetchesLimit: 100, + ResetTime: now.Add(time.Hour), + }, + }, + wantZero: true, + }, + { + name: "negative remaining returns zero", + quota: &TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: -5, + FetchesLimit: 100, + ResetTime: now.Add(time.Hour), + }, + }, + wantZero: true, + }, + { + name: "expired reset time returns capped rate", + quota: &TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 50, + FetchesLimit: 100, + ResetTime: now.Add(-time.Hour), + }, + }, + wantRate: 1.0, + }, + { + name: "normal quota calculates rate with safety margin", + quota: &TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1000, + FetchesLimit: 1000, + ResetTime: now.Add(time.Hour), + }, + }, + // effective = 1000 * 0.9 = 900, time = 3600s, rate = 0.25 + wantRate: 900.0 / 3600.0, + }, + { + name: "very low remaining with margin", + quota: &TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: 1, + FetchesLimit: 100, + ResetTime: now.Add(time.Hour), + }, + }, + // effective = 1 * 0.9 = 0.9, time = 3600s + wantRate: 0.9 / 3600.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rate := sustainableRate(tt.quota, now) + if tt.wantZero { + if rate != 0 { + t.Errorf("expected 0, got %f", rate) + } + return + } + if math.Abs(rate-tt.wantRate) > 1e-9 { + t.Errorf("expected %f, got %f", tt.wantRate, rate) + } + }) + } +} + +// --- quotaManager tests (external interface only) --- + +func TestQuotaManagerHandlePeerQuota(t *testing.T) { + src := &mockSource{ + name: "blockcypher", + quota: makeQuota(100, 200, 24*time.Hour), + } + qm, updates := newTestQuotaManager("node-A", []sources.Source{src}) + + qm.handlePeerSourceQuota("node-B", &TimestampedQuotaStatus{ + QuotaStatus: makeQuota(50, 200, 12*time.Hour), + ReceivedAt: time.Now(), + }, "blockcypher") + + // Should emit a state update. + if len(*updates) != 1 { + t.Fatalf("expected 1 update, got %d", len(*updates)) + } + snap := (*updates)[0] + q, ok := snap.Sources["blockcypher"].Quotas["node-B"] + if !ok { + t.Fatal("expected node-B quota in snapshot") + } + if q.FetchesRemaining != 50 || q.FetchesLimit != 200 { + t.Errorf("unexpected quota values: remaining=%d, limit=%d", q.FetchesRemaining, q.FetchesLimit) + } + + // Should be stored in network quotas. + peers := qm.getNetworkQuotas() + if _, ok := peers["node-B"]["blockcypher"]; !ok { + t.Error("expected peer quota stored for node-B/blockcypher") + } +} + +func TestQuotaManagerHandlePeerQuotaOverwrite(t *testing.T) { + qm, _ := newTestQuotaManager("node-A", nil) + now := time.Now() + + qm.handlePeerSourceQuota("node-B", makeTimestampedQuota(100, 200, 12*time.Hour, now), "blockcypher") + qm.handlePeerSourceQuota("node-B", makeTimestampedQuota(50, 200, 12*time.Hour, now.Add(time.Minute)), "blockcypher") + + peers := qm.getNetworkQuotas() + if peers["node-B"]["blockcypher"].FetchesRemaining != 50 { + t.Errorf("expected overwritten quota with 50 remaining, got %d", + peers["node-B"]["blockcypher"].FetchesRemaining) + } +} + +func TestQuotaManagerGetLocalQuotas(t *testing.T) { + qm, _ := newTestQuotaManager("node-A", []sources.Source{ + &mockSource{name: "blockcypher", quota: makeQuota(80, 200, 24*time.Hour)}, + &mockSource{name: "coinpaprika", quota: makeQuota(500, 1000, 12*time.Hour)}, + }) + + quotas := qm.getLocalQuotas() + if len(quotas) != 2 { + t.Fatalf("expected 2 local quotas, got %d", len(quotas)) + } + if quotas["blockcypher"].FetchesRemaining != 80 { + t.Errorf("expected 80 remaining for blockcypher, got %d", quotas["blockcypher"].FetchesRemaining) + } + if quotas["coinpaprika"].FetchesRemaining != 500 { + t.Errorf("expected 500 remaining for coinpaprika, got %d", quotas["coinpaprika"].FetchesRemaining) + } +} + +func TestQuotaManagerGetNetworkQuotasMapIndependence(t *testing.T) { + qm, _ := newTestQuotaManager("node-A", nil) + qm.handlePeerSourceQuota("node-B", makeTimestampedQuota(50, 100, time.Hour, time.Now()), "blockcypher") + + copy1 := qm.getNetworkQuotas() + delete(copy1, "node-B") + + copy2 := qm.getNetworkQuotas() + if _, ok := copy2["node-B"]; !ok { + t.Error("deleting from returned map should not affect internal state") + } +} + +func TestQuotaManagerRunContextCancellation(t *testing.T) { + qm, _ := newTestQuotaManager("node-A", nil) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + qm.run(ctx) + close(done) + }() + + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("run() did not exit after context cancellation") + } +} + +func TestQuotaManagerRunHeartbeatAndExpiration(t *testing.T) { + src := &mockSource{name: "blockcypher", quota: makeQuota(100, 200, 24*time.Hour)} + var publishedQuotas map[string]*sources.QuotaStatus + + qm := newQuotaManager("aManagerConfig{ + log: slog.Disabled, + nodeID: "node-A", + sources: []sources.Source{src}, + publishQuotaHeartbeat: func(_ context.Context, quotas map[string]*sources.QuotaStatus) error { + publishedQuotas = quotas + return nil + }, + onStateUpdate: func(_ *OracleSnapshot) {}, + }) + + // Add a stale peer. + qm.peerQuotasMtx.Lock() + qm.peerQuotas["stale-peer"] = map[string]*TimestampedQuotaStatus{ + "blockcypher": makeTimestampedQuota(10, 100, time.Hour, time.Now().Add(-quotaPeerActiveThreshold-time.Minute)), + } + qm.peerQuotasMtx.Unlock() + + // Simulate what run() does each tick. + ctx := context.Background() + if err := qm.publishHeartbeat(ctx, qm.getLocalQuotas()); err != nil { + t.Fatalf("publishHeartbeat failed: %v", err) + } + qm.expireStalePeerQuotas() + + if _, ok := publishedQuotas["blockcypher"]; !ok { + t.Error("expected blockcypher quota in heartbeat") + } + if _, ok := qm.getNetworkQuotas()["stale-peer"]; ok { + t.Error("expected stale peer to be expired") + } +} diff --git a/oracle/snapshot.go b/oracle/snapshot.go new file mode 100644 index 0000000..11b85fa --- /dev/null +++ b/oracle/snapshot.go @@ -0,0 +1,245 @@ +package oracle + +import ( + "fmt" + "time" +) + +// DataType identifies the kind of data point (price or fee rate). +type DataType = string + +const ( + PriceData DataType = "price" + FeeRateData DataType = "fee_rate" +) + +// SourceStatus is the per-source view. +type SourceStatus struct { + LastFetch *time.Time `json:"last_fetch,omitempty"` + NextFetchTime *time.Time `json:"next_fetch_time,omitempty"` + MinFetchInterval *time.Duration `json:"min_fetch_interval,omitempty"` + NetworkSustainableRate *float64 `json:"network_sustainable_rate,omitempty"` + NetworkSustainablePeriod *time.Duration `json:"network_sustainable_period,omitempty"` + NetworkNextFetchTime *time.Time `json:"network_next_fetch_time,omitempty"` + LastError string `json:"last_error,omitempty"` + LastErrorTime *time.Time `json:"last_error_time,omitempty"` + OrderedNodes []string `json:"ordered_nodes,omitempty"` // Node IDs in fetch order + Fetches24h map[string]int `json:"fetches_24h,omitempty"` + Quotas map[string]*Quota `json:"quotas,omitempty"` + // LatestData holds the most recent values from this source, keyed by + // data type ("price" or "fee_rate") then by identifier (ticker or + // network name), with the formatted value string as the map value. + LatestData map[string]map[string]string `json:"latest_data,omitempty"` +} + +// Quota is per-node quota info embedded in each source. +type Quota struct { + FetchesRemaining int64 `json:"fetches_remaining"` + FetchesLimit int64 `json:"fetches_limit"` + ResetTime time.Time `json:"reset_time"` +} + +// SourceContribution represents a single source's contribution to an +// aggregated price or fee rate. +type SourceContribution struct { + Value string `json:"value,omitempty"` + Stamp time.Time `json:"stamp,omitempty"` + Weight float64 `json:"weight,omitempty"` +} + +// sourcesStatus assembles the per-source status data. +func (o *Oracle) sourcesStatus() map[string]*SourceStatus { + fetchCounts := o.fetchTracker.fetchCounts() + latestPerSource := o.fetchTracker.latestPerSource() + localQuotas := o.quotaManager.getLocalQuotas() + networkQuotas := o.quotaManager.getNetworkQuotas() + + // Collect all source names. + sourceNames := make(map[string]bool) + o.divinersMtx.RLock() + for name := range o.diviners { + sourceNames[name] = true + } + o.divinersMtx.RUnlock() + + sources := make(map[string]*SourceStatus, len(sourceNames)) + + for name := range sourceNames { + status := &SourceStatus{ + Fetches24h: make(map[string]int), + Quotas: make(map[string]*Quota), + } + + // Latest fetch. + if stamp, ok := latestPerSource[name]; ok { + status.LastFetch = &stamp + } + + // Next fetch time and intervals (only for our diviners). + o.divinersMtx.RLock() + if div, ok := o.diviners[name]; ok { + info := div.fetchScheduleInfo() + if !info.NextFetchTime.IsZero() { + nft := info.NextFetchTime + status.NextFetchTime = &nft + } + if !info.NetworkNextFetchTime.IsZero() { + nnft := info.NetworkNextFetchTime + status.NetworkNextFetchTime = &nnft + } + minPeriod := info.MinPeriod + status.MinFetchInterval = &minPeriod + status.NetworkSustainableRate = &info.NetworkSustainableRate + nsp := info.NetworkSustainablePeriod + status.NetworkSustainablePeriod = &nsp + status.OrderedNodes = info.OrderedNodes + if errMsg, errTime := div.fetchErrorInfo(); errMsg != "" && errTime != nil { + status.LastError = errMsg + status.LastErrorTime = errTime + } + } + o.divinersMtx.RUnlock() + + // Per-node fetch counts. + if counts, ok := fetchCounts[name]; ok { + status.Fetches24h = counts + } + + // Local quotas (our node). + if lq, ok := localQuotas[name]; ok { + status.Quotas[o.nodeID] = &Quota{ + FetchesRemaining: lq.FetchesRemaining, + FetchesLimit: lq.FetchesLimit, + ResetTime: lq.ResetTime, + } + } + + // Network quotas (peers). + for peerID, sourceQuotas := range networkQuotas { + if pq, ok := sourceQuotas[name]; ok { + status.Quotas[peerID] = &Quota{ + FetchesRemaining: pq.FetchesRemaining, + FetchesLimit: pq.FetchesLimit, + ResetTime: pq.ResetTime, + } + } + } + + // Latest data from this source. + latestData := make(map[string]map[string]string) + o.pricesMtx.RLock() + for ticker, bucket := range o.prices { + bucket.mtx.RLock() + if entry, ok := bucket.sources[name]; ok { + if latestData[PriceData] == nil { + latestData[PriceData] = make(map[string]string) + } + latestData[PriceData][string(ticker)] = fmt.Sprintf("%f", entry.price) + } + bucket.mtx.RUnlock() + } + o.pricesMtx.RUnlock() + + o.feeRatesMtx.RLock() + for network, bucket := range o.feeRates { + bucket.mtx.RLock() + if entry, ok := bucket.sources[name]; ok { + if latestData[FeeRateData] == nil { + latestData[FeeRateData] = make(map[string]string) + } + latestData[FeeRateData][string(network)] = entry.feeRate.String() + } + bucket.mtx.RUnlock() + } + o.feeRatesMtx.RUnlock() + + if len(latestData) > 0 { + status.LatestData = latestData + } + + sources[name] = status + } + + return sources +} + +// SnapshotRate holds the aggregated value and all source contributions for a rate. +type SnapshotRate struct { + Value string `json:"value,omitempty"` + Contributions map[string]*SourceContribution `json:"contributions,omitempty"` +} + +// priceContributions returns all prices with their source contributions. +func (o *Oracle) priceContributions() map[string]*SnapshotRate { + result := make(map[string]*SnapshotRate) + o.pricesMtx.RLock() + defer o.pricesMtx.RUnlock() + + for ticker, bucket := range o.prices { + bucket.mtx.RLock() + contribs := make(map[string]*SourceContribution, len(bucket.sources)) + for name, upd := range bucket.sources { + contribs[name] = &SourceContribution{ + Value: fmt.Sprintf("%f", upd.price), + Stamp: upd.stamp, + Weight: upd.weight, + } + } + agg := bucket.aggregatedPrice() + bucket.mtx.RUnlock() + + result[string(ticker)] = &SnapshotRate{ + Value: fmt.Sprintf("%f", agg), + Contributions: contribs, + } + } + return result +} + +// feeRateContributions returns all fee rates with their source contributions. +func (o *Oracle) feeRateContributions() map[string]*SnapshotRate { + result := make(map[string]*SnapshotRate) + o.feeRatesMtx.RLock() + defer o.feeRatesMtx.RUnlock() + + for network, bucket := range o.feeRates { + bucket.mtx.RLock() + contribs := make(map[string]*SourceContribution, len(bucket.sources)) + for name, upd := range bucket.sources { + contribs[name] = &SourceContribution{ + Value: upd.feeRate.String(), + Stamp: upd.stamp, + Weight: upd.weight, + } + } + agg := bucket.aggregatedRate() + bucket.mtx.RUnlock() + if agg == nil { + continue + } + + result[string(network)] = &SnapshotRate{ + Value: agg.String(), + Contributions: contribs, + } + } + return result +} + +// OracleSnapshot contains the current state of the oracle. +type OracleSnapshot struct { + NodeID string `json:"node_id,omitempty"` + Sources map[string]*SourceStatus `json:"sources,omitempty"` + Prices map[string]*SnapshotRate `json:"prices,omitempty"` + FeeRates map[string]*SnapshotRate `json:"fee_rates,omitempty"` +} + +// OracleSnapshot returns the current state of the oracle. +func (o *Oracle) OracleSnapshot() *OracleSnapshot { + return &OracleSnapshot{ + NodeID: o.nodeID, + Sources: o.sourcesStatus(), + Prices: o.priceContributions(), + FeeRates: o.feeRateContributions(), + } +} diff --git a/oracle/sources.go b/oracle/sources.go deleted file mode 100644 index 4a9a796..0000000 --- a/oracle/sources.go +++ /dev/null @@ -1,446 +0,0 @@ -package oracle - -import ( - "context" - "encoding/json" - "fmt" - "io" - "math" - "math/big" - "net/http" - "strconv" - "time" -) - -// priceUpdate is the internal message used for when a price update is fetched -// or received from a source. -type priceUpdate struct { - ticker Ticker - price float64 - - // Added by Oracle loops - stamp time.Time - weight float64 -} - -// feeRateUpdate is the internal message used for when a fee rate update is -// fetched or received from a source. -type feeRateUpdate struct { - network Network - feeRate *big.Int - - // Added by Oracle loops - stamp time.Time - weight float64 -} - -// divination is an update from a source, which could be fee rates or prices. -type divination any // []*priceUpdate or []*feeRateUpdate - -// httpSource is a source from which http requests will be performed on some -// interval. -type httpSource struct { - name string - url string - parse func(io.Reader) (divination, error) - period time.Duration // default 5 minutes - errPeriod time.Duration // default 1 minute - weight float64 // range: [0, 1], default 1 - headers []http.Header -} - -func (h *httpSource) fetch(ctx context.Context, client HTTPClient) (any, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.url, nil) - if err != nil { - return nil, fmt.Errorf("error generating request %q: %v", h.url, err) - } - - for _, header := range h.headers { - for k, vs := range header { - for _, v := range vs { - req.Header.Add(k, v) - } - } - } - - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("error fetching %q: %v", h.url, err) - } - defer resp.Body.Close() - - return h.parse(resp.Body) -} - -// setHTTPSourceDefaults sets default values for HTTP sources. -func setHTTPSourceDefaults(sources []*httpSource) error { - for _, s := range sources { - const defaultWeight = 1.0 - if s.weight == 0 { - s.weight = defaultWeight - } else if s.weight < 0 { - return fmt.Errorf("http source '%s' has a negative weight", s.name) - } else if s.weight > 1 { - return fmt.Errorf("http source '%s' has a weight > 1", s.name) - } - const defaultHttpRequestInterval = time.Minute * 5 - if s.period == 0 { - s.period = defaultHttpRequestInterval - } - const defaultHttpErrorInterval = time.Minute - if s.errPeriod == 0 { - s.errPeriod = defaultHttpErrorInterval - } - } - - return nil -} - -// unauthedHttpSources are HTTP sources that don't require any kind of -// authorization e.g. registration or API keys. -var unauthedHttpSources = []*httpSource{ - { - name: "dcrdata", - url: "https://explorer.dcrdata.org/insight/api/utils/estimatefee?nbBlocks=2", - parse: dcrdataParser, - }, - { - name: "btc.mempooldotspace", - url: "https://mempool.space/api/v1/fees/recommended", - parse: mempoolDotSpaceParser, - }, - { - // You can make up to 20,000 requests per month on the free plan, which - // works out to one request every ~2m10s, but we'll stick with the - // default of 5m. - name: "coinpaprika", - url: "https://api.coinpaprika.com/v1/tickers", - parse: coinpaprikaParser, - }, - // Bitcore APIs not well-documented, and I believe that they use - // estimatesmartfee, which is known to be a little wild. Use with caution. - { - name: "bch.bitcore", - url: "https://api.bitcore.io/api/BCH/mainnet/fee/2", - parse: bitcoreBitcoinCashParser, - weight: 0.25, - }, - { - name: "doge.bitcore", - url: "https://api.bitcore.io/api/DOGE/mainnet/fee/2", - parse: bitcoreDogecoinParser, - weight: 0.25, - }, - { - name: "ltc.bitcore", - url: "https://api.bitcore.io/api/LTC/mainnet/fee/2", - parse: bitcoreLitecoinParser, - weight: 0.25, - }, - { - name: "firo.org", - url: "https://explorer.firo.org/insight-api-zcoin/utils/estimatefee", - parse: firoOrgParser, - weight: 0.25, // Also an estimatesmartfee source, I believe. - }, - { - name: "ltc.blockcypher", - url: "https://api.blockcypher.com/v1/ltc/main", - parse: blockcypherLitecoinParser, - weight: 0.25, - }, -} - -func dcrdataParser(r io.Reader) (u divination, err error) { - var resp map[string]float64 - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - if len(resp) != 1 || resp["2"] == 0 { - return nil, fmt.Errorf("unexpected response format: %+v", resp) - } - if resp["2"] <= 0 { - return nil, fmt.Errorf("fee rate cannot be negative or zero") - } - return []*feeRateUpdate{{network: "DCR", feeRate: uint64ToBigInt(uint64(math.Round(resp["2"] * 1e5)))}}, nil -} - -func mempoolDotSpaceParser(r io.Reader) (u divination, err error) { - var resp struct { - FastestFee uint64 `json:"fastestFee"` - } - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - if resp.FastestFee == 0 { - return nil, fmt.Errorf("zero fee rate returned") - } - return []*feeRateUpdate{{network: "BTC", feeRate: uint64ToBigInt(resp.FastestFee)}}, nil -} - -func coinpaprikaParser(r io.Reader) (u divination, err error) { - var prices []*struct { - Symbol string `json:"symbol"` - Quotes struct { - USD struct { - Price float64 `json:"price"` - } `json:"USD"` - } `json:"quotes"` - } - if err := streamDecodeJSON(r, &prices); err != nil { - return nil, err - } - seen := make(map[string]bool, len(prices)) - us := make([]*priceUpdate, 0, len(prices)) - for _, p := range prices { - if seen[p.Symbol] { - continue - } - seen[p.Symbol] = true - us = append(us, &priceUpdate{ - ticker: Ticker(p.Symbol), - price: p.Quotes.USD.Price, - }) - } - return us, nil -} - -func coinmarketcapSource(key string) *httpSource { - // Coinmarketcap free plan gives 10,000 credits per month. This endpoint - // uses 1 credit per call per 200 assets requested. So if we request the - // top 400 assets, we can call 5,000 times per month, which comes to - // about 1 call per every 8.9 minutes. We'll call every 10 minutes. - const requestInterval = time.Minute * 10 - return &httpSource{ - name: "coinmarketcap", - url: "https://pro-api.coinmarketcap.com/v1/cryptocurrency/listings/latest?limit=400", - parse: coinmarketcapParser, - headers: []http.Header{{"X-CMC_PRO_API_KEY": []string{key}}}, - period: requestInterval, - } -} - -func coinmarketcapParser(r io.Reader) (u divination, err error) { - var resp struct { - Data []*struct { - Symbol string `json:"symbol"` - Quote struct { - USD struct { - Price float64 `json:"price"` - } `json:"USD"` - } `json:"quote"` - } `json:"data"` - } - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - prices := resp.Data - seen := make(map[string]bool, len(prices)) - us := make([]*priceUpdate, 0, len(prices)) - for _, p := range prices { - if seen[p.Symbol] { - continue - } - seen[p.Symbol] = true - us = append(us, &priceUpdate{ - ticker: Ticker(p.Symbol), - price: p.Quote.USD.Price, - }) - } - return us, nil -} - -func parseBitcoreResponse(netName Network, r io.Reader) (u divination, err error) { - var resp struct { - RatePerKB float64 `json:"feerate"` - } - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - if resp.RatePerKB <= 0 { - return nil, fmt.Errorf("fee rate cannot be negative or zero") - } - return []*feeRateUpdate{{network: netName, feeRate: uint64ToBigInt(uint64(resp.RatePerKB * 1e5))}}, nil -} - -func bitcoreBitcoinCashParser(r io.Reader) (u divination, err error) { - return parseBitcoreResponse("BCH", r) -} - -func bitcoreDogecoinParser(r io.Reader) (u divination, err error) { - return parseBitcoreResponse("DOGE", r) -} - -func bitcoreLitecoinParser(r io.Reader) (u divination, err error) { - return parseBitcoreResponse("LTC", r) -} - -func firoOrgParser(r io.Reader) (u divination, err error) { - var resp map[string]float64 - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - if len(resp) != 1 || resp["2"] == 0 { - return nil, fmt.Errorf("unexpected response format: %+v", resp) - } - if resp["2"] <= 0 { - return nil, fmt.Errorf("fee rate cannot be negative or zero") - } - return []*feeRateUpdate{{network: "FIRO", feeRate: uint64ToBigInt(uint64(math.Round(resp["2"] * 1e5)))}}, nil -} - -func blockcypherLitecoinParser(r io.Reader) (u divination, err error) { - var resp struct { - // Low float64 `json:"low_fee_per_kb"` - Medium float64 `json:"medium_fee_per_kb"` - // High float64 `json:"high_fee_per_kb"` - } - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - if resp.Medium <= 0 { - return nil, fmt.Errorf("fee rate cannot be negative or zero") - } - return []*feeRateUpdate{{network: "LTC", feeRate: uint64ToBigInt(uint64(resp.Medium * 1e5))}}, nil -} - -func tatumSource(key, coin, name string, parser func(io.Reader) (divination, error)) *httpSource { - // Tatum free tier provides 100,000 lifetime API credits. With 3 sources - // (BTC, LTC, DOGE) making requests every 5 minutes, this equals ~864 - // requests/day, which will exhaust the free tier in approximately 116 days. - // A paid plan will be required for use in production. - return &httpSource{ - name: name, - url: fmt.Sprintf("https://api.tatum.io/v3/blockchain/fee/%s", coin), - parse: parser, - headers: []http.Header{{"x-api-key": []string{key}}}, - period: time.Minute * 5, - errPeriod: time.Minute, - weight: 1.0, - } -} - -func tatumBitcoinSource(key string) *httpSource { - return tatumSource(key, "BTC", "tatum.btc", tatumBitcoinParser) -} - -func tatumLitecoinSource(key string) *httpSource { - return tatumSource(key, "LTC", "tatum.ltc", tatumLitecoinParser) -} - -func tatumDogecoinSource(key string) *httpSource { - return tatumSource(key, "DOGE", "tatum.doge", tatumDogecoinParser) -} - -func tatumParser(r io.Reader, network Network) (u divination, err error) { - var resp struct { - Fast float64 `json:"fast"` - // Medium float64 `json:"medium"` - // Slow float64 `json:"slow"` - } - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - if resp.Fast <= 0 { - return nil, fmt.Errorf("fee rate cannot be negative or zero") - } - return []*feeRateUpdate{{network: network, feeRate: uint64ToBigInt(uint64(resp.Fast))}}, nil -} - -func tatumBitcoinParser(r io.Reader) (u divination, err error) { - return tatumParser(r, "BTC") -} - -func tatumLitecoinParser(r io.Reader) (u divination, err error) { - return tatumParser(r, "LTC") -} - -func tatumDogecoinParser(r io.Reader) (u divination, err error) { - return tatumParser(r, "DOGE") -} - -func cryptoApisSource(key, blockchain, name string, parser func(io.Reader) (divination, error)) *httpSource { - // Crypto APIs free tier provides 100 requests per day. With 5 sources - // (BTC, BCH, DOGE, DASH, LTC) making requests every 5 minutes, this equals - // ~1,440 requests/day, which exceeds the free tier limit. A paid plan is - // required for production use. - return &httpSource{ - name: name, - url: fmt.Sprintf("https://rest.cryptoapis.io/blockchain-fees/utxo/%s/mainnet/mempool", blockchain), - parse: parser, - headers: []http.Header{{"X-API-Key": []string{key}}}, - period: time.Minute * 5, - errPeriod: time.Minute, - weight: 1.0, - } -} - -func cryptoApisBitcoinSource(key string) *httpSource { - return cryptoApisSource(key, "BTC", "cryptoapis.btc", cryptoApisBitcoinParser) -} - -func cryptoApisBitcoinCashSource(key string) *httpSource { - return cryptoApisSource(key, "BCH", "cryptoapis.bch", cryptoApisBitcoinCashParser) -} - -func cryptoApisDogecoinSource(key string) *httpSource { - return cryptoApisSource(key, "DOGE", "cryptoapis.doge", cryptoApisDogecoinParser) -} - -func cryptoApisDashSource(key string) *httpSource { - return cryptoApisSource(key, "DASH", "cryptoapis.dash", cryptoApisDashParser) -} - -func cryptoApisLitecoinSource(key string) *httpSource { - return cryptoApisSource(key, "LTC", "cryptoapis.ltc", cryptoApisLitecoinParser) -} - -func cryptoApisParser(r io.Reader, network Network) (u divination, err error) { - var resp struct { - Data struct { - Item struct { - Fast string `json:"fast"` - // Standard float64 `json:"standard"` - // Slow float64 `json:"slow"` - } `json:"item"` - } `json:"data"` - } - if err := streamDecodeJSON(r, &resp); err != nil { - return nil, err - } - // The API returns fees in the coin's base unit (e.g., BTC, LTC, DOGE). - // Convert to satoshis per byte by multiplying by 1e8. - feeRate, err := strconv.ParseFloat(resp.Data.Item.Fast, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse fee rate: %v", err) - } - if feeRate <= 0 { - return nil, fmt.Errorf("fee rate cannot be negative or zero") - } - feeRateSatoshis := uint64(feeRate * 1e8) - return []*feeRateUpdate{{network: network, feeRate: uint64ToBigInt(feeRateSatoshis)}}, nil -} - -func cryptoApisBitcoinParser(r io.Reader) (u divination, err error) { - return cryptoApisParser(r, "BTC") -} - -func cryptoApisBitcoinCashParser(r io.Reader) (u divination, err error) { - return cryptoApisParser(r, "BCH") -} - -func cryptoApisDogecoinParser(r io.Reader) (u divination, err error) { - return cryptoApisParser(r, "DOGE") -} - -func cryptoApisDashParser(r io.Reader) (u divination, err error) { - return cryptoApisParser(r, "DASH") -} - -func cryptoApisLitecoinParser(r io.Reader) (u divination, err error) { - return cryptoApisParser(r, "LTC") -} - -func streamDecodeJSON(stream io.Reader, thing any) error { - return json.NewDecoder(stream).Decode(thing) -} diff --git a/oracle/sources/interface.go b/oracle/sources/interface.go new file mode 100644 index 0000000..27518c8 --- /dev/null +++ b/oracle/sources/interface.go @@ -0,0 +1,58 @@ +package sources + +import ( + "context" + "math/big" + "time" +) + +// Ticker is the upper-case symbol used to indicate an asset. +type Ticker string + +// Network is the network symbol of a Blockchain. +type Network string + +// PriceUpdate represents a price update from a source. +type PriceUpdate struct { + Ticker Ticker + Price float64 +} + +// FeeRateUpdate represents a fee rate update from a source. +type FeeRateUpdate struct { + Network Network + FeeRate *big.Int +} + +// RateInfo is a union type that can hold either price updates or fee rate updates. +type RateInfo struct { + Prices []*PriceUpdate + FeeRates []*FeeRateUpdate +} + +// QuotaStatus represents the current quota state for an API source. +// Values represent fetches, not raw API credits. +type QuotaStatus struct { + FetchesRemaining int64 + FetchesLimit int64 + ResetTime time.Time +} + +// Source is the interface that all oracle data sources must implement. +type Source interface { + // Name returns the source identifier. + Name() string + + // FetchRates fetches current rates/data. + FetchRates(ctx context.Context) (*RateInfo, error) + + // QuotaStatus returns the current quota status. Always returns a valid status. + QuotaStatus() *QuotaStatus + + // Weight returns the configured weight for this source (0-1 range). + Weight() float64 + + // MinPeriod returns the minimum allowed fetch period for this source. + // This is based on the API's data refresh rate and rate limits. + MinPeriod() time.Duration +} diff --git a/oracle/sources/providers/bitcore.go b/oracle/sources/providers/bitcore.go new file mode 100644 index 0000000..7546a37 --- /dev/null +++ b/oracle/sources/providers/bitcore.go @@ -0,0 +1,70 @@ +package providers + +import ( + "context" + "fmt" + "io" + "math" + "strings" + "time" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +func newBitcoreSource(coin string, client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + coin = strings.ToUpper(coin) + name := fmt.Sprintf("%s.bitcore", strings.ToLower(coin)) + + url := fmt.Sprintf("https://api.bitcore.io/api/%s/mainnet/fee/2", coin) + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, client, url, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return parseBitcoreResponse(sources.Network(coin), resp.Body) + } + + return utils.NewUnlimitedSource(utils.UnlimitedSourceConfig{ + Name: name, + Weight: 0.25, // Lower weight due to estimatesmartfee variability + MinPeriod: 30 * time.Second, + FetchRates: fetchRates, + }) +} + +func parseBitcoreResponse(netName sources.Network, r io.Reader) (*sources.RateInfo, error) { + var resp struct { + RatePerKB float64 `json:"feerate"` + } + if err := utils.StreamDecodeJSON(r, &resp); err != nil { + return nil, err + } + if resp.RatePerKB <= 0 { + return nil, fmt.Errorf("fee rate cannot be negative or zero") + } + return &sources.RateInfo{ + FeeRates: []*sources.FeeRateUpdate{{ + Network: netName, + FeeRate: uint64ToBigInt(uint64(math.Round(resp.RatePerKB * 1e5))), + }}, + }, nil +} + +// NewBitcoreBitcoinCashSource creates a Bitcore Bitcoin Cash fee rate source. +func NewBitcoreBitcoinCashSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + return newBitcoreSource("BCH", client, log) +} + +// NewBitcoreDogecoinSource creates a Bitcore Dogecoin fee rate source. +func NewBitcoreDogecoinSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + return newBitcoreSource("DOGE", client, log) +} + +// NewBitcoreLitecoinSource creates a Bitcore Litecoin fee rate source. +func NewBitcoreLitecoinSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + return newBitcoreSource("LTC", client, log) +} diff --git a/oracle/sources/providers/blockcypher.go b/oracle/sources/providers/blockcypher.go new file mode 100644 index 0000000..7e5031a --- /dev/null +++ b/oracle/sources/providers/blockcypher.go @@ -0,0 +1,108 @@ +package providers + +import ( + "context" + "fmt" + "io" + "math" + "net/url" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +// NewBlockcypherLitecoinSource creates a BlockCypher Litecoin fee rate source. +// BlockCypher has a free quota endpoint at /v1/tokens/$TOKEN that resets hourly. +// Free tier: 100 requests/hour = ~36 second interval minimum. +func NewBlockcypherLitecoinSource(httpClient utils.HTTPClient, log slog.Logger, token string) sources.Source { + dataURL := "https://api.blockcypher.com/v1/ltc/main" + if token != "" { + dataURL += "?token=" + url.QueryEscape(token) + } + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, httpClient, dataURL, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return blockcypherLitecoinParser(resp.Body) + } + + tracker := utils.NewQuotaTracker(&utils.QuotaTrackerConfig{ + Name: "blockcypher", + FetchQuota: blockcypherQuotaFetcher(httpClient, token), + ReconcileInterval: 30 * time.Second, + Log: log, + }) + return utils.NewTrackedSource(utils.TrackedSourceConfig{ + Name: "ltc.blockcypher", + Weight: 0.25, + MinPeriod: 36 * time.Second, + FetchRates: fetchRates, + Tracker: tracker, + CreditsPerRequest: 1, + }) +} + +func blockcypherQuotaFetcher(client utils.HTTPClient, token string) func(ctx context.Context) (*sources.QuotaStatus, error) { + return func(ctx context.Context) (*sources.QuotaStatus, error) { + if token == "" { + return utils.UnlimitedQuotaStatus(), nil + } + + url := fmt.Sprintf("https://api.blockcypher.com/v1/tokens/%s", token) + resp, err := utils.DoGet(ctx, client, url, nil) + if err != nil { + return nil, fmt.Errorf("error fetching quota: %v", err) + } + defer resp.Body.Close() + + var result struct { + Limits struct { + APIHour int64 `json:"api/hour"` + } `json:"limits"` + Hits struct { + APIHour int64 `json:"api/hour"` + } `json:"hits"` + } + + if err := utils.StreamDecodeJSON(resp.Body, &result); err != nil { + return nil, fmt.Errorf("error parsing quota response: %v", err) + } + + // Calculate remaining from limit and current hour's usage + limit := result.Limits.APIHour + used := result.Hits.APIHour + + // Reset at top of next hour + now := time.Now().UTC() + resetTime := now.Truncate(time.Hour).Add(time.Hour) + + return &sources.QuotaStatus{ + FetchesRemaining: max(limit-used, 0), + FetchesLimit: limit, + ResetTime: resetTime, + }, nil + } +} + +func blockcypherLitecoinParser(r io.Reader) (*sources.RateInfo, error) { + var resp struct { + Medium float64 `json:"medium_fee_per_kb"` + } + if err := utils.StreamDecodeJSON(r, &resp); err != nil { + return nil, err + } + if resp.Medium <= 0 { + return nil, fmt.Errorf("fee rate cannot be negative or zero") + } + return &sources.RateInfo{ + FeeRates: []*sources.FeeRateUpdate{{ + Network: "LTC", + // medium_fee_per_kb is in litoshis/kB. Divide by 1000 to get litoshis/byte. + FeeRate: uint64ToBigInt(uint64(math.Round(resp.Medium / 1000))), + }}, + }, nil +} diff --git a/oracle/sources/providers/coinmarketcap.go b/oracle/sources/providers/coinmarketcap.go new file mode 100644 index 0000000..a9ff315 --- /dev/null +++ b/oracle/sources/providers/coinmarketcap.go @@ -0,0 +1,118 @@ +package providers + +import ( + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +// NewCoinMarketCapSource creates a CoinMarketCap price source. +// Free plan gives 10,000 credits per month. The listings endpoint uses 1 credit +// per call per 200 assets. With 400 assets, we can call ~5,000 times per month, +// which is about 1 call per 8.9 minutes. We call every 10 minutes to be safe. +// MinPeriod is 60s because CoinMarketCap data only updates every minute. +func NewCoinMarketCapSource(httpClient utils.HTTPClient, log slog.Logger, apiKey string) sources.Source { + url := "https://pro-api.coinmarketcap.com/v1/cryptocurrency/listings/latest?limit=400" + headers := []http.Header{{"X-CMC_PRO_API_KEY": []string{apiKey}}} + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, httpClient, url, headers) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return coinmarketcapParser(resp.Body) + } + + tracker := utils.NewQuotaTracker(&utils.QuotaTrackerConfig{ + Name: "coinmarketcap", + FetchQuota: cmcQuotaFetcher(httpClient, apiKey), + ReconcileInterval: 30 * time.Second, + Log: log, + }) + return utils.NewTrackedSource(utils.TrackedSourceConfig{ + Name: "coinmarketcap", + MinPeriod: 60 * time.Second, + FetchRates: fetchRates, + Tracker: tracker, + CreditsPerRequest: 2, // 1 per 200 assets, fetching 400 + }) +} + +func cmcQuotaFetcher(client utils.HTTPClient, apiKey string) func(ctx context.Context) (*sources.QuotaStatus, error) { + return func(ctx context.Context) (*sources.QuotaStatus, error) { + url := "https://pro-api.coinmarketcap.com/v1/key/info" + resp, err := utils.DoGet(ctx, client, url, []http.Header{{"X-CMC_PRO_API_KEY": []string{apiKey}}}) + if err != nil { + return nil, fmt.Errorf("error fetching quota: %v", err) + } + defer resp.Body.Close() + + var result struct { + Data struct { + Plan struct { + CreditLimitMonthly int64 `json:"credit_limit_monthly"` + CreditLimitMonthlyResetTS string `json:"credit_limit_monthly_reset_timestamp"` + } `json:"plan"` + Usage struct { + CurrentMonth struct { + CreditsUsed int64 `json:"credits_used"` + } `json:"current_month"` + } `json:"usage"` + } `json:"data"` + } + + if err := utils.StreamDecodeJSON(resp.Body, &result); err != nil { + return nil, fmt.Errorf("error parsing quota response: %v", err) + } + + // Parse reset timestamp from API response + resetTime, err := time.Parse(time.RFC3339, result.Data.Plan.CreditLimitMonthlyResetTS) + if err != nil { + // Fallback to first of next month if parsing fails + now := time.Now().UTC() + resetTime = time.Date(now.Year(), now.Month()+1, 1, 0, 0, 0, 0, time.UTC) + } + + return &sources.QuotaStatus{ + FetchesRemaining: max(result.Data.Plan.CreditLimitMonthly-result.Data.Usage.CurrentMonth.CreditsUsed, 0), + FetchesLimit: result.Data.Plan.CreditLimitMonthly, + ResetTime: resetTime, + }, nil + } +} + +func coinmarketcapParser(r io.Reader) (*sources.RateInfo, error) { + var resp struct { + Data []*struct { + Symbol string `json:"symbol"` + Quote struct { + USD struct { + Price float64 `json:"price"` + } `json:"USD"` + } `json:"quote"` + } `json:"data"` + } + if err := utils.StreamDecodeJSON(r, &resp); err != nil { + return nil, err + } + prices := resp.Data + seen := make(map[string]bool, len(prices)) + us := make([]*sources.PriceUpdate, 0, len(prices)) + for _, p := range prices { + if seen[p.Symbol] { + continue + } + seen[p.Symbol] = true + us = append(us, &sources.PriceUpdate{ + Ticker: sources.Ticker(p.Symbol), + Price: p.Quote.USD.Price, + }) + } + return &sources.RateInfo{Prices: us}, nil +} diff --git a/oracle/sources/providers/coinpaprika.go b/oracle/sources/providers/coinpaprika.go new file mode 100644 index 0000000..d26dd9e --- /dev/null +++ b/oracle/sources/providers/coinpaprika.go @@ -0,0 +1,59 @@ +package providers + +import ( + "context" + "io" + "time" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +// NewCoinpaprikaSource creates a Coinpaprika price source. +func NewCoinpaprikaSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + url := "https://api.coinpaprika.com/v1/tickers" + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, client, url, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return coinpaprikaParser(resp.Body) + } + return utils.NewUnlimitedSource(utils.UnlimitedSourceConfig{ + Name: "coinpaprika", + // Free tier updates data up to every 10 minutes, and + // allows 20,000 monthly requests (2m10s interval). + MinPeriod: (5 * time.Minute) / 2, + FetchRates: fetchRates, + }) +} + +func coinpaprikaParser(r io.Reader) (*sources.RateInfo, error) { + var prices []*struct { + Symbol string `json:"symbol"` + Quotes struct { + USD struct { + Price float64 `json:"price"` + } `json:"USD"` + } `json:"quotes"` + } + if err := utils.StreamDecodeJSON(r, &prices); err != nil { + return nil, err + } + seen := make(map[string]bool, len(prices)) + us := make([]*sources.PriceUpdate, 0, len(prices)) + for _, p := range prices { + if seen[p.Symbol] { + continue + } + seen[p.Symbol] = true + us = append(us, &sources.PriceUpdate{ + Ticker: sources.Ticker(p.Symbol), + Price: p.Quotes.USD.Price, + }) + } + return &sources.RateInfo{Prices: us}, nil +} diff --git a/oracle/sources/providers/dcrdata.go b/oracle/sources/providers/dcrdata.go new file mode 100644 index 0000000..bbeb460 --- /dev/null +++ b/oracle/sources/providers/dcrdata.go @@ -0,0 +1,63 @@ +package providers + +import ( + "context" + "fmt" + "io" + "math" + "math/big" + "time" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +// NewDcrdataSource creates a dcrdata fee rate source. +// Decred blocks average ~5 minutes. Self-hosted infrastructure with configurable rate limits. +// MinPeriod is 30s as a reasonable default for block-based fee estimates. +func NewDcrdataSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + url := "https://explorer.dcrdata.org/insight/api/utils/estimatefee?nbBlocks=2" + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, client, url, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return dcrdataParser(resp.Body) + } + return utils.NewUnlimitedSource(utils.UnlimitedSourceConfig{ + Name: "dcrdata", + MinPeriod: 30 * time.Second, // Block-based fee data, ~5 min blocks + FetchRates: fetchRates, + }) +} + +var dcrdataParser = estimateFeeParser("DCR") + +func estimateFeeParser(network sources.Network) func(io.Reader) (*sources.RateInfo, error) { + return func(r io.Reader) (*sources.RateInfo, error) { + var resp map[string]float64 + if err := utils.StreamDecodeJSON(r, &resp); err != nil { + return nil, err + } + rate, ok := resp["2"] + if !ok || len(resp) != 1 { + return nil, fmt.Errorf("unexpected response format: %+v", resp) + } + if rate <= 0 { + return nil, fmt.Errorf("fee rate must be positive, got %v", rate) + } + return &sources.RateInfo{ + FeeRates: []*sources.FeeRateUpdate{{ + Network: network, + FeeRate: uint64ToBigInt(uint64(math.Round(rate * 1e5))), + }}, + }, nil + } +} + +func uint64ToBigInt(val uint64) *big.Int { + return new(big.Int).SetUint64(val) +} diff --git a/oracle/sources/providers/firo.go b/oracle/sources/providers/firo.go new file mode 100644 index 0000000..9f2b520 --- /dev/null +++ b/oracle/sources/providers/firo.go @@ -0,0 +1,34 @@ +package providers + +import ( + "context" + "time" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +// NewFiroOrgSource creates a Firo network fee rate source. +// Third-party explorer with ~1 req/sec limit (CoinExplorer). +// MinPeriod is 30s as a conservative default for third-party explorers. +func NewFiroOrgSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + url := "https://explorer.firo.org/insight-api-zcoin/utils/estimatefee" + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, client, url, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return firoOrgParser(resp.Body) + } + return utils.NewUnlimitedSource(utils.UnlimitedSourceConfig{ + Name: "firo.org", + Weight: 0.25, // Lower weight due to estimatesmartfee variability + MinPeriod: 30 * time.Second, + FetchRates: fetchRates, + }) +} + +var firoOrgParser = estimateFeeParser("FIRO") diff --git a/oracle/sources/providers/live_test.go b/oracle/sources/providers/live_test.go new file mode 100644 index 0000000..aedbb42 --- /dev/null +++ b/oracle/sources/providers/live_test.go @@ -0,0 +1,268 @@ +//go:build live + +package providers_test + +import ( + "context" + "math" + "net/http" + "testing" + "time" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/providers" +) + +// These tests make real HTTP requests to external APIs. +// Run with: go test -tags=live -v ./oracle/sources/providers + +// Supply API keys before running authenticated source tests. +const ( + coinmarketcapAPIKey = "b2def99762ca4df2b5d557ae6bf1a4a5" + tatumAPIKey = "" + blockcypherToken = "91ded84bd49348688d319245a62388af" +) + +func liveTestLogger() slog.Logger { return slog.Disabled } + +func httpClient() *http.Client { return &http.Client{Timeout: 30 * time.Second} } + +// === Unlimited Sources (no API key required) === + +func TestLiveDcrdataSource(t *testing.T) { + src := providers.NewDcrdataSource(httpClient(), liveTestLogger()) + testFeeRateSource(t, src, "DCR") + testMinPeriod(t, src, 30*time.Second) + testUnlimitedQuota(t, src) +} + +func TestLiveMempoolDotSpaceSource(t *testing.T) { + src := providers.NewMempoolDotSpaceSource(httpClient(), liveTestLogger()) + testFeeRateSource(t, src, "BTC") + testMinPeriod(t, src, 10*time.Second) + testUnlimitedQuota(t, src) +} + +func TestLiveCoinpaprikaSource(t *testing.T) { + src := providers.NewCoinpaprikaSource(httpClient(), liveTestLogger()) + testPriceSource(t, src) + testMinPeriod(t, src, 60*time.Second) + testUnlimitedQuota(t, src) +} + +func TestLiveBitcoreBitcoinCashSource(t *testing.T) { + src := providers.NewBitcoreBitcoinCashSource(httpClient(), liveTestLogger()) + testFeeRateSource(t, src, "BCH") + testMinPeriod(t, src, 30*time.Second) + testUnlimitedQuota(t, src) +} + +func TestLiveBitcoreDogecoinSource(t *testing.T) { + src := providers.NewBitcoreDogecoinSource(httpClient(), liveTestLogger()) + testFeeRateSource(t, src, "DOGE") + testMinPeriod(t, src, 30*time.Second) + testUnlimitedQuota(t, src) +} + +func TestLiveBitcoreLitecoinSource(t *testing.T) { + src := providers.NewBitcoreLitecoinSource(httpClient(), liveTestLogger()) + testFeeRateSource(t, src, "LTC") + testMinPeriod(t, src, 30*time.Second) + testUnlimitedQuota(t, src) +} + +func TestLiveFiroOrgSource(t *testing.T) { + src := providers.NewFiroOrgSource(httpClient(), liveTestLogger()) + testFeeRateSource(t, src, "FIRO") + testMinPeriod(t, src, 30*time.Second) + testUnlimitedQuota(t, src) +} + +// === Authenticated Sources (API key required) === + +func TestLiveBlockcypherLitecoinSource(t *testing.T) { + src := providers.NewBlockcypherLitecoinSource(httpClient(), liveTestLogger(), blockcypherToken) + testFeeRateSource(t, src, "LTC") + testMinPeriod(t, src, 36*time.Second) + + if blockcypherToken != "" { + testPooledQuota(t, src) + } else { + t.Log("Skipping quota test: blockcypher token not provided") + testUnlimitedQuota(t, src) + } +} + +func TestLiveCoinMarketCapSource(t *testing.T) { + if coinmarketcapAPIKey == "" { + t.Skip("coinmarketcap API key not provided") + } + src := providers.NewCoinMarketCapSource(httpClient(), liveTestLogger(), coinmarketcapAPIKey) + testPriceSource(t, src) + testMinPeriod(t, src, 60*time.Second) + testPooledQuota(t, src) +} + +func TestLiveTatumSources(t *testing.T) { + if tatumAPIKey == "" { + t.Skip("tatum API key not provided") + } + tatumSources := providers.NewTatumSources(providers.TatumConfig{ + HTTPClient: httpClient(), + Log: liveTestLogger(), + APIKey: tatumAPIKey, + }) + + t.Run("bitcoin", func(t *testing.T) { + testFeeRateSource(t, tatumSources.Bitcoin, "BTC") + testMinPeriod(t, tatumSources.Bitcoin, 10*time.Second) + }) + t.Run("litecoin", func(t *testing.T) { + testFeeRateSource(t, tatumSources.Litecoin, "LTC") + testMinPeriod(t, tatumSources.Litecoin, 10*time.Second) + }) + t.Run("dogecoin", func(t *testing.T) { + testFeeRateSource(t, tatumSources.Dogecoin, "DOGE") + testMinPeriod(t, tatumSources.Dogecoin, 10*time.Second) + }) + t.Run("shared quota", func(t *testing.T) { + // Give reconciliation time to complete. + time.Sleep(2 * time.Second) + for _, src := range tatumSources.All() { + testPooledQuota(t, src) + } + }) +} + +// === Helper Functions === + +func testFeeRateSource(t *testing.T, src sources.Source, expectedNetwork sources.Network) { + t.Helper() + + testSourceInterface(t, src) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := src.FetchRates(ctx) + if err != nil { + t.Fatalf("FetchRates failed: %v", err) + } + if len(result.FeeRates) == 0 { + t.Fatal("no fee rates returned") + } + + found := false + for _, fr := range result.FeeRates { + if fr.Network == expectedNetwork { + found = true + if fr.FeeRate == nil || fr.FeeRate.Sign() <= 0 { + t.Errorf("fee rate for %s is nil or non-positive", expectedNetwork) + } + t.Logf("[%s] %s fee rate: %s", src.Name(), expectedNetwork, fr.FeeRate.String()) + } + } + if !found { + t.Errorf("expected network %s not found in fee rates", expectedNetwork) + } +} + +func testPriceSource(t *testing.T, src sources.Source) { + t.Helper() + + testSourceInterface(t, src) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := src.FetchRates(ctx) + if err != nil { + t.Fatalf("FetchRates failed: %v", err) + } + if len(result.Prices) == 0 { + t.Fatal("no prices returned") + } + + t.Logf("[%s] total prices returned: %d", src.Name(), len(result.Prices)) + + // Log some common tickers + commonTickers := []sources.Ticker{"BTC", "ETH", "LTC", "DCR", "DOGE"} + for _, ticker := range commonTickers { + for _, p := range result.Prices { + if p.Ticker == ticker { + if p.Price <= 0 { + t.Errorf("price for %s is <= 0: %f", ticker, p.Price) + } + t.Logf("[%s] %s price: $%.2f", src.Name(), ticker, p.Price) + break + } + } + } +} + +func testSourceInterface(t *testing.T, src sources.Source) { + t.Helper() + + name := src.Name() + if name == "" { + t.Error("Name() returned empty string") + } + + weight := src.Weight() + if weight <= 0 || weight > 1 { + t.Errorf("Weight() returned %f, expected (0, 1]", weight) + } + + minPeriod := src.MinPeriod() + if minPeriod <= 0 { + t.Errorf("MinPeriod() returned %v, expected > 0", minPeriod) + } + + quota := src.QuotaStatus() + if quota == nil { + t.Error("QuotaStatus() returned nil") + } + + t.Logf("[%s] interface: weight=%.2f, minPeriod=%v", name, weight, minPeriod) +} + +func testMinPeriod(t *testing.T, src sources.Source, expected time.Duration) { + t.Helper() + actual := src.MinPeriod() + if actual != expected { + t.Errorf("[%s] MinPeriod() = %v, expected %v", src.Name(), actual, expected) + } +} + +func testUnlimitedQuota(t *testing.T, src sources.Source) { + t.Helper() + status := src.QuotaStatus() + if status == nil { + t.Fatal("QuotaStatus() returned nil") + } + if status.FetchesRemaining != math.MaxInt64 { + t.Errorf("[%s] expected unlimited fetches (MaxInt64), got %d", src.Name(), status.FetchesRemaining) + } + t.Logf("[%s] quota: unlimited (fetches=%d)", src.Name(), status.FetchesRemaining) +} + +func testPooledQuota(t *testing.T, src sources.Source) { + t.Helper() + status := src.QuotaStatus() + if status == nil { + t.Fatal("QuotaStatus() returned nil") + } + if status.FetchesLimit <= 0 { + t.Errorf("[%s] expected positive FetchesLimit, got %d", src.Name(), status.FetchesLimit) + } + if status.FetchesRemaining < 0 { + t.Errorf("[%s] expected non-negative FetchesRemaining, got %d", src.Name(), status.FetchesRemaining) + } + if status.ResetTime.IsZero() { + t.Errorf("[%s] expected ResetTime to be set", src.Name()) + } + t.Logf("[%s] pooled quota: %d/%d fetches remaining (per source), resets at %v", + src.Name(), status.FetchesRemaining, status.FetchesLimit, status.ResetTime.Format(time.RFC3339)) +} diff --git a/oracle/sources/providers/mempool.go b/oracle/sources/providers/mempool.go new file mode 100644 index 0000000..58c1ddc --- /dev/null +++ b/oracle/sources/providers/mempool.go @@ -0,0 +1,52 @@ +package providers + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +// NewMempoolDotSpaceSource creates a mempool.space Bitcoin fee rate source. +// Real-time fee data updated per block. Rate limits undisclosed but enforced. +// MinPeriod is 1 minute since data is real-time and they recommend self-hosting +// for heavy use. +func NewMempoolDotSpaceSource(client utils.HTTPClient, log slog.Logger) *utils.UnlimitedSource { + url := "https://mempool.space/api/v1/fees/recommended" + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, client, url, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return mempoolDotSpaceParser(resp.Body) + } + return utils.NewUnlimitedSource(utils.UnlimitedSourceConfig{ + Name: "btc.mempooldotspace", + MinPeriod: time.Minute, + FetchRates: fetchRates, + }) +} + +func mempoolDotSpaceParser(r io.Reader) (*sources.RateInfo, error) { + var resp struct { + FastestFee uint64 `json:"fastestFee"` + } + if err := utils.StreamDecodeJSON(r, &resp); err != nil { + return nil, err + } + if resp.FastestFee == 0 { + return nil, fmt.Errorf("zero fee rate returned") + } + return &sources.RateInfo{ + FeeRates: []*sources.FeeRateUpdate{{ + Network: "BTC", + FeeRate: uint64ToBigInt(resp.FastestFee), + }}, + }, nil +} diff --git a/oracle/sources/providers/source_test.go b/oracle/sources/providers/source_test.go new file mode 100644 index 0000000..b1f4385 --- /dev/null +++ b/oracle/sources/providers/source_test.go @@ -0,0 +1,220 @@ +package providers_test + +import ( + "bytes" + "context" + "io" + "math" + "math/big" + "net/http" + "testing" + + "github.com/decred/slog" + + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/providers" +) + +// tHTTPClient implements sources.HTTPClient for testing. +type tHTTPClient struct { + response *http.Response + err error +} + +func (tc *tHTTPClient) Do(*http.Request) (*http.Response, error) { return tc.response, tc.err } + +func newMockResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + } +} + +func testLogger() slog.Logger { return slog.Disabled } + +func TestDcrdataSource(t *testing.T) { + client := &tHTTPClient{response: newMockResponse(`{"2": 0.0001}`)} + src := providers.NewDcrdataSource(client, testLogger()) + + t.Run("valid response", func(t *testing.T) { + client.response = newMockResponse(`{"2": 0.0001}`) + result, err := src.FetchRates(context.Background()) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if len(result.FeeRates) != 1 { + t.Fatalf("expected 1 fee rate, got %d", len(result.FeeRates)) + } + + if result.FeeRates[0].Network != "DCR" { + t.Errorf("expected network DCR, got %s", result.FeeRates[0].Network) + } + + // 0.0001 DCR/kB * 1e5 = 10 atoms/byte + if result.FeeRates[0].FeeRate.Cmp(big.NewInt(10)) != 0 { + t.Errorf("expected fee rate 10, got %s", result.FeeRates[0].FeeRate.String()) + } + }) + + t.Run("quota status is unlimited", func(t *testing.T) { + status := src.QuotaStatus() + if status.FetchesRemaining != math.MaxInt64 { + t.Errorf("expected unlimited fetches, got %d", status.FetchesRemaining) + } + }) +} + +func TestMempoolDotSpaceSource(t *testing.T) { + client := &tHTTPClient{} + src := providers.NewMempoolDotSpaceSource(client, testLogger()) + + t.Run("valid response", func(t *testing.T) { + client.response = newMockResponse(`{"fastestFee": 25}`) + result, err := src.FetchRates(context.Background()) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if len(result.FeeRates) != 1 { + t.Fatalf("expected 1 fee rate, got %d", len(result.FeeRates)) + } + + if result.FeeRates[0].Network != "BTC" { + t.Errorf("expected network BTC, got %s", result.FeeRates[0].Network) + } + + if result.FeeRates[0].FeeRate.Cmp(big.NewInt(25)) != 0 { + t.Errorf("expected fee rate 25, got %s", result.FeeRates[0].FeeRate.String()) + } + }) +} + +func TestCoinpaprikaSource(t *testing.T) { + client := &tHTTPClient{} + src := providers.NewCoinpaprikaSource(client, testLogger()) + + t.Run("valid response", func(t *testing.T) { + body := `[ + {"id":"btc-bitcoin","symbol":"BTC","quotes":{"USD":{"price":87838.55}}}, + {"id":"eth-ethereum","symbol":"ETH","quotes":{"USD":{"price":2954.14}}} + ]` + client.response = newMockResponse(body) + result, err := src.FetchRates(context.Background()) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if len(result.Prices) != 2 { + t.Fatalf("expected 2 prices, got %d", len(result.Prices)) + } + + prices := make(map[sources.Ticker]float64) + for _, p := range result.Prices { + prices[p.Ticker] = p.Price + } + + if prices["BTC"] != 87838.55 { + t.Errorf("expected BTC price 87838.55, got %f", prices["BTC"]) + } + }) +} + +func TestCoinMarketCapSource(t *testing.T) { + client := &tHTTPClient{} + src := providers.NewCoinMarketCapSource(client, testLogger(), "test-api-key") + + t.Run("valid response", func(t *testing.T) { + body := `{ + "data": [ + {"symbol":"BTC","quote":{"USD":{"price":90000.50}}}, + {"symbol":"ETH","quote":{"USD":{"price":3100.25}}} + ] + }` + client.response = newMockResponse(body) + result, err := src.FetchRates(context.Background()) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if len(result.Prices) != 2 { + t.Fatalf("expected 2 prices, got %d", len(result.Prices)) + } + }) + + t.Run("quota status refreshes", func(t *testing.T) { + // Initially should return unlimited (no quota fetched yet) + status := src.QuotaStatus() + if status == nil { + t.Fatal("expected quota status") + } + }) +} + +func TestTatumSources(t *testing.T) { + client := &tHTTPClient{} + tatumSources := providers.NewTatumSources(providers.TatumConfig{ + HTTPClient: client, + Log: testLogger(), + APIKey: "test-api-key", + }) + + t.Run("all sources returned", func(t *testing.T) { + all := tatumSources.All() + if len(all) != 3 { + t.Fatalf("expected 3 sources, got %d", len(all)) + } + }) + + t.Run("btc valid response", func(t *testing.T) { + client.response = newMockResponse(`{"fast": 25}`) + result, err := tatumSources.Bitcoin.FetchRates(context.Background()) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if len(result.FeeRates) != 1 { + t.Fatalf("expected 1 fee rate, got %d", len(result.FeeRates)) + } + + if result.FeeRates[0].Network != "BTC" { + t.Errorf("expected network BTC, got %s", result.FeeRates[0].Network) + } + + if result.FeeRates[0].FeeRate.Cmp(big.NewInt(25)) != 0 { + t.Errorf("expected fee rate 25, got %s", result.FeeRates[0].FeeRate.String()) + } + }) + + t.Run("pool tracks consumption", func(t *testing.T) { + // After a fetch, pool should have consumed 10 credits. + // Before reconciliation, quota is unlimited. + status := tatumSources.Bitcoin.QuotaStatus() + if status == nil { + t.Fatal("expected quota status") + } + }) +} + +func TestBlockcypherSource(t *testing.T) { + client := &tHTTPClient{} + src := providers.NewBlockcypherLitecoinSource(client, testLogger(), "test-token") + + t.Run("valid response", func(t *testing.T) { + body := `{"medium_fee_per_kb": 10000}` + client.response = newMockResponse(body) + result, err := src.FetchRates(context.Background()) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if len(result.FeeRates) != 1 { + t.Fatalf("expected 1 fee rate, got %d", len(result.FeeRates)) + } + + if result.FeeRates[0].Network != "LTC" { + t.Errorf("expected network LTC, got %s", result.FeeRates[0].Network) + } + }) +} diff --git a/oracle/sources/providers/tatum.go b/oracle/sources/providers/tatum.go new file mode 100644 index 0000000..4a9bcc8 --- /dev/null +++ b/oracle/sources/providers/tatum.go @@ -0,0 +1,131 @@ +package providers + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" + "github.com/bisoncraft/mesh/oracle/sources/utils" +) + +const ( + tatumCreditsPerRequest = 10 // Each fee estimation call costs 10 credits. + tatumReconcileInterval = 10 * time.Minute + tatumMinPeriod = 10 * time.Second // Real-time fee data, 3 req/sec rate limit. +) + +// TatumConfig configures the Tatum source group. +type TatumConfig struct { + HTTPClient utils.HTTPClient + Log slog.Logger + APIKey string +} + +// TatumSources holds all Tatum-powered fee rate sources that share +// a single API key and quota tracker. +type TatumSources struct { + Bitcoin sources.Source + Litecoin sources.Source + Dogecoin sources.Source + pool *utils.QuotaTracker +} + +// All returns all Tatum sources. +func (ts *TatumSources) All() []sources.Source { + return []sources.Source{ts.Bitcoin, ts.Litecoin, ts.Dogecoin} +} + +// NewTatumSources creates a Tatum source group with a shared quota tracker. +func NewTatumSources(cfg TatumConfig) *TatumSources { + tracker := utils.NewQuotaTracker(&utils.QuotaTrackerConfig{ + Name: "tatum", + FetchQuota: tatumQuotaFetcher(cfg.HTTPClient, cfg.APIKey), + ReconcileInterval: tatumReconcileInterval, + Log: cfg.Log, + }) + + headers := []http.Header{{"x-api-key": []string{cfg.APIKey}}} + + mkSource := func(coin string, network sources.Network, name string) sources.Source { + url := fmt.Sprintf("https://api.tatum.io/v3/blockchain/fee/%s", coin) + parse := tatumParserForNetwork(network) + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + resp, err := utils.DoGet(ctx, cfg.HTTPClient, url, headers) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return parse(resp.Body) + } + return utils.NewTrackedSource(utils.TrackedSourceConfig{ + Name: name, + MinPeriod: tatumMinPeriod, + FetchRates: fetchRates, + Tracker: tracker, + CreditsPerRequest: tatumCreditsPerRequest, + }) + } + + return &TatumSources{ + Bitcoin: mkSource("BTC", "BTC", "tatum.btc"), + Litecoin: mkSource("LTC", "LTC", "tatum.ltc"), + Dogecoin: mkSource("DOGE", "DOGE", "tatum.doge"), + pool: tracker, + } +} + +// tatumQuotaFetcher returns a function that fetches quota from Tatum's usage endpoint. +func tatumQuotaFetcher(client utils.HTTPClient, apiKey string) func(ctx context.Context) (*sources.QuotaStatus, error) { + return func(ctx context.Context) (*sources.QuotaStatus, error) { + url := "https://api.tatum.io/v3/tatum/usage" + resp, err := utils.DoGet(ctx, client, url, []http.Header{{"x-api-key": []string{apiKey}}}) + if err != nil { + return nil, fmt.Errorf("error fetching tatum quota: %v", err) + } + defer resp.Body.Close() + + var result struct { + Used int64 `json:"used"` + Limit int64 `json:"limit"` + } + + if err := utils.StreamDecodeJSON(resp.Body, &result); err != nil { + return nil, fmt.Errorf("error parsing tatum quota response: %v", err) + } + + // Reset at first of next month. + now := time.Now().UTC() + nextMonth := time.Date(now.Year(), now.Month()+1, 1, 0, 0, 0, 0, time.UTC) + + return &sources.QuotaStatus{ + FetchesRemaining: max(result.Limit-result.Used, 0), + FetchesLimit: result.Limit, + ResetTime: nextMonth, + }, nil + } +} + +func tatumParserForNetwork(network sources.Network) func(io.Reader) (*sources.RateInfo, error) { + return func(r io.Reader) (*sources.RateInfo, error) { + var resp struct { + Fast float64 `json:"fast"` + } + if err := utils.StreamDecodeJSON(r, &resp); err != nil { + return nil, err + } + if resp.Fast <= 0 { + return nil, fmt.Errorf("fee rate cannot be negative or zero") + } + return &sources.RateInfo{ + FeeRates: []*sources.FeeRateUpdate{{ + Network: network, + FeeRate: uint64ToBigInt(uint64(math.Round(resp.Fast))), + }}, + }, nil + } +} diff --git a/oracle/sources/utils/http.go b/oracle/sources/utils/http.go new file mode 100644 index 0000000..dceb6b4 --- /dev/null +++ b/oracle/sources/utils/http.go @@ -0,0 +1,96 @@ +package utils + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "strings" + "time" + + "github.com/bisoncraft/mesh/oracle/sources" +) + +const ( + defaultMinPeriod = 30 * time.Second + defaultWeight = 1.0 + + // httpErrBodySnippetLimit is the max bytes of response body to include in a + // non-2xx HTTP error. + httpErrBodySnippetLimit = 4 << 10 // 4 KiB + + // maxJSONBytes is a safety cap for JSON decoding from HTTP responses. + // Note: callers generally decode a small subset of fields, so responses + // should be modest in size. + maxJSONBytes = 10 << 20 // 10 MiB +) + +// HTTPClient defines the requirements for implementing an http client. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// DoGet performs an HTTP GET request, returning the response or an error for +// non-2xx status codes. +func DoGet(ctx context.Context, client HTTPClient, url string, headers []http.Header) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("error generating request %q: %w", url, err) + } + + for _, header := range headers { + for k, vs := range header { + for _, v := range vs { + req.Header.Add(k, v) + } + } + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error fetching %q: %w", url, err) + } + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + snippetBytes, _ := io.ReadAll(io.LimitReader(resp.Body, httpErrBodySnippetLimit)) + _ = resp.Body.Close() + snippet := strings.TrimSpace(string(snippetBytes)) + if snippet != "" { + return nil, fmt.Errorf("http %d fetching %q: %s", resp.StatusCode, url, snippet) + } + return nil, fmt.Errorf("http %d fetching %q", resp.StatusCode, url) + } + + return resp, nil +} + +// UnlimitedQuotaStatus returns a quota status indicating unlimited fetches. +func UnlimitedQuotaStatus() *sources.QuotaStatus { + now := time.Now().UTC() + return &sources.QuotaStatus{ + FetchesRemaining: math.MaxInt64, + FetchesLimit: math.MaxInt64, + ResetTime: now.Add(24 * time.Hour), + } +} + +// StreamDecodeJSON decodes JSON from a stream. +func StreamDecodeJSON(stream io.Reader, thing any) error { + dec := json.NewDecoder(io.LimitReader(stream, maxJSONBytes)) + if err := dec.Decode(thing); err != nil { + return err + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + if err == nil { + return fmt.Errorf("unexpected trailing JSON") + } + return err + } + return nil +} diff --git a/oracle/sources/utils/quota_tracker.go b/oracle/sources/utils/quota_tracker.go new file mode 100644 index 0000000..19d6131 --- /dev/null +++ b/oracle/sources/utils/quota_tracker.go @@ -0,0 +1,279 @@ +package utils + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" +) + +const ( + defaultReconcileInterval = time.Hour + reconcileTimeout = 10 * time.Second +) + +// QuotaTracker tracks quota for one or more sources that share a single API +// key/credit pool. It divides the available quota evenly among registered +// sources. It tracks consumption locally and periodically reconciles with an +// API endpoint. +type QuotaTracker struct { + mtx sync.RWMutex + creditsRemaining int64 + creditsLimit int64 + resetTime time.Time + sourceCount int + initialized bool + + name string + fetchQuota func(ctx context.Context) (*sources.QuotaStatus, error) + reconcileInterval time.Duration + lastReconcile time.Time + reconciling atomic.Bool + initOnce sync.Once + log slog.Logger +} + +// QuotaTrackerConfig configures a QuotaTracker. +type QuotaTrackerConfig struct { + // Name identifies this quota tracker in log messages. + Name string + + // FetchQuota fetches quota status from the server. + FetchQuota func(ctx context.Context) (*sources.QuotaStatus, error) + + // ReconcileInterval is how often to reconcile quota with the API. + ReconcileInterval time.Duration + + Log slog.Logger +} + +// verify validates that all fields are set. Panics if any field is missing. +func (cfg *QuotaTrackerConfig) verify() { + if cfg == nil { + panic("quota tracker config is nil") + } + if cfg.FetchQuota == nil { + panic("fetch quota function is required") + } + if cfg.Log == nil { + panic("logger is required") + } +} + +// NewQuotaTracker creates a new quota tracker. +func NewQuotaTracker(cfg *QuotaTrackerConfig) *QuotaTracker { + cfg.verify() + + reconcileInterval := cfg.ReconcileInterval + if reconcileInterval == 0 { + reconcileInterval = defaultReconcileInterval + } + + return &QuotaTracker{ + name: cfg.Name, + fetchQuota: cfg.FetchQuota, + reconcileInterval: reconcileInterval, + log: cfg.Log, + } +} + +// ConsumeCredits decrements the tracker's credit counter. +func (p *QuotaTracker) ConsumeCredits(n int64) { + p.mtx.Lock() + defer p.mtx.Unlock() + + if p.creditsRemaining <= n { + p.creditsRemaining = 0 + } else { + p.creditsRemaining -= n + } +} + +// AddSource increments the source count. +func (p *QuotaTracker) AddSource() { + p.mtx.Lock() + p.sourceCount++ + p.mtx.Unlock() +} + +// QuotaStatus returns the quota divided by source count. +// Each source gets an equal share of the available credits. +// The first call blocks until reconciliation completes so that callers +// always receive accurate quota data. Subsequent reconciliations are +// triggered in the background when the interval elapses. +// Returns a zero-valued status if the tracker has not been initialized via +// reconciliation. +func (p *QuotaTracker) QuotaStatus() *sources.QuotaStatus { + // Block on first reconciliation to ensure accurate initial quota data. + p.initOnce.Do(func() { + p.reconciling.Store(true) + p.reconcile() + }) + + p.mtx.RLock() + defer p.mtx.RUnlock() + + // Trigger async reconciliation if stale. + now := time.Now().UTC() + if now.Sub(p.lastReconcile) > p.reconcileInterval { + if p.reconciling.CompareAndSwap(false, true) { + go p.reconcile() + } + } + + sourceCount := p.sourceCount + if sourceCount == 0 { + sourceCount = 1 + } + + return &sources.QuotaStatus{ + FetchesRemaining: p.creditsRemaining / int64(sourceCount), + FetchesLimit: p.creditsLimit / int64(sourceCount), + ResetTime: p.resetTime, + } +} + +// reconcile fetches the current quota from the server and merges it with +// local state. On the first successful sync it adopts the server's values +// unconditionally. On subsequent syncs it conservatively keeps the lower +// of the two remaining counts to avoid over-fetching when another consumer +// shares the same API key. +func (p *QuotaTracker) reconcile() { + defer p.reconciling.Store(false) + + ctx, cancel := context.WithTimeout(context.Background(), reconcileTimeout) + defer cancel() + + serverQuota, err := p.fetchQuota(ctx) + now := time.Now().UTC() + + if err != nil { + p.log.Errorf("[%s] Failed to reconcile quota: %v", p.name, err) + // Update lastReconcile to avoid hammering the endpoint. + p.mtx.Lock() + p.lastReconcile = now + p.mtx.Unlock() + return + } + + p.mtx.Lock() + defer p.mtx.Unlock() + + if serverQuota == nil { + p.log.Warnf("[%s] Quota reconcile: server returned nil quota", p.name) + p.lastReconcile = now + return + } + + firstSync := !p.initialized + + p.creditsLimit = serverQuota.FetchesLimit + p.resetTime = serverQuota.ResetTime + p.initialized = true + + // For some sources, the API's counter is eventually consistent and can lag + // behind our local consumption tracking. Only adopt the API's remaining + // count when it reports more usage than we've tracked locally (e.g. + // another consumer sharing the same key) or on the first sync. + if firstSync { + p.log.Infof("[%s] Quota initial sync: server remaining = %d, limit = %d", + p.name, serverQuota.FetchesRemaining, serverQuota.FetchesLimit) + p.creditsRemaining = serverQuota.FetchesRemaining + } else if serverQuota.FetchesRemaining < p.creditsRemaining { + // Server reports more usage than we tracked — another consumer + // may be sharing this key. + p.log.Warnf("[%s] Quota reconcile: server remaining (%d) < local estimate (%d), syncing down", + p.name, serverQuota.FetchesRemaining, p.creditsRemaining) + p.creditsRemaining = serverQuota.FetchesRemaining + } else if serverQuota.FetchesRemaining > p.creditsRemaining { + // Server hasn't caught up with our local consumption yet. + p.log.Infof("[%s] Quota reconcile: server remaining (%d) > local estimate (%d), keeping local", + p.name, serverQuota.FetchesRemaining, p.creditsRemaining) + } + + p.lastReconcile = now +} + +// TrackedSourceConfig configures a TrackedSource. +type TrackedSourceConfig struct { + Name string + Weight float64 + MinPeriod time.Duration + FetchRates FetchRatesFunc + Tracker *QuotaTracker + CreditsPerRequest int64 +} + +// TrackedSource is a source whose quota is managed by a shared QuotaTracker. +type TrackedSource struct { + name string + weight float64 + minPeriod time.Duration + fetchRates FetchRatesFunc + tracker *QuotaTracker + creditsPerRequest int64 +} + +// NewTrackedSource creates a new tracked source. It validates config, applies +// defaults for Weight and MinPeriod, and registers itself with the tracker. +func NewTrackedSource(cfg TrackedSourceConfig) *TrackedSource { + if cfg.Name == "" { + panic("tracked source: name is required") + } + if cfg.FetchRates == nil { + panic("tracked source: FetchRates is required") + } + if cfg.Tracker == nil { + panic("tracked source: Tracker is required") + } + + weight := cfg.Weight + if weight == 0 { + weight = defaultWeight + } + minPeriod := cfg.MinPeriod + if minPeriod == 0 { + minPeriod = defaultMinPeriod + } + + if cfg.CreditsPerRequest <= 0 { + cfg.CreditsPerRequest = 1 + } + + cfg.Tracker.AddSource() + + return &TrackedSource{ + name: cfg.Name, + weight: weight, + minPeriod: minPeriod, + fetchRates: cfg.FetchRates, + tracker: cfg.Tracker, + creditsPerRequest: cfg.CreditsPerRequest, + } +} + +func (s *TrackedSource) Name() string { return s.name } +func (s *TrackedSource) Weight() float64 { return s.weight } +func (s *TrackedSource) MinPeriod() time.Duration { return s.minPeriod } + +func (s *TrackedSource) FetchRates(ctx context.Context) (*sources.RateInfo, error) { + rates, err := s.fetchRates(ctx) + if err != nil { + return nil, err + } + + s.tracker.ConsumeCredits(s.creditsPerRequest) + return rates, nil +} + +func (s *TrackedSource) QuotaStatus() *sources.QuotaStatus { + status := s.tracker.QuotaStatus() + if s.creditsPerRequest > 1 { + status.FetchesRemaining /= s.creditsPerRequest + status.FetchesLimit /= s.creditsPerRequest + } + return status +} diff --git a/oracle/sources/utils/quota_tracker_test.go b/oracle/sources/utils/quota_tracker_test.go new file mode 100644 index 0000000..b50eb76 --- /dev/null +++ b/oracle/sources/utils/quota_tracker_test.go @@ -0,0 +1,481 @@ +package utils + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/decred/slog" + "github.com/bisoncraft/mesh/oracle/sources" +) + +// newTestPool creates a QuotaTracker whose FetchQuota always errors. +// Useful for tests that don't need initialized quota (e.g. panic validation, +// field accessor tests, the "zero before initialized" test). +func newTestPool(t *testing.T) *QuotaTracker { + t.Helper() + return NewQuotaTracker(&QuotaTrackerConfig{ + Name: "test", + FetchQuota: func(ctx context.Context) (*sources.QuotaStatus, error) { + return nil, fmt.Errorf("test pool: no server") + }, + Log: slog.Disabled, + }) +} + +// newTestPoolWithQuota creates a QuotaTracker whose FetchQuota returns the +// given values. The first QuotaStatus() call triggers reconciliation and +// seeds the tracker. +func newTestPoolWithQuota(t *testing.T, remaining, limit int64) *QuotaTracker { + t.Helper() + return NewQuotaTracker(&QuotaTrackerConfig{ + Name: "test", + FetchQuota: func(ctx context.Context) (*sources.QuotaStatus, error) { + return &sources.QuotaStatus{ + FetchesRemaining: remaining, + FetchesLimit: limit, + }, nil + }, + Log: slog.Disabled, + }) +} + +func TestNewQuotaTracker_NilConfigPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for nil config") + } + }() + NewQuotaTracker(nil) +} + +func TestNewQuotaTracker_MissingFetchQuotaPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for missing FetchQuota") + } + }() + NewQuotaTracker(&QuotaTrackerConfig{ + Log: slog.Disabled, + }) +} + +func TestNewQuotaTracker_MissingLogPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for missing Log") + } + }() + NewQuotaTracker(&QuotaTrackerConfig{ + FetchQuota: func(ctx context.Context) (*sources.QuotaStatus, error) { + return nil, nil + }, + }) +} + +func TestNewQuotaTracker_Defaults(t *testing.T) { + p := NewQuotaTracker(&QuotaTrackerConfig{ + FetchQuota: func(ctx context.Context) (*sources.QuotaStatus, error) { + return nil, fmt.Errorf("test") + }, + Log: slog.Disabled, + }) + if p.reconcileInterval != defaultReconcileInterval { + t.Errorf("expected default reconcile interval %v, got %v", + defaultReconcileInterval, p.reconcileInterval) + } +} + +func TestQuotaTracker_ConsumeCredits(t *testing.T) { + p := newTestPoolWithQuota(t, 100, 100) + p.AddSource() + // Trigger initial reconciliation to seed values. + _ = p.QuotaStatus() + + p.ConsumeCredits(30) + status := p.QuotaStatus() + if status.FetchesRemaining != 70 { + t.Errorf("expected 70 remaining, got %d", status.FetchesRemaining) + } + + p.ConsumeCredits(50) + status = p.QuotaStatus() + if status.FetchesRemaining != 20 { + t.Errorf("expected 20 remaining, got %d", status.FetchesRemaining) + } +} + +func TestQuotaTracker_ConsumeCreditsFloorAtZero(t *testing.T) { + p := newTestPoolWithQuota(t, 10, 100) + p.AddSource() + _ = p.QuotaStatus() + + p.ConsumeCredits(50) // exceeds remaining + status := p.QuotaStatus() + if status.FetchesRemaining != 0 { + t.Errorf("expected 0 remaining (floor), got %d", status.FetchesRemaining) + } +} + +func TestQuotaTracker_DividesAmongSources(t *testing.T) { + p := newTestPoolWithQuota(t, 300, 900) + p.AddSource() + p.AddSource() + p.AddSource() + + status := p.QuotaStatus() + if status.FetchesRemaining != 100 { + t.Errorf("expected 100 remaining per source (300/3), got %d", status.FetchesRemaining) + } + if status.FetchesLimit != 300 { + t.Errorf("expected 300 limit per source (900/3), got %d", status.FetchesLimit) + } +} + +func TestQuotaTracker_ZeroSourcesDefaultsToOne(t *testing.T) { + p := newTestPoolWithQuota(t, 200, 500) + // No AddSource calls. + + status := p.QuotaStatus() + if status.FetchesRemaining != 200 { + t.Errorf("expected 200 remaining (no division), got %d", status.FetchesRemaining) + } +} + +func TestQuotaTracker_ZeroBeforeInitialized(t *testing.T) { + p := newTestPool(t) + p.AddSource() + // Pool's FetchQuota returns error, so reconciliation fails and the + // pool stays uninitialized. + status := p.QuotaStatus() + if status.FetchesRemaining != 0 { + t.Errorf("expected 0 fetches for uninitialized pool, got %d", status.FetchesRemaining) + } + if status.FetchesLimit != 0 { + t.Errorf("expected 0 limit for uninitialized pool, got %d", status.FetchesLimit) + } +} + +func TestQuotaTracker_Reconcile(t *testing.T) { + fetchQuota := func(ctx context.Context) (*sources.QuotaStatus, error) { + return &sources.QuotaStatus{ + FetchesRemaining: 800, + FetchesLimit: 1000, + ResetTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), + }, nil + } + + p := NewQuotaTracker(&QuotaTrackerConfig{ + FetchQuota: fetchQuota, + ReconcileInterval: time.Millisecond, + Log: slog.Disabled, + }) + p.AddSource() + + // First call blocks until reconciliation completes. + status := p.QuotaStatus() + if status.FetchesRemaining != 800 { + t.Errorf("expected 800 remaining after reconcile, got %d", status.FetchesRemaining) + } + if status.FetchesLimit != 1000 { + t.Errorf("expected 1000 limit after reconcile, got %d", status.FetchesLimit) + } +} + +func TestQuotaTracker_ReconcileError(t *testing.T) { + fetchQuota := func(ctx context.Context) (*sources.QuotaStatus, error) { + return nil, fmt.Errorf("network error") + } + + p := NewQuotaTracker(&QuotaTrackerConfig{ + FetchQuota: fetchQuota, + ReconcileInterval: time.Millisecond, + Log: slog.Disabled, + }) + p.AddSource() + + // First call blocks on reconciliation which fails. + // Pool remains uninitialized, so returns zero quota. + status := p.QuotaStatus() + if status.FetchesRemaining != 0 { + t.Errorf("expected 0 after failed reconcile, got %d", status.FetchesRemaining) + } +} + +func TestQuotaTracker_ReconcileSyncsToServer(t *testing.T) { + t.Run("server shows more usage than local", func(t *testing.T) { + // Server reports fewer remaining than our local estimate, + // meaning another consumer used credits. We should sync down. + var call int + fetchQuota := func(ctx context.Context) (*sources.QuotaStatus, error) { + call++ + if call == 1 { + // Initial sync: seed with 1000/1000. + return &sources.QuotaStatus{ + FetchesRemaining: 1000, + FetchesLimit: 1000, + }, nil + } + // Subsequent: server says 700 remaining. + return &sources.QuotaStatus{ + FetchesRemaining: 700, + FetchesLimit: 1000, + }, nil + } + + p := NewQuotaTracker(&QuotaTrackerConfig{ + FetchQuota: fetchQuota, + ReconcileInterval: time.Hour, // prevent auto-reconcile + Log: slog.Disabled, + }) + p.AddSource() + + // Trigger initial sync. + _ = p.QuotaStatus() // call=1, remaining=1000 + + p.ConsumeCredits(200) // local: 800 + + p.reconcile() // call=2, server=700 < local=800, adopt 700 + + status := p.QuotaStatus() + // Server says 700, local says 800. Adopt server (lower = more usage). + if status.FetchesRemaining != 700 { + t.Errorf("expected 700 remaining after reconcile sync, got %d", status.FetchesRemaining) + } + }) + + t.Run("server lags behind local consumption", func(t *testing.T) { + // Server's hits counter is eventually consistent and hasn't + // caught up with our local consumption. Keep local estimate. + var call int + fetchQuota := func(ctx context.Context) (*sources.QuotaStatus, error) { + call++ + if call == 1 { + return &sources.QuotaStatus{ + FetchesRemaining: 1000, + FetchesLimit: 1000, + }, nil + } + return &sources.QuotaStatus{ + FetchesRemaining: 900, + FetchesLimit: 1000, + }, nil + } + + p := NewQuotaTracker(&QuotaTrackerConfig{ + FetchQuota: fetchQuota, + ReconcileInterval: time.Hour, + Log: slog.Disabled, + }) + p.AddSource() + + _ = p.QuotaStatus() // call=1, remaining=1000 + + p.ConsumeCredits(200) // local: 800 + + p.reconcile() // call=2, server=900 > local=800, keep local + + status := p.QuotaStatus() + // Server says 900, local says 800. Keep local (more conservative). + if status.FetchesRemaining != 800 { + t.Errorf("expected 800 remaining (local estimate preserved), got %d", status.FetchesRemaining) + } + }) +} + +// --- TrackedSource tests --- + +func TestNewTrackedSource_RegistersWithTracker(t *testing.T) { + p := newTestPool(t) + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + } + + _ = NewTrackedSource(TrackedSourceConfig{ + Name: "test1", FetchRates: fetchRates, Tracker: p, CreditsPerRequest: 1, + }) + if p.sourceCount != 1 { + t.Errorf("expected sourceCount 1, got %d", p.sourceCount) + } + + _ = NewTrackedSource(TrackedSourceConfig{ + Name: "test2", FetchRates: fetchRates, Tracker: p, CreditsPerRequest: 1, + }) + if p.sourceCount != 2 { + t.Errorf("expected sourceCount 2, got %d", p.sourceCount) + } +} + +func TestNewTrackedSource_PanicsOnMissingName(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for missing name") + } + }() + NewTrackedSource(TrackedSourceConfig{ + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { return nil, nil }, + Tracker: newTestPool(t), + }) +} + +func TestNewTrackedSource_PanicsOnMissingFetchRates(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for missing FetchRates") + } + }() + NewTrackedSource(TrackedSourceConfig{ + Name: "test", + Tracker: newTestPool(t), + }) +} + +func TestNewTrackedSource_PanicsOnMissingTracker(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for missing Tracker") + } + }() + NewTrackedSource(TrackedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { return nil, nil }, + }) +} + +func TestTrackedSource_ConsumesCreditsOnFetch(t *testing.T) { + p := newTestPoolWithQuota(t, 100, 100) + + pooled := NewTrackedSource(TrackedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + Tracker: p, + CreditsPerRequest: 10, + }) + + // Trigger initial reconciliation to seed values. + _ = pooled.QuotaStatus() + + _, err := pooled.FetchRates(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Pool: 100 - 10 = 90 raw credits, divided by creditsPerRequest=10 = 9 fetches. + status := pooled.QuotaStatus() + if status.FetchesRemaining != 9 { + t.Errorf("expected 9 fetches remaining, got %d", status.FetchesRemaining) + } +} + +func TestTrackedSource_NoConsumeOnError(t *testing.T) { + p := newTestPoolWithQuota(t, 100, 100) + + pooled := NewTrackedSource(TrackedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return nil, fmt.Errorf("fetch error") + }, + Tracker: p, + CreditsPerRequest: 10, + }) + + // Trigger initial reconciliation to seed values. + _ = pooled.QuotaStatus() + + _, err := pooled.FetchRates(context.Background()) + if err == nil { + t.Fatal("expected error") + } + + // Credits should not have been consumed. + // Pool: 100 raw credits, divided by creditsPerRequest=10 = 10 fetches. + status := pooled.QuotaStatus() + if status.FetchesRemaining != 10 { + t.Errorf("expected 10 fetches remaining after failed fetch, got %d", status.FetchesRemaining) + } +} + +func TestTrackedSource_FieldAccessors(t *testing.T) { + p := newTestPool(t) + + pooled := NewTrackedSource(TrackedSourceConfig{ + Name: "inner-source", + Weight: 0.75, + MinPeriod: 42 * time.Second, + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + Tracker: p, + CreditsPerRequest: 1, + }) + if pooled.Name() != "inner-source" { + t.Errorf("expected Name() = inner-source, got %s", pooled.Name()) + } + if pooled.Weight() != 0.75 { + t.Errorf("expected Weight() = 0.75, got %f", pooled.Weight()) + } + if pooled.MinPeriod() != 42*time.Second { + t.Errorf("expected MinPeriod() = 42s, got %v", pooled.MinPeriod()) + } +} + +func TestTrackedSource_QuotaStatusFromTracker(t *testing.T) { + p := newTestPoolWithQuota(t, 600, 1200) + + fetchRates := func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + } + + p1 := NewTrackedSource(TrackedSourceConfig{ + Name: "test1", FetchRates: fetchRates, Tracker: p, CreditsPerRequest: 1, + }) + p2 := NewTrackedSource(TrackedSourceConfig{ + Name: "test2", FetchRates: fetchRates, Tracker: p, CreditsPerRequest: 1, + }) + + // 2 sources registered, so each gets 600/2=300 remaining, 1200/2=600 limit. + s1 := p1.QuotaStatus() + s2 := p2.QuotaStatus() + + if s1.FetchesRemaining != 300 { + t.Errorf("p1: expected 300 remaining, got %d", s1.FetchesRemaining) + } + if s2.FetchesLimit != 600 { + t.Errorf("p2: expected 600 limit, got %d", s2.FetchesLimit) + } +} + +func TestTrackedSource_ConcurrentFetches(t *testing.T) { + p := newTestPoolWithQuota(t, 1000, 1000) + + pooled := NewTrackedSource(TrackedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + Tracker: p, + CreditsPerRequest: 1, + }) + + // Trigger initial reconciliation to seed values. + _ = pooled.QuotaStatus() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = pooled.FetchRates(context.Background()) + }() + } + wg.Wait() + + status := pooled.QuotaStatus() + if status.FetchesRemaining != 900 { + t.Errorf("expected 900 remaining after 100 fetches, got %d", status.FetchesRemaining) + } +} diff --git a/oracle/sources/utils/unlimited.go b/oracle/sources/utils/unlimited.go new file mode 100644 index 0000000..cb9e5be --- /dev/null +++ b/oracle/sources/utils/unlimited.go @@ -0,0 +1,66 @@ +package utils + +import ( + "context" + "time" + + "github.com/bisoncraft/mesh/oracle/sources" +) + +// FetchRatesFunc fetches rates. Used by quota-aware wrappers that don't want to +// know about URLs, headers, or parsing details. +type FetchRatesFunc func(ctx context.Context) (*sources.RateInfo, error) + +// UnlimitedSourceConfig configures a source without quota constraints. +type UnlimitedSourceConfig struct { + Name string + Weight float64 + MinPeriod time.Duration + FetchRates FetchRatesFunc +} + +// UnlimitedSource is a source without quota constraints. +type UnlimitedSource struct { + name string + weight float64 + minPeriod time.Duration + fetchRates FetchRatesFunc +} + +// NewUnlimitedSource creates a new unlimited source. +func NewUnlimitedSource(cfg UnlimitedSourceConfig) *UnlimitedSource { + if cfg.Name == "" { + panic("unlimited source: name is required") + } + if cfg.FetchRates == nil { + panic("unlimited source: FetchRates is required") + } + + weight := cfg.Weight + if weight == 0 { + weight = defaultWeight + } + minPeriod := cfg.MinPeriod + if minPeriod == 0 { + minPeriod = defaultMinPeriod + } + + return &UnlimitedSource{ + name: cfg.Name, + weight: weight, + minPeriod: minPeriod, + fetchRates: cfg.FetchRates, + } +} + +func (s *UnlimitedSource) Name() string { return s.name } +func (s *UnlimitedSource) Weight() float64 { return s.weight } +func (s *UnlimitedSource) MinPeriod() time.Duration { return s.minPeriod } + +func (s *UnlimitedSource) FetchRates(ctx context.Context) (*sources.RateInfo, error) { + return s.fetchRates(ctx) +} + +func (s *UnlimitedSource) QuotaStatus() *sources.QuotaStatus { + return UnlimitedQuotaStatus() +} diff --git a/oracle/sources/utils/unlimited_test.go b/oracle/sources/utils/unlimited_test.go new file mode 100644 index 0000000..0811c1a --- /dev/null +++ b/oracle/sources/utils/unlimited_test.go @@ -0,0 +1,180 @@ +package utils + +import ( + "context" + "fmt" + "math" + "testing" + "time" + + "github.com/bisoncraft/mesh/oracle/sources" +) + +func TestNewUnlimitedSource_FullConfig(t *testing.T) { + called := false + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "test", + Weight: 0.5, + MinPeriod: 10 * time.Second, + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + called = true + return &sources.RateInfo{}, nil + }, + }) + + if s.Name() != "test" { + t.Errorf("expected Name() = test, got %s", s.Name()) + } + if s.Weight() != 0.5 { + t.Errorf("expected Weight() = 0.5, got %f", s.Weight()) + } + if s.MinPeriod() != 10*time.Second { + t.Errorf("expected MinPeriod() = 10s, got %v", s.MinPeriod()) + } + + _, err := s.FetchRates(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("FetchRates did not call the underlying function") + } +} + +func TestNewUnlimitedSource_DefaultWeightAndMinPeriod(t *testing.T) { + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "defaults", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + }) + + if s.Weight() != defaultWeight { + t.Errorf("expected default weight %f, got %f", defaultWeight, s.Weight()) + } + if s.MinPeriod() != defaultMinPeriod { + t.Errorf("expected default min period %v, got %v", defaultMinPeriod, s.MinPeriod()) + } +} + +func TestNewUnlimitedSource_EmptyNamePanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for empty name") + } + }() + NewUnlimitedSource(UnlimitedSourceConfig{ + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + }) +} + +func TestNewUnlimitedSource_NilFetchRatesPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for nil FetchRates") + } + }() + NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "test", + }) +} + +func TestUnlimitedSource_QuotaStatusIsUnlimited(t *testing.T) { + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + }) + + status := s.QuotaStatus() + if status == nil { + t.Fatal("expected non-nil QuotaStatus") + } + if status.FetchesRemaining != math.MaxInt64 { + t.Errorf("expected unlimited fetches, got %d", status.FetchesRemaining) + } + if status.FetchesLimit != math.MaxInt64 { + t.Errorf("expected unlimited fetches limit, got %d", status.FetchesLimit) + } + if status.ResetTime.IsZero() { + t.Error("expected non-zero reset time") + } +} + +func TestUnlimitedSource_FetchRatesPropagatesError(t *testing.T) { + fetchErr := fmt.Errorf("upstream failure") + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return nil, fetchErr + }, + }) + + _, err := s.FetchRates(context.Background()) + if err != fetchErr { + t.Errorf("expected fetchErr, got %v", err) + } +} + +func TestUnlimitedSource_FetchRatesReturnsPrices(t *testing.T) { + expected := &sources.RateInfo{ + Prices: []*sources.PriceUpdate{ + {Ticker: "BTC", Price: 50000}, + {Ticker: "ETH", Price: 3000}, + }, + } + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return expected, nil + }, + }) + + result, err := s.FetchRates(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.Prices) != 2 { + t.Fatalf("expected 2 prices, got %d", len(result.Prices)) + } + if result.Prices[0].Ticker != "BTC" { + t.Errorf("expected BTC, got %s", result.Prices[0].Ticker) + } +} + +func TestUnlimitedSource_FetchRatesRespectsContext(t *testing.T) { + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return &sources.RateInfo{}, nil + } + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, err := s.FetchRates(ctx) + if err == nil { + t.Error("expected error from cancelled context") + } +} + +func TestUnlimitedSource_ImplementsSourceInterface(t *testing.T) { + s := NewUnlimitedSource(UnlimitedSourceConfig{ + Name: "iface-test", + FetchRates: func(ctx context.Context) (*sources.RateInfo, error) { + return &sources.RateInfo{}, nil + }, + }) + + // Compile-time check: *UnlimitedSource must satisfy sources.Source. + var _ sources.Source = s +} diff --git a/oracle/sources_test.go b/oracle/sources_test.go deleted file mode 100644 index b97c480..0000000 --- a/oracle/sources_test.go +++ /dev/null @@ -1,1061 +0,0 @@ -package oracle - -import ( - "bytes" - "context" - "io" - "math/big" - "net/http" - "strings" - "testing" -) - -// tHTTPClient implements HTTPClient for testing. -type tHTTPClient struct { - response *http.Response - err error -} - -func (tc *tHTTPClient) Do(*http.Request) (*http.Response, error) { - return tc.response, tc.err -} - -// newMockResponse creates a mock HTTP response with the given body. -func newMockResponse(body string) *http.Response { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(body)), - } -} - -// testHTTPSource tests an httpSource with a mock response. -func testHTTPSource(t *testing.T, src *httpSource, mockBody string) (divination, error) { - t.Helper() - client := &tHTTPClient{response: newMockResponse(mockBody)} - return src.fetch(context.Background(), client) -} - -func TestDcrdataParser(t *testing.T) { - src := &httpSource{ - name: "dcrdata", - url: "https://explorer.dcrdata.org/insight/api/utils/estimatefee?nbBlocks=2", - parse: dcrdataParser, - } - - t.Run("valid response", func(t *testing.T) { - // Fee rate in DCR/kB, 0.0001 DCR/kB = 10 atoms/byte - result, err := testHTTPSource(t, src, `{"2": 0.0001}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if len(updates) != 1 { - t.Fatalf("expected 1 update, got %d", len(updates)) - } - - if updates[0].network != "DCR" { - t.Errorf("expected network DCR, got %s", updates[0].network) - } - - // 0.0001 DCR/kB * 1e5 = 10 atoms/byte - if updates[0].feeRate.Cmp(big.NewInt(10)) != 0 { - t.Errorf("expected fee rate 10, got %s", updates[0].feeRate.String()) - } - }) - - t.Run("higher fee rate", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"2": 0.00025}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*feeRateUpdate) - // 0.00025 * 1e5 = 25 - if updates[0].feeRate.Cmp(big.NewInt(25)) != 0 { - t.Errorf("expected fee rate 25, got %s", updates[0].feeRate) - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"2": 0}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - }) - - t.Run("empty response", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{}`) - if err == nil { - t.Error("expected error for empty response") - } - }) - - t.Run("wrong key", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"3": 0.0001}`) - if err == nil { - t.Error("expected error for wrong key") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{invalid}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestMempoolDotSpaceParser(t *testing.T) { - src := &httpSource{ - name: "btc.mempooldotspace", - url: "https://mempool.space/api/v1/fees/recommended", - parse: mempoolDotSpaceParser, - } - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"fastestFee": 25, "halfHourFee": 20, "hourFee": 15, "economyFee": 10, "minimumFee": 5}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if len(updates) != 1 { - t.Fatalf("expected 1 update, got %d", len(updates)) - } - - if updates[0].network != "BTC" { - t.Errorf("expected network BTC, got %s", updates[0].network) - } - - if updates[0].feeRate.Cmp(big.NewInt(25)) != 0 { - t.Errorf("expected fee rate 25, got %s", updates[0].feeRate) - } - }) - - t.Run("high fee environment", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"fastestFee": 150}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*feeRateUpdate) - if updates[0].feeRate.Cmp(big.NewInt(150)) != 0 { - t.Errorf("expected fee rate 150, got %s", updates[0].feeRate) - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"fastestFee": 0}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - }) - - t.Run("missing field", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"halfHourFee": 20}`) - if err == nil { - t.Error("expected error for missing fastestFee") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `not json`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestCoinpaprikaParser(t *testing.T) { - src := &httpSource{ - name: "coinpaprika", - url: "https://api.coinpaprika.com/v1/tickers", - parse: coinpaprikaParser, - } - - t.Run("valid response", func(t *testing.T) { - body := `[ - {"id":"btc-bitcoin","symbol":"BTC","quotes":{"USD":{"price":87838.55}}}, - {"id":"eth-ethereum","symbol":"ETH","quotes":{"USD":{"price":2954.14}}}, - {"id":"ltc-litecoin","symbol":"LTC","quotes":{"USD":{"price":77.22}}} - ]` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*priceUpdate) - if !ok { - t.Fatalf("expected []*priceUpdate, got %T", result) - } - - if len(updates) != 3 { - t.Fatalf("expected 3 updates, got %d", len(updates)) - } - - prices := make(map[Ticker]float64) - for _, u := range updates { - prices[u.ticker] = u.price - } - - if prices["BTC"] != 87838.55 { - t.Errorf("expected BTC price 87838.55, got %f", prices["BTC"]) - } - if prices["ETH"] != 2954.14 { - t.Errorf("expected ETH price 2954.14, got %f", prices["ETH"]) - } - if prices["LTC"] != 77.22 { - t.Errorf("expected LTC price 77.22, got %f", prices["LTC"]) - } - }) - - t.Run("handles duplicate symbols", func(t *testing.T) { - body := `[ - {"id":"btc-bitcoin","symbol":"BTC","quotes":{"USD":{"price":50000.0}}}, - {"id":"btc-other","symbol":"BTC","quotes":{"USD":{"price":51000.0}}}, - {"id":"eth-ethereum","symbol":"ETH","quotes":{"USD":{"price":3000.0}}} - ]` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*priceUpdate) - if len(updates) != 2 { - t.Errorf("expected 2 updates after deduplication, got %d", len(updates)) - } - - // First BTC should be kept - for _, u := range updates { - if u.ticker == "BTC" && u.price != 50000.0 { - t.Errorf("expected first BTC price 50000.0, got %f", u.price) - } - } - }) - - t.Run("empty array", func(t *testing.T) { - result, err := testHTTPSource(t, src, `[]`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*priceUpdate) - if len(updates) != 0 { - t.Errorf("expected 0 updates, got %d", len(updates)) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{invalid}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestCoinmarketcapParser(t *testing.T) { - src := coinmarketcapSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - body := `{ - "data": [ - {"symbol":"BTC","quote":{"USD":{"price":90000.50}}}, - {"symbol":"ETH","quote":{"USD":{"price":3100.25}}}, - {"symbol":"DCR","quote":{"USD":{"price":15.75}}} - ] - }` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*priceUpdate) - if !ok { - t.Fatalf("expected []*priceUpdate, got %T", result) - } - - if len(updates) != 3 { - t.Fatalf("expected 3 updates, got %d", len(updates)) - } - - prices := make(map[Ticker]float64) - for _, u := range updates { - prices[u.ticker] = u.price - } - - if prices["BTC"] != 90000.50 { - t.Errorf("expected BTC price 90000.50, got %f", prices["BTC"]) - } - if prices["ETH"] != 3100.25 { - t.Errorf("expected ETH price 3100.25, got %f", prices["ETH"]) - } - if prices["DCR"] != 15.75 { - t.Errorf("expected DCR price 15.75, got %f", prices["DCR"]) - } - }) - - t.Run("handles duplicate symbols", func(t *testing.T) { - body := `{ - "data": [ - {"symbol":"BTC","quote":{"USD":{"price":90000.0}}}, - {"symbol":"BTC","quote":{"USD":{"price":91000.0}}} - ] - }` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*priceUpdate) - if len(updates) != 1 { - t.Errorf("expected 1 update after deduplication, got %d", len(updates)) - } - }) - - t.Run("empty data array", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"data": []}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*priceUpdate) - if len(updates) != 0 { - t.Errorf("expected 0 updates, got %d", len(updates)) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `not json`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) - - t.Run("verifies headers are set", func(t *testing.T) { - if len(src.headers) == 0 { - t.Error("expected headers to be set") - } - found := false - for _, h := range src.headers { - if keys, ok := h["X-CMC_PRO_API_KEY"]; ok { - if len(keys) > 0 && keys[0] == "test-api-key" { - found = true - } - } - } - if !found { - t.Error("expected X-CMC_PRO_API_KEY header") - } - }) -} - -func TestBitcoreBitcoinCashParser(t *testing.T) { - src := &httpSource{ - name: "bch.bitcore", - url: "https://api.bitcore.io/api/BCH/mainnet/fee/2", - parse: bitcoreBitcoinCashParser, - } - - t.Run("valid response", func(t *testing.T) { - // Fee rate in BCH/kB - result, err := testHTTPSource(t, src, `{"feerate": 0.00001}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if len(updates) != 1 { - t.Fatalf("expected 1 update, got %d", len(updates)) - } - - if updates[0].network != "BCH" { - t.Errorf("expected network BCH, got %s", updates[0].network) - } - - // 0.00001 BCH/kB * 1e5 = 1 sat/byte - if updates[0].feeRate.Cmp(big.NewInt(1)) != 0 { - t.Errorf("expected fee rate 1, got %s", updates[0].feeRate) - } - }) - - t.Run("higher fee rate", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"feerate": 0.0001}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*feeRateUpdate) - // 0.0001 * 1e5 = 10 - if updates[0].feeRate.Cmp(big.NewInt(10)) != 0 { - t.Errorf("expected fee rate 10, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{bad json}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestBitcoreDogecoinParser(t *testing.T) { - src := &httpSource{ - name: "doge.bitcore", - url: "https://api.bitcore.io/api/DOGE/mainnet/fee/2", - parse: bitcoreDogecoinParser, - } - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"feerate": 0.01}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "DOGE" { - t.Errorf("expected network DOGE, got %s", updates[0].network) - } - - // 0.01 DOGE/kB * 1e5 = 1000 sat/byte - if updates[0].feeRate.Cmp(big.NewInt(1000)) != 0 { - t.Errorf("expected fee rate 1000, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `invalid`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestBitcoreLitecoinParser(t *testing.T) { - src := &httpSource{ - name: "ltc.bitcore", - url: "https://api.bitcore.io/api/LTC/mainnet/fee/2", - parse: bitcoreLitecoinParser, - } - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"feerate": 0.0001}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "LTC" { - t.Errorf("expected network LTC, got %s", updates[0].network) - } - - // 0.0001 LTC/kB * 1e5 = 10 sat/byte - if updates[0].feeRate.Cmp(big.NewInt(10)) != 0 { - t.Errorf("expected fee rate 10, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `[]`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestFiroOrgParser(t *testing.T) { - src := &httpSource{ - name: "firo.org", - url: "https://explorer.firo.org/insight-api-zcoin/utils/estimatefee", - parse: firoOrgParser, - } - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"2": 0.0001}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if len(updates) != 1 { - t.Fatalf("expected 1 update, got %d", len(updates)) - } - - if updates[0].network != "FIRO" { - t.Errorf("expected network FIRO, got %s", updates[0].network) - } - - // 0.0001 FIRO/kB * 1e5 = 10 sat/byte - if updates[0].feeRate.Cmp(big.NewInt(10)) != 0 { - t.Errorf("expected fee rate 10, got %s", updates[0].feeRate) - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"2": 0}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - }) - - t.Run("empty response", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{}`) - if err == nil { - t.Error("expected error for empty response") - } - }) - - t.Run("wrong key", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"1": 0.0001}`) - if err == nil { - t.Error("expected error for wrong key") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `not json`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestBlockcypherLitecoinParser(t *testing.T) { - src := &httpSource{ - name: "ltc.blockcypher", - url: "https://api.blockcypher.com/v1/ltc/main", - parse: blockcypherLitecoinParser, - } - - t.Run("valid response", func(t *testing.T) { - body := `{ - "name": "LTC.main", - "height": 2500000, - "low_fee_per_kb": 10000, - "medium_fee_per_kb": 25000, - "high_fee_per_kb": 50000 - }` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if len(updates) != 1 { - t.Fatalf("expected 1 update, got %d", len(updates)) - } - - if updates[0].network != "LTC" { - t.Errorf("expected network LTC, got %s", updates[0].network) - } - - // medium_fee_per_kb 25000 * 1e5 = 2500000000 (this seems wrong in the parser) - // Actually the parser does: res.Medium * 1e5, so 25000 * 1e5 = 2500000000 - // Let me check the actual parser logic... - // The response is in satoshis/kB already, so we should just use it directly - // But the parser multiplies by 1e5, which suggests the API returns in coins/kB - // Let me use a more realistic value - }) - - t.Run("realistic response", func(t *testing.T) { - // Blockcypher returns fees in satoshis/kB - body := `{"medium_fee_per_kb": 10000}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*feeRateUpdate) - // 10000 sat/kB * 1e5 = 1000000000 - this seems like a bug in the parser - // The parser treats the value as coins/kB but blockcypher returns sat/kB - // For now, test what the parser actually does - expected := big.NewInt(int64(10000 * 1e5)) - if updates[0].feeRate.Cmp(expected) != 0 { - t.Errorf("expected fee rate %s, got %s", expected.String(), updates[0].feeRate.String()) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{invalid}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestTatumBitcoinParser(t *testing.T) { - src := tatumBitcoinSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"fast": 25, "medium": 15, "slow": 5}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if len(updates) != 1 { - t.Fatalf("expected 1 update, got %d", len(updates)) - } - - if updates[0].network != "BTC" { - t.Errorf("expected network BTC, got %s", updates[0].network) - } - - if updates[0].feeRate.Cmp(big.NewInt(25)) != 0 { - t.Errorf("expected fee rate 25, got %s", updates[0].feeRate) - } - }) - - t.Run("high fee environment", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"fast": 150}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates := result.([]*feeRateUpdate) - if updates[0].feeRate.Cmp(big.NewInt(150)) != 0 { - t.Errorf("expected fee rate 150, got %s", updates[0].feeRate) - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"fast": 0}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - if !strings.Contains(err.Error(), "fee rate cannot be negative or zero") { - t.Errorf("expected 'fee rate cannot be negative or zero' error, got: %v", err) - } - }) - - t.Run("negative fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"fast": -5}`) - if err == nil { - t.Error("expected error for negative fee rate") - } - }) - - t.Run("missing fast field", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"medium": 15}`) - if err == nil { - t.Error("expected error for missing fast field") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{invalid}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) - - t.Run("verifies headers are set", func(t *testing.T) { - found := false - for _, h := range src.headers { - if keys, ok := h["x-api-key"]; ok { - if len(keys) > 0 && keys[0] == "test-api-key" { - found = true - } - } - } - if !found { - t.Error("expected x-api-key header") - } - }) -} - -func TestTatumLitecoinParser(t *testing.T) { - src := tatumLitecoinSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"fast": 42}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "LTC" { - t.Errorf("expected network LTC, got %s", updates[0].network) - } - - if updates[0].feeRate.Cmp(big.NewInt(42)) != 0 { - t.Errorf("expected fee rate 42, got %s", updates[0].feeRate) - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"fast": 0}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `not json`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestTatumDogecoinParser(t *testing.T) { - src := tatumDogecoinSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - result, err := testHTTPSource(t, src, `{"fast": 1000}`) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "DOGE" { - t.Errorf("expected network DOGE, got %s", updates[0].network) - } - - if updates[0].feeRate.Cmp(big.NewInt(1000)) != 0 { - t.Errorf("expected fee rate 1000, got %s", updates[0].feeRate) - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"fast": 0}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"fast": "not a number"}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestCryptoApisBitcoinParser(t *testing.T) { - src := cryptoApisBitcoinSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - body := `{"data":{"item":{"fast":"0.000025","standard":"0.000015","slow":"0.000010"}}}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "BTC" { - t.Errorf("expected network BTC, got %s", updates[0].network) - } - - // 0.000025 BTC/byte * 1e8 = 2500 satoshis/byte - if updates[0].feeRate.Cmp(big.NewInt(2500)) != 0 { - t.Errorf("expected fee rate 2500, got %s", updates[0].feeRate) - } - }) - - t.Run("various fee rates", func(t *testing.T) { - testCases := []struct { - input string - expected int64 - }{ - {"0.000010", 1000}, - {"0.000050", 5000}, - {"0.0001", 10000}, - {"0.00000001", 1}, - } - - for _, tc := range testCases { - body := `{"data":{"item":{"fast":"` + tc.input + `"}}}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed for input %s: %v", tc.input, err) - } - - updates := result.([]*feeRateUpdate) - expectedBigInt := big.NewInt(tc.expected) - if updates[0].feeRate.Cmp(expectedBigInt) != 0 { - t.Errorf("input %s: expected fee rate %s, got %s", tc.input, expectedBigInt.String(), updates[0].feeRate.String()) - } - } - }) - - t.Run("zero fee rate", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"data":{"item":{"fast":"0"}}}`) - if err == nil { - t.Error("expected error for zero fee rate") - } - }) - - t.Run("invalid fee rate string", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{"data":{"item":{"fast":"not a number"}}}`) - if err == nil { - t.Error("expected error for invalid fee rate") - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{invalid}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) - - t.Run("verifies headers are set", func(t *testing.T) { - found := false - for _, h := range src.headers { - if keys, ok := h["X-API-Key"]; ok { - if len(keys) > 0 && keys[0] == "test-api-key" { - found = true - } - } - } - if !found { - t.Error("expected X-API-Key header") - } - }) -} - -func TestCryptoApisBitcoinCashParser(t *testing.T) { - src := cryptoApisBitcoinCashSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - body := `{"data":{"item":{"fast":"0.00001"}}}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "BCH" { - t.Errorf("expected network BCH, got %s", updates[0].network) - } - - // 0.00001 BCH/byte * 1e8 = 1000 satoshis/byte - if updates[0].feeRate.Cmp(big.NewInt(1000)) != 0 { - t.Errorf("expected fee rate 1000, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `invalid`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestCryptoApisDogecoinParser(t *testing.T) { - src := cryptoApisDogecoinSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - body := `{"data":{"item":{"fast":"0.001"}}}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "DOGE" { - t.Errorf("expected network DOGE, got %s", updates[0].network) - } - - // 0.001 DOGE/byte * 1e8 = 100000 satoshis/byte - if updates[0].feeRate.Cmp(big.NewInt(100000)) != 0 { - t.Errorf("expected fee rate 100000, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `[]`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestCryptoApisDashParser(t *testing.T) { - src := cryptoApisDashSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - body := `{"data":{"item":{"fast":"0.0001"}}}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "DASH" { - t.Errorf("expected network DASH, got %s", updates[0].network) - } - - // 0.0001 DASH/byte * 1e8 = 10000 satoshis/byte - if updates[0].feeRate.Cmp(big.NewInt(10000)) != 0 { - t.Errorf("expected fee rate 10000, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `{bad}`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestCryptoApisLitecoinParser(t *testing.T) { - src := cryptoApisLitecoinSource("test-api-key") - - t.Run("valid response", func(t *testing.T) { - body := `{"data":{"item":{"fast":"0.00001"}}}` - result, err := testHTTPSource(t, src, body) - if err != nil { - t.Fatalf("fetch failed: %v", err) - } - - updates, ok := result.([]*feeRateUpdate) - if !ok { - t.Fatalf("expected []*feeRateUpdate, got %T", result) - } - - if updates[0].network != "LTC" { - t.Errorf("expected network LTC, got %s", updates[0].network) - } - - // 0.00001 LTC/byte * 1e8 = 1000 satoshis/byte - if updates[0].feeRate.Cmp(big.NewInt(1000)) != 0 { - t.Errorf("expected fee rate 1000, got %s", updates[0].feeRate) - } - }) - - t.Run("malformed JSON", func(t *testing.T) { - _, err := testHTTPSource(t, src, `not valid`) - if err == nil { - t.Error("expected error for malformed JSON") - } - }) -} - -func TestSetHTTPSourceDefaults(t *testing.T) { - t.Run("sets default values", func(t *testing.T) { - sources := []*httpSource{ - {name: "test"}, - } - err := setHTTPSourceDefaults(sources) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if sources[0].weight != 1.0 { - t.Errorf("expected default weight 1.0, got %f", sources[0].weight) - } - if sources[0].period != 5*60*1e9 { // 5 minutes in nanoseconds - t.Errorf("expected default period 5m, got %v", sources[0].period) - } - if sources[0].errPeriod != 60*1e9 { // 1 minute in nanoseconds - t.Errorf("expected default errPeriod 1m, got %v", sources[0].errPeriod) - } - }) - - t.Run("preserves custom values", func(t *testing.T) { - sources := []*httpSource{ - {name: "test", weight: 0.5, period: 10 * 60 * 1e9, errPeriod: 30 * 1e9}, - } - err := setHTTPSourceDefaults(sources) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if sources[0].weight != 0.5 { - t.Errorf("expected weight 0.5, got %f", sources[0].weight) - } - if sources[0].period != 10*60*1e9 { - t.Errorf("expected period 10m, got %v", sources[0].period) - } - if sources[0].errPeriod != 30*1e9 { - t.Errorf("expected errPeriod 30s, got %v", sources[0].errPeriod) - } - }) - - t.Run("returns error on negative weight", func(t *testing.T) { - sources := []*httpSource{ - {name: "test", weight: -0.5}, - } - err := setHTTPSourceDefaults(sources) - if err == nil { - t.Error("expected error for negative weight") - } - if !strings.Contains(err.Error(), "negative weight") { - t.Errorf("expected 'negative weight' in error, got: %v", err) - } - }) - - t.Run("returns error on weight > 1", func(t *testing.T) { - sources := []*httpSource{ - {name: "test", weight: 1.5}, - } - err := setHTTPSourceDefaults(sources) - if err == nil { - t.Error("expected error for weight > 1") - } - if !strings.Contains(err.Error(), "weight > 1") { - t.Errorf("expected 'weight > 1' in error, got: %v", err) - } - }) -} diff --git a/protocols/protocols.go b/protocols/protocols.go index 178475a..888b25f 100644 --- a/protocols/protocols.go +++ b/protocols/protocols.go @@ -34,3 +34,10 @@ const ( // tatanka nodes in the mesh that this node is connected to. AvailableMeshNodesProtocol = "/tatanka/available-mesh-nodes/1.0.0" ) + +var ( + // PriceTopicPrefix is the prefix for price topics. + PriceTopicPrefix = "price." + // FeeRateTopicPrefix is the prefix for fee rate topics. + FeeRateTopicPrefix = "fee_rate." +) diff --git a/tatanka/admin/server.go b/tatanka/admin/server.go index 797b5b8..7c89d7c 100644 --- a/tatanka/admin/server.go +++ b/tatanka/admin/server.go @@ -10,6 +10,7 @@ import ( "github.com/decred/slog" "github.com/gorilla/websocket" "github.com/libp2p/go-libp2p/core/peer" + "github.com/bisoncraft/mesh/oracle" ) // NodeConnectionState defines the status of a peer connection @@ -47,10 +48,16 @@ func (s AdminState) DeepCopy() AdminState { return newState } +// WSMessage is the envelope for all WebSocket messages. +type WSMessage struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + // Client represents a connected WebSocket user. type Client struct { conn *websocket.Conn - send chan AdminState + send chan WSMessage } // Server manages the admin server for a tatanka node. @@ -64,11 +71,18 @@ type Server struct { clientsMtx sync.RWMutex clients map[*Client]bool + + oracle Oracle } -// NewServer initializes the admin server -func NewServer(log slog.Logger, addr string) *Server { - return &Server{ +// Oracle supplies data for admin oracle endpoints. +type Oracle interface { + OracleSnapshot() *oracle.OracleSnapshot +} + +// NewServer initializes the admin server. +func NewServer(log slog.Logger, addr string, oracle Oracle) *Server { + server := &Server{ log: log, upgrader: websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}, clients: make(map[*Client]bool), @@ -77,7 +91,10 @@ func NewServer(log slog.Logger, addr string) *Server { OurWhitelist: []string{}, }, httpServer: &http.Server{Addr: addr}, + oracle: oracle, } + + return server } // Start launches the HTTP server @@ -127,14 +144,42 @@ func (s *Server) UpdateWhitelist(whitelist []string) { s.broadcastState(snapshot) } -// broadcastState sends the state to all clients. +// broadcastState sends the admin state to all clients. func (s *Server) broadcastState(state AdminState) { + data, err := json.Marshal(state) + if err != nil { + s.log.Errorf("Failed to marshal admin state: %v", err) + return + } + msg := WSMessage{ + Type: "admin_state", + Data: json.RawMessage(data), + } + s.broadcast(msg) +} + +// BroadcastOracleUpdate broadcasts a typed oracle update to all connected clients. +func (s *Server) BroadcastOracleUpdate(msgType string, snapshotDiff *oracle.OracleSnapshot) { + data, err := json.Marshal(snapshotDiff) + if err != nil { + s.log.Errorf("Failed to marshal oracle update (%s): %v", msgType, err) + return + } + msg := WSMessage{ + Type: msgType, + Data: json.RawMessage(data), + } + s.broadcast(msg) +} + +// broadcast sends a WSMessage to all connected clients. +func (s *Server) broadcast(msg WSMessage) { s.clientsMtx.RLock() defer s.clientsMtx.RUnlock() for client := range s.clients { select { - case client.send <- state: + case client.send <- msg: default: s.log.Errorf("Client buffer full, skipping update") } @@ -164,27 +209,42 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { client := &Client{ conn: conn, - send: make(chan AdminState, 10), + send: make(chan WSMessage, 10), } s.clientsMtx.Lock() s.clients[client] = true s.clientsMtx.Unlock() - // Send initial state immediately + // Send initial admin state s.stateMtx.RLock() initialState := s.state.DeepCopy() s.stateMtx.RUnlock() - select { - case client.send <- initialState: - default: + stateData, err := json.Marshal(initialState) + if err == nil { + select { + case client.send <- WSMessage{Type: "admin_state", Data: json.RawMessage(stateData)}: + default: + } + } + + // Send oracle snapshot + snapshot := s.oracle.OracleSnapshot() + if snapshot != nil { + snapshotData, err := json.Marshal(snapshot) + if err == nil { + select { + case client.send <- WSMessage{Type: "oracle_snapshot", Data: json.RawMessage(snapshotData)}: + default: + } + } } // 1. Writer Goroutine go func() { defer conn.Close() - for state := range client.send { - if err := conn.WriteJSON(state); err != nil { + for msg := range client.send { + if err := conn.WriteJSON(msg); err != nil { return } } diff --git a/tatanka/gossipsub.go b/tatanka/gossipsub.go index 0178cfb..641f6b6 100644 --- a/tatanka/gossipsub.go +++ b/tatanka/gossipsub.go @@ -10,6 +10,8 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/oracle/sources" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" pb "github.com/bisoncraft/mesh/tatanka/pb" "golang.org/x/sync/errgroup" @@ -28,6 +30,10 @@ const ( // oracleUpdatesTopicName is the name of the pubsub topic used to // propagate oracle updates between tatanka nodes. oracleUpdatesTopicName = "oracle_updates" + + // quotaHeartbeatTopicName is the name of the pubsub topic used to + // periodically share quota information between tatanka nodes. + quotaHeartbeatTopicName = "quota_heartbeat" ) type clientConnectionUpdate struct { @@ -71,7 +77,8 @@ type gossipSubCfg struct { getWhitelistPeers func() map[peer.ID]struct{} handleBroadcastMessage func(msg *protocolsPb.PushMessage) handleClientConnectionMessage func(update *clientConnectionUpdate) - handleOracleUpdate func(update *pb.NodeOracleUpdate) + handleOracleUpdate func(senderID peer.ID, update *pb.NodeOracleUpdate) + handleQuotaHeartbeat func(senderID peer.ID, heartbeat *pb.QuotaHandshake) } // gossipSub manages the nodes connection to a gossip sub network between tatanka @@ -84,6 +91,7 @@ type gossipSub struct { clientMessageTopic *pubsub.Topic clientConnectionsTopic *pubsub.Topic oracleUpdatesTopic *pubsub.Topic + quotaHeartbeatTopic *pubsub.Topic zstdEncoder *zstd.Encoder zstdDecoder *zstd.Decoder } @@ -122,6 +130,11 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { return nil, fmt.Errorf("failed to join oracle updates topic: %w", err) } + quotaHeartbeatTopic, err := ps.Join(quotaHeartbeatTopicName) + if err != nil { + return nil, fmt.Errorf("failed to join quota heartbeat topic: %w", err) + } + zstdEncoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) if err != nil { return nil, fmt.Errorf("failed to create zstd encoder: %w", err) @@ -139,6 +152,7 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { clientMessageTopic: clientMessageTopic, clientConnectionsTopic: clientConnectionsTopic, oracleUpdatesTopic: oracleUpdatesTopic, + quotaHeartbeatTopic: quotaHeartbeatTopic, zstdEncoder: zstdEncoder, zstdDecoder: zstdDecoder, }, nil @@ -235,7 +249,7 @@ func (gs *gossipSub) listenForOracleUpdates(ctx context.Context) error { continue } - gs.cfg.handleOracleUpdate(oracleUpdate) + gs.cfg.handleOracleUpdate(msg.GetFrom(), oracleUpdate) } } } @@ -257,8 +271,13 @@ func (gs *gossipSub) publishClientConnectionMessage(ctx context.Context, msg *cl return gs.clientConnectionsTopic.Publish(ctx, data) } -func (gs *gossipSub) publishOracleUpdate(ctx context.Context, update *pb.NodeOracleUpdate) error { - data, err := proto.Marshal(update) +func (gs *gossipSub) publishOracleUpdate(ctx context.Context, update *oracle.OracleUpdate) error { + pbUpdate, err := oracleUpdateToPb(update) + if err != nil { + return err + } + + data, err := proto.Marshal(pbUpdate) if err != nil { return fmt.Errorf("failed to marshal oracle update: %w", err) } @@ -268,6 +287,43 @@ func (gs *gossipSub) publishOracleUpdate(ctx context.Context, update *pb.NodeOra return gs.oracleUpdatesTopic.Publish(ctx, compressed) } +func (gs *gossipSub) listenForQuotaHeartbeats(ctx context.Context) error { + sub, err := gs.quotaHeartbeatTopic.Subscribe() + if err != nil { + return fmt.Errorf("failed to subscribe to quota heartbeat topic: %w", err) + } + + for { + msg, err := sub.Next(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + + if msg != nil && gs.cfg.handleQuotaHeartbeat != nil { + heartbeat := &pb.QuotaHandshake{} + if err := proto.Unmarshal(msg.Data, heartbeat); err != nil { + gs.log.Errorf("Failed to unmarshal quota heartbeat: %v", err) + continue + } + gs.cfg.handleQuotaHeartbeat(msg.GetFrom(), heartbeat) + } + } +} + +func (gs *gossipSub) publishQuotaHeartbeat(ctx context.Context, quotas map[string]*sources.QuotaStatus) error { + heartbeat := &pb.QuotaHandshake{ + Quotas: quotaStatusesToPb(quotas), + } + data, err := proto.Marshal(heartbeat) + if err != nil { + return fmt.Errorf("failed to marshal quota heartbeat: %w", err) + } + return gs.quotaHeartbeatTopic.Publish(ctx, data) +} + func (gs *gossipSub) run(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) @@ -289,5 +345,11 @@ func (gs *gossipSub) run(ctx context.Context) error { return err }) + g.Go(func() error { + err := gs.listenForQuotaHeartbeats(ctx) + gs.log.Debug("Quota heartbeat listener stopped.") + return err + }) + return g.Wait() } diff --git a/tatanka/handlers.go b/tatanka/handlers.go index a3f517e..2758a34 100644 --- a/tatanka/handlers.go +++ b/tatanka/handlers.go @@ -2,7 +2,6 @@ package tatanka import ( "context" - "fmt" "math/big" "strings" "time" @@ -15,48 +14,10 @@ import ( "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/bisoncraft/mesh/tatanka/pb" - ma "github.com/multiformats/go-multiaddr" "google.golang.org/protobuf/proto" ) -const ( - defaultTimeout = time.Second * 30 -) - -// libp2pPeerInfoToPb converts a peer.AddrInfo to a protocolsPb.PeerInfo. -func libp2pPeerInfoToPb(peerInfo peer.AddrInfo) *protocolsPb.PeerInfo { - addrBytes := make([][]byte, len(peerInfo.Addrs)) - for i, addr := range peerInfo.Addrs { - addrBytes[i] = addr.Bytes() - } - - return &protocolsPb.PeerInfo{ - Id: []byte(peerInfo.ID), - Addrs: addrBytes, - } -} - -// pbPeerInfoToLibp2p converts a protocolsPb.PeerInfo to a peer.AddrInfo. -func pbPeerInfoToLibp2p(pbPeer *protocolsPb.PeerInfo) (peer.AddrInfo, error) { - peerID, err := peer.IDFromBytes(pbPeer.Id) - if err != nil { - return peer.AddrInfo{}, fmt.Errorf("failed to parse peer ID: %w", err) - } - - addrs := make([]ma.Multiaddr, 0, len(pbPeer.Addrs)) - for _, addrBytes := range pbPeer.Addrs { - addr, err := ma.NewMultiaddrBytes(addrBytes) - if err != nil { - return peer.AddrInfo{}, fmt.Errorf("failed to parse multiaddr: %w", err) - } - addrs = append(addrs, addr) - } - - return peer.AddrInfo{ - ID: peerID, - Addrs: addrs, - }, nil -} +const defaultTimeout = time.Second * 30 // handleClientPush is called when the client opens a push stream to the node. func (t *TatankaNode) handleClientPush(s network.Stream) { @@ -125,9 +86,9 @@ func (t *TatankaNode) handleClientSubscribe(s network.Stream) { // Update the subscribing client immediately if subscribing for oracle updates. // Check for prefixed price or fee rate topics. - if strings.HasPrefix(subscribeMessage.Topic, oracle.PriceTopicPrefix) { + if strings.HasPrefix(subscribeMessage.Topic, protocols.PriceTopicPrefix) { t.sendCurrentOracleUpdate(client, subscribeMessage.Topic) - } else if strings.HasPrefix(subscribeMessage.Topic, oracle.FeeRateTopicPrefix) { + } else if strings.HasPrefix(subscribeMessage.Topic, protocols.FeeRateTopicPrefix) { t.sendCurrentOracleUpdate(client, subscribeMessage.Topic) } } @@ -145,19 +106,18 @@ func (t *TatankaNode) sendCurrentOracleUpdate(client peer.ID, topic string) { var err error // Check for prefixed price subscription. - if strings.HasPrefix(topic, oracle.PriceTopicPrefix) { - ticker := topic[len(oracle.PriceTopicPrefix):] - prices := t.oracle.Prices() - if price, ok := prices[oracle.Ticker(ticker)]; ok { + if strings.HasPrefix(topic, protocols.PriceTopicPrefix) { + ticker := topic[len(protocols.PriceTopicPrefix):] + if price, ok := t.oracle.Price(oracle.Ticker(ticker)); ok { clientUpdate := &protocolsPb.ClientPriceUpdate{ Price: price, } data, err = proto.Marshal(clientUpdate) } - } else if strings.HasPrefix(topic, oracle.FeeRateTopicPrefix) { + } else if strings.HasPrefix(topic, protocols.FeeRateTopicPrefix) { // Check for prefixed fee rate subscription. - network := topic[len(oracle.FeeRateTopicPrefix):] - if feeRate, ok := t.oracle.FeeRates()[oracle.Network(network)]; ok { + network := topic[len(protocols.FeeRateTopicPrefix):] + if feeRate, ok := t.oracle.FeeRate(oracle.Network(network)); ok { clientUpdate := &protocolsPb.ClientFeeRateUpdate{ FeeRate: bigIntToBytes(feeRate), } @@ -196,8 +156,8 @@ func (t *TatankaNode) handleClientPublish(s network.Stream) { return } - if strings.HasPrefix(publishMessage.Topic, oracle.PriceTopicPrefix) || - strings.HasPrefix(publishMessage.Topic, oracle.FeeRateTopicPrefix) { + if strings.HasPrefix(publishMessage.Topic, protocols.PriceTopicPrefix) || + strings.HasPrefix(publishMessage.Topic, protocols.FeeRateTopicPrefix) { t.log.Warnf("Client %s attempted to publish to restricted oracle topic %s", client.ShortString(), publishMessage.Topic) return @@ -456,7 +416,7 @@ func (t *TatankaNode) handleForwardRelay(s network.Stream) { func (t *TatankaNode) findSubscribedPriceTopics(prices map[oracle.Ticker]float64) map[string][]peer.ID { candidates := make(map[string]struct{}, len(prices)) for ticker := range prices { - candidates[oracle.PriceTopicPrefix+string(ticker)] = struct{}{} + candidates[protocols.PriceTopicPrefix+string(ticker)] = struct{}{} } return t.subscriptionManager.subscribedTopics(candidates) @@ -466,7 +426,7 @@ func (t *TatankaNode) findSubscribedFeeRateTopics(feeRates map[oracle.Network]*b candidates := make(map[string]struct{}, len(feeRates)) for network := range feeRates { - candidates[oracle.FeeRateTopicPrefix+string(network)] = struct{}{} + candidates[protocols.FeeRateTopicPrefix+string(network)] = struct{}{} } return t.subscriptionManager.subscribedTopics(candidates) @@ -514,148 +474,79 @@ func (t *TatankaNode) distributeFeeRateUpdate(topic string, candidates []peer.ID t.pushStreamManager.distribute(candidates, pushMsg) } -func (t *TatankaNode) handleOracleUpdate(oracleUpdate *pb.NodeOracleUpdate) { - switch update := oracleUpdate.Update.(type) { - case *pb.NodeOracleUpdate_PriceUpdate: - pbUpdate := update.PriceUpdate - // Validate source-level fields - if pbUpdate.Source == "" { - t.log.Warn("Skipping price update with empty source") - return - } - if pbUpdate.Timestamp <= 0 { - t.log.Warnf("Skipping price update with invalid timestamp: %d", pbUpdate.Timestamp) - return - } - - // Convert and validate individual prices - prices := make([]*oracle.SourcedPrice, 0, len(pbUpdate.Prices)) - for _, p := range pbUpdate.Prices { - if p.Price <= 0 { - t.log.Warnf("Skipping price with invalid value: %f", p.Price) - continue - } - if p.Ticker == "" { - t.log.Warn("Skipping price with empty ticker") - continue - } - prices = append(prices, &oracle.SourcedPrice{ - Ticker: oracle.Ticker(p.Ticker), - Price: p.Price, - }) - } - - if len(prices) == 0 { - t.log.Warn("No valid prices to merge from gossipsub") - return - } - - sourcedUpdate := &oracle.SourcedPriceUpdate{ - Source: pbUpdate.Source, - Stamp: time.Unix(pbUpdate.Timestamp, 0), - Weight: t.oracle.GetSourceWeight(pbUpdate.Source), - Prices: prices, - } - - // Merge prices and get only the updated ones - updatedPrices := t.oracle.MergePrices(sourcedUpdate) - t.log.Debugf("Merged %d price updates from gossipsub", len(prices)) - - // Distribute updated prices to clients via per-ticker topics. - if len(updatedPrices) == 0 { - // Nothing to do. - return - } - - priceSubs := t.findSubscribedPriceTopics(updatedPrices) - if len(priceSubs) == 0 { - // Nothing to do. - return - } - - for topic, candidates := range priceSubs { - ticker := topic[len(oracle.PriceTopicPrefix):] - price, ok := updatedPrices[oracle.Ticker(ticker)] - if !ok { - t.log.Errorf("No update price found for %s", ticker) - } +func (t *TatankaNode) handleOracleUpdate(senderID peer.ID, oracleUpdate *pb.NodeOracleUpdate) { + if oracleUpdate.Source == "" { + t.log.Warn("Skipping oracle update with empty source") + return + } + if oracleUpdate.Timestamp <= 0 { + t.log.Warnf("Skipping oracle update with invalid timestamp: %d", oracleUpdate.Timestamp) + return + } - go func(topic string, price float64, candidates []peer.ID) { - t.distributePriceUpdate(topic, candidates, price) - }(topic, price, candidates) - } + // Extract piggybacked quota status and forward to oracle. + if oracleUpdate.Quota != nil { + t.oracle.UpdatePeerSourceQuota(senderID.String(), pbToTimestampedQuotaStatus(oracleUpdate.Quota), oracleUpdate.Source) + } - case *pb.NodeOracleUpdate_FeeRateUpdate: - pbUpdate := update.FeeRateUpdate - // Validate source-level fields - if pbUpdate.Source == "" { - t.log.Warn("Skipping fee rate update with empty source") - return - } - if pbUpdate.Timestamp <= 0 { - t.log.Warnf("Skipping fee rate update with invalid timestamp: %d", pbUpdate.Timestamp) - return - } + update := pbToOracleUpdate(oracleUpdate) + if len(update.Prices) == 0 && len(update.FeeRates) == 0 { + t.log.Warn("Skipping oracle update with no prices or fee rates") + return + } - // Convert and validate individual fee rates - feeRates := make([]*oracle.SourcedFeeRate, 0, len(pbUpdate.FeeRates)) - for _, fr := range pbUpdate.FeeRates { - if len(fr.FeeRate) == 0 { - t.log.Warn("Skipping fee rate with empty value") - continue - } - if fr.Network == "" { - t.log.Warn("Skipping fee rate with empty network") - continue - } - feeRates = append(feeRates, &oracle.SourcedFeeRate{ - Network: oracle.Network(fr.Network), - FeeRate: fr.FeeRate, - }) - } + result := t.oracle.Merge(update, senderID.String()) + if result == nil { + return + } - if len(feeRates) == 0 { - t.log.Warn("No valid fee rates to merge from gossipsub") - return - } + t.distributePriceUpdates(result.Prices) + t.distributeFeeRateUpdates(result.FeeRates) +} - sourcedUpdate := &oracle.SourcedFeeRateUpdate{ - Source: pbUpdate.Source, - Stamp: time.Unix(pbUpdate.Timestamp, 0), - Weight: t.oracle.GetSourceWeight(pbUpdate.Source), - FeeRates: feeRates, - } +func (t *TatankaNode) distributePriceUpdates(updatedPrices map[oracle.Ticker]float64) { + if len(updatedPrices) == 0 { + return + } - // Merge fee rates and get only the updated ones - updatedFeeRates := t.oracle.MergeFeeRates(sourcedUpdate) - t.log.Debugf("Merged %d fee rate updates from gossipsub", len(feeRates)) + priceSubs := t.findSubscribedPriceTopics(updatedPrices) + if len(priceSubs) == 0 { + return + } - // Distribute updated fee rates to clients via per-ticker topics. - if len(updatedFeeRates) == 0 { - // Nothing to do. - return + for topic, candidates := range priceSubs { + ticker := topic[len(protocols.PriceTopicPrefix):] + price, ok := updatedPrices[oracle.Ticker(ticker)] + if !ok { + t.log.Errorf("No updated price found for %s", ticker) + continue } + go func(topic string, price float64, candidates []peer.ID) { + t.distributePriceUpdate(topic, candidates, price) + }(topic, price, candidates) + } +} - feeRateSubs := t.findSubscribedFeeRateTopics(updatedFeeRates) - if len(feeRateSubs) == 0 { - // Nothing to do. - return - } +func (t *TatankaNode) distributeFeeRateUpdates(updatedFeeRates map[oracle.Network]*big.Int) { + if len(updatedFeeRates) == 0 { + return + } - for topic, candidates := range feeRateSubs { - network := topic[len(oracle.FeeRateTopicPrefix):] - feeRate, ok := updatedFeeRates[oracle.Network(network)] - if !ok { - t.log.Errorf("No updated fee rate found for %s", network) - } + feeRateSubs := t.findSubscribedFeeRateTopics(updatedFeeRates) + if len(feeRateSubs) == 0 { + return + } - go func(topic string, feeRate *big.Int, candidates []peer.ID) { - t.distributeFeeRateUpdate(topic, candidates, feeRate) - }(topic, feeRate, candidates) + for topic, candidates := range feeRateSubs { + network := topic[len(protocols.FeeRateTopicPrefix):] + feeRate, ok := updatedFeeRates[oracle.Network(network)] + if !ok { + t.log.Errorf("No updated fee rate found for %s", network) + continue } - - default: - t.log.Warnf("Received unknown oracle update type %T", update) + go func(topic string, feeRate *big.Int, candidates []peer.ID) { + t.distributeFeeRateUpdate(topic, candidates, feeRate) + }(topic, feeRate, candidates) } } @@ -784,222 +675,36 @@ func (t *TatankaNode) handleAvailableMeshNodes(s network.Stream) { } } -// --- Protobuf Helper Functions --- - -func pbPushMessageSubscription(topic string, client peer.ID, subscribed bool) *protocolsPb.PushMessage { - messageType := protocolsPb.PushMessage_SUBSCRIBE - if !subscribed { - messageType = protocolsPb.PushMessage_UNSUBSCRIBE - } - return &protocolsPb.PushMessage{ - MessageType: messageType, - Topic: topic, - Sender: []byte(client), - } -} - -func pbPushMessageBroadcast(topic string, data []byte, sender peer.ID) *protocolsPb.PushMessage { - return &protocolsPb.PushMessage{ - MessageType: protocolsPb.PushMessage_BROADCAST, - Topic: topic, - Data: data, - Sender: []byte(sender), - } -} - -func pbResponseError(err error) *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_Error{ - Error: &protocolsPb.Error{ - Error: &protocolsPb.Error_Message{ - Message: err.Error(), - }, - }, - }, - } -} - -func pbResponseUnauthorizedError() *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_Error{ - Error: &protocolsPb.Error{ - Error: &protocolsPb.Error_Unauthorized{ - Unauthorized: &protocolsPb.UnauthorizedError{}, - }, - }, - }, - } -} - -func pbResponseClientAddr(addrs [][]byte) *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_AddrResponse{ - AddrResponse: &protocolsPb.ClientAddrResponse{ - Addrs: addrs, - }, - }, - } -} - -func pbResponsePostBondError(index uint32) *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_Error{ - Error: &protocolsPb.Error{ - Error: &protocolsPb.Error_PostBondError{ - PostBondError: &protocolsPb.PostBondError{ - InvalidBondIndex: index, - }, - }, - }, - }, - } -} - -func pbResponsePostBond(bondStrength uint32) *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_PostBondResponse{ - PostBondResponse: &protocolsPb.PostBondResponse{ - BondStrength: bondStrength, - }, - }, - } -} - -func pbAvailableMeshNodesResponse(peers []*protocolsPb.PeerInfo) *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_AvailableMeshNodesResponse{ - AvailableMeshNodesResponse: &protocolsPb.AvailableMeshNodesResponse{ - Peers: peers, - }, - }, - } -} - -func pbResponseSuccess() *protocolsPb.Response { - return &protocolsPb.Response{ - Response: &protocolsPb.Response_Success{ - Success: &protocolsPb.Success{}, - }, - } -} - -func pbClientRelayMessageSuccess(message []byte) *protocolsPb.ClientRelayMessageResponse { - return &protocolsPb.ClientRelayMessageResponse{ - Response: &protocolsPb.ClientRelayMessageResponse_Message{ - Message: message, - }, +// handleQuotaHeartbeat handles a quota heartbeat message from another tatanka node. +// This is used to periodically share quota information via gossipsub. +func (t *TatankaNode) handleQuotaHeartbeat(senderID peer.ID, heartbeat *pb.QuotaHandshake) { + for source, q := range heartbeat.Quotas { + t.oracle.UpdatePeerSourceQuota(senderID.String(), pbToTimestampedQuotaStatus(q), source) } } -func pbClientRelayMessageError(err *protocolsPb.Error) *protocolsPb.ClientRelayMessageResponse { - return &protocolsPb.ClientRelayMessageResponse{ - Response: &protocolsPb.ClientRelayMessageResponse_Error{ - Error: err, - }, - } -} - -func pbClientRelayMessageErrorMessage(message string) *protocolsPb.ClientRelayMessageResponse { - return pbClientRelayMessageError(&protocolsPb.Error{ - Error: &protocolsPb.Error_Message{ - Message: message, - }, - }) -} - -func pbClientRelayMessageCounterpartyNotFound() *protocolsPb.ClientRelayMessageResponse { - return pbClientRelayMessageError(&protocolsPb.Error{ - Error: &protocolsPb.Error_CpNotFoundError{ - CpNotFoundError: &protocolsPb.CounterpartyNotFoundError{}, - }, - }) -} - -func pbClientRelayMessageCounterpartyRejected() *protocolsPb.ClientRelayMessageResponse { - return pbClientRelayMessageError(&protocolsPb.Error{ - Error: &protocolsPb.Error_CpRejectedError{ - CpRejectedError: &protocolsPb.CounterpartyRejectedError{}, - }, - }) -} - -func pbTatankaForwardRelaySuccess(message []byte) *pb.TatankaForwardRelayResponse { - return &pb.TatankaForwardRelayResponse{ - Response: &pb.TatankaForwardRelayResponse_Success{ - Success: message, - }, - } -} - -func pbTatankaForwardRelayClientNotFound() *pb.TatankaForwardRelayResponse { - return &pb.TatankaForwardRelayResponse{ - Response: &pb.TatankaForwardRelayResponse_ClientNotFound_{ - ClientNotFound: &pb.TatankaForwardRelayResponse_ClientNotFound{}, - }, - } -} - -func pbTatankaForwardRelayClientRejected() *pb.TatankaForwardRelayResponse { - return &pb.TatankaForwardRelayResponse{ - Response: &pb.TatankaForwardRelayResponse_ClientRejected_{ - ClientRejected: &pb.TatankaForwardRelayResponse_ClientRejected{}, - }, - } -} - -func pbTatankaForwardRelayError(message string) *pb.TatankaForwardRelayResponse { - return &pb.TatankaForwardRelayResponse{ - Response: &pb.TatankaForwardRelayResponse_Error{ - Error: message, - }, - } -} - -func pbWhitelistResponseSuccess() *pb.WhitelistResponse { - return &pb.WhitelistResponse{ - Response: &pb.WhitelistResponse_Success_{ - Success: &pb.WhitelistResponse_Success{}, - }, - } -} - -func pbWhitelistResponseMismatch(mismatchedPeerIDs [][]byte) *pb.WhitelistResponse { - return &pb.WhitelistResponse{ - Response: &pb.WhitelistResponse_Mismatch_{ - Mismatch: &pb.WhitelistResponse_Mismatch{ - PeerIDs: mismatchedPeerIDs, - }, - }, - } -} - -func pbDiscoveryResponseNotFound() *pb.DiscoveryResponse { - return &pb.DiscoveryResponse{ - Response: &pb.DiscoveryResponse_NotFound_{ - NotFound: &pb.DiscoveryResponse_NotFound{}, - }, - } -} +// handleQuotaHandshake handles a quota handshake request from another tatanka node. +// This is used to exchange quota information on connection. +func (t *TatankaNode) handleQuotaHandshake(s network.Stream) { + defer func() { _ = s.Close() }() + peerID := s.Conn().RemotePeer() -func pbDiscoveryResponseSuccess(addrs []ma.Multiaddr) *pb.DiscoveryResponse { - addrBytes := make([][]byte, 0, len(addrs)) - for _, addr := range addrs { - addrBytes = append(addrBytes, addr.Bytes()) + // Read peer's quotas + req := &pb.QuotaHandshake{} + if err := codec.ReadLengthPrefixedMessage(s, req); err != nil { + t.log.Warnf("Failed to read quota handshake from %s: %v", peerID.ShortString(), err) + return } - return &pb.DiscoveryResponse{ - Response: &pb.DiscoveryResponse_Success_{ - Success: &pb.DiscoveryResponse_Success{ - Addrs: addrBytes, - }, - }, + // Process peer quotas + for source, q := range req.Quotas { + t.oracle.UpdatePeerSourceQuota(peerID.String(), pbToTimestampedQuotaStatus(q), source) } -} -// bigIntToBytes converts big.Int to big-endian encoded bytes. -func bigIntToBytes(bi *big.Int) []byte { - if bi == nil || bi.Sign() == 0 { - return []byte{0} + // Send our quotas + localQuotas := quotaStatusesToPb(t.oracle.GetLocalQuotas()) + resp := &pb.QuotaHandshake{Quotas: localQuotas} + if err := codec.WriteLengthPrefixedMessage(s, resp); err != nil { + t.log.Warnf("Failed to send quota handshake to %s: %v", peerID.ShortString(), err) } - return bi.Bytes() } diff --git a/tatanka/mesh_connection_manager.go b/tatanka/mesh_connection_manager.go index bc9c7a4..1d2dedb 100644 --- a/tatanka/mesh_connection_manager.go +++ b/tatanka/mesh_connection_manager.go @@ -217,9 +217,39 @@ func (t *peerTracker) connect() error { return fmt.Errorf("failed to verify whitelist for peer %s: %w", t.peerID, err) } + go t.exchangeOracleQuotas() + return nil } +// exchangeOracleQuotas sends local quota information to the peer and receives theirs. +func (t *peerTracker) exchangeOracleQuotas() { + ctx, cancel := context.WithTimeout(t.ctx, 10*time.Second) + defer cancel() + + stream, err := t.m.node.NewStream(ctx, t.peerID, quotaHandshakeProtocol) + if err != nil { + t.m.log.Debugf("Quota handshake stream to %s failed: %v", t.peerID, err) + return + } + defer func() { _ = stream.Close() }() + + localQuotas := t.m.getLocalQuotas() + req := &pb.QuotaHandshake{Quotas: localQuotas} + if err := codec.WriteLengthPrefixedMessage(stream, req); err != nil { + t.m.log.Debugf("Failed to send quota handshake to %s: %v", t.peerID, err) + return + } + + resp := &pb.QuotaHandshake{} + if err := codec.ReadLengthPrefixedMessage(stream, resp); err != nil { + t.m.log.Debugf("Failed to read quota handshake from %s: %v", t.peerID, err) + return + } + + t.m.handlePeerQuotas(t.peerID, resp.Quotas) +} + // discoverAddresses asks connected whitelist peers for the address of the target. func (t *peerTracker) discoverAddresses() bool { whitelist := t.m.getWhitelist() @@ -319,15 +349,28 @@ type meshConnectionManager struct { initialOnce sync.Once initialErr atomic.Value // error adminCallback AdminUpdateCallback + + // Quota exchange callbacks + getLocalQuotas func() map[string]*pb.QuotaStatus + handlePeerQuotas func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) } -func newMeshConnectionManager(log slog.Logger, node host.Host, whitelist *whitelist, adminCallback AdminUpdateCallback) *meshConnectionManager { +func newMeshConnectionManager( + log slog.Logger, + node host.Host, + whitelist *whitelist, + adminCallback AdminUpdateCallback, + getLocalQuotas func() map[string]*pb.QuotaStatus, + handlePeerQuotas func(peerID peer.ID, quotas map[string]*pb.QuotaStatus), +) *meshConnectionManager { m := &meshConnectionManager{ - log: log, - node: node, - peerTrackers: make(map[peer.ID]*peerTracker), - initialCh: make(chan struct{}), - adminCallback: adminCallback, + log: log, + node: node, + peerTrackers: make(map[peer.ID]*peerTracker), + initialCh: make(chan struct{}), + adminCallback: adminCallback, + getLocalQuotas: getLocalQuotas, + handlePeerQuotas: handlePeerQuotas, } m.whitelist.Store(whitelist) diff --git a/tatanka/pb/messages.pb.go b/tatanka/pb/messages.pb.go index 5106ee2..013182f 100644 --- a/tatanka/pb/messages.pb.go +++ b/tatanka/pb/messages.pb.go @@ -516,29 +516,32 @@ func (*WhitelistResponse_Success_) isWhitelistResponse_Response() {} func (*WhitelistResponse_Mismatch_) isWhitelistResponse_Response() {} -// SourcedPrice represents a single price entry within a sourced update batch. -type SourcedPrice struct { +// NodeOracleUpdate contains oracle data for sharing between Tatanka mesh nodes. +type NodeOracleUpdate struct { state protoimpl.MessageState `protogen:"open.v1"` - Ticker string `protobuf:"bytes,1,opt,name=ticker,proto3" json:"ticker,omitempty"` - Price float64 `protobuf:"fixed64,2,opt,name=price,proto3" json:"price,omitempty"` + Source string `protobuf:"bytes,1,opt,name=source,proto3" json:"source,omitempty"` + Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + Prices map[string]float64 `protobuf:"bytes,3,rep,name=prices,proto3" json:"prices,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"fixed64,2,opt,name=value"` // ticker -> price + FeeRates map[string][]byte `protobuf:"bytes,4,rep,name=fee_rates,json=feeRates,proto3" json:"fee_rates,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // network -> big-endian encoded big.Int + Quota *QuotaStatus `protobuf:"bytes,5,opt,name=quota,proto3" json:"quota,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *SourcedPrice) Reset() { - *x = SourcedPrice{} +func (x *NodeOracleUpdate) Reset() { + *x = NodeOracleUpdate{} mi := &file_tatanka_pb_messages_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *SourcedPrice) String() string { +func (x *NodeOracleUpdate) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SourcedPrice) ProtoMessage() {} +func (*NodeOracleUpdate) ProtoMessage() {} -func (x *SourcedPrice) ProtoReflect() protoreflect.Message { +func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { mi := &file_tatanka_pb_messages_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -550,166 +553,71 @@ func (x *SourcedPrice) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SourcedPrice.ProtoReflect.Descriptor instead. -func (*SourcedPrice) Descriptor() ([]byte, []int) { +// Deprecated: Use NodeOracleUpdate.ProtoReflect.Descriptor instead. +func (*NodeOracleUpdate) Descriptor() ([]byte, []int) { return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{7} } -func (x *SourcedPrice) GetTicker() string { - if x != nil { - return x.Ticker - } - return "" -} - -func (x *SourcedPrice) GetPrice() float64 { - if x != nil { - return x.Price - } - return 0 -} - -// SourcedPriceUpdate is a batch of price updates from a single source for sharing -// between Tatanka Mesh nodes. -type SourcedPriceUpdate struct { - state protoimpl.MessageState `protogen:"open.v1"` - Source string `protobuf:"bytes,1,opt,name=source,proto3" json:"source,omitempty"` - Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` - Prices []*SourcedPrice `protobuf:"bytes,3,rep,name=prices,proto3" json:"prices,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *SourcedPriceUpdate) Reset() { - *x = SourcedPriceUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *SourcedPriceUpdate) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*SourcedPriceUpdate) ProtoMessage() {} - -func (x *SourcedPriceUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[8] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use SourcedPriceUpdate.ProtoReflect.Descriptor instead. -func (*SourcedPriceUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{8} -} - -func (x *SourcedPriceUpdate) GetSource() string { +func (x *NodeOracleUpdate) GetSource() string { if x != nil { return x.Source } return "" } -func (x *SourcedPriceUpdate) GetTimestamp() int64 { +func (x *NodeOracleUpdate) GetTimestamp() int64 { if x != nil { return x.Timestamp } return 0 } -func (x *SourcedPriceUpdate) GetPrices() []*SourcedPrice { +func (x *NodeOracleUpdate) GetPrices() map[string]float64 { if x != nil { return x.Prices } return nil } -// SourcedFeeRate represents a single fee rate entry within a sourced update batch. -type SourcedFeeRate struct { - state protoimpl.MessageState `protogen:"open.v1"` - Network string `protobuf:"bytes,1,opt,name=network,proto3" json:"network,omitempty"` - FeeRate []byte `protobuf:"bytes,2,opt,name=fee_rate,json=feeRate,proto3" json:"fee_rate,omitempty"` // big-endian encoded big integer - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *SourcedFeeRate) Reset() { - *x = SourcedFeeRate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *SourcedFeeRate) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*SourcedFeeRate) ProtoMessage() {} - -func (x *SourcedFeeRate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[9] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use SourcedFeeRate.ProtoReflect.Descriptor instead. -func (*SourcedFeeRate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{9} -} - -func (x *SourcedFeeRate) GetNetwork() string { +func (x *NodeOracleUpdate) GetFeeRates() map[string][]byte { if x != nil { - return x.Network + return x.FeeRates } - return "" + return nil } -func (x *SourcedFeeRate) GetFeeRate() []byte { +func (x *NodeOracleUpdate) GetQuota() *QuotaStatus { if x != nil { - return x.FeeRate + return x.Quota } return nil } -// SourcedFeeRateUpdate is a batch of fee rate updates from a single source for sharing -// between Tatanka Mesh nodes. -type SourcedFeeRateUpdate struct { - state protoimpl.MessageState `protogen:"open.v1"` - Source string `protobuf:"bytes,1,opt,name=source,proto3" json:"source,omitempty"` - Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` - FeeRates []*SourcedFeeRate `protobuf:"bytes,3,rep,name=fee_rates,json=feeRates,proto3" json:"fee_rates,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache +// QuotaStatus represents quota state for an API source. +type QuotaStatus struct { + state protoimpl.MessageState `protogen:"open.v1"` + FetchesRemaining int64 `protobuf:"varint,1,opt,name=fetches_remaining,json=fetchesRemaining,proto3" json:"fetches_remaining,omitempty"` + FetchesLimit int64 `protobuf:"varint,2,opt,name=fetches_limit,json=fetchesLimit,proto3" json:"fetches_limit,omitempty"` + ResetTimestamp int64 `protobuf:"varint,3,opt,name=reset_timestamp,json=resetTimestamp,proto3" json:"reset_timestamp,omitempty"` // Unix timestamp + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } -func (x *SourcedFeeRateUpdate) Reset() { - *x = SourcedFeeRateUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[10] +func (x *QuotaStatus) Reset() { + *x = QuotaStatus{} + mi := &file_tatanka_pb_messages_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *SourcedFeeRateUpdate) String() string { +func (x *QuotaStatus) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SourcedFeeRateUpdate) ProtoMessage() {} +func (*QuotaStatus) ProtoMessage() {} -func (x *SourcedFeeRateUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[10] +func (x *QuotaStatus) ProtoReflect() protoreflect.Message { + mi := &file_tatanka_pb_messages_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -720,59 +628,56 @@ func (x *SourcedFeeRateUpdate) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SourcedFeeRateUpdate.ProtoReflect.Descriptor instead. -func (*SourcedFeeRateUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{10} +// Deprecated: Use QuotaStatus.ProtoReflect.Descriptor instead. +func (*QuotaStatus) Descriptor() ([]byte, []int) { + return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{8} } -func (x *SourcedFeeRateUpdate) GetSource() string { +func (x *QuotaStatus) GetFetchesRemaining() int64 { if x != nil { - return x.Source + return x.FetchesRemaining } - return "" + return 0 } -func (x *SourcedFeeRateUpdate) GetTimestamp() int64 { +func (x *QuotaStatus) GetFetchesLimit() int64 { if x != nil { - return x.Timestamp + return x.FetchesLimit } return 0 } -func (x *SourcedFeeRateUpdate) GetFeeRates() []*SourcedFeeRate { +func (x *QuotaStatus) GetResetTimestamp() int64 { if x != nil { - return x.FeeRates + return x.ResetTimestamp } - return nil + return 0 } -// NodeOracleUpdate contains oracle data for sharing between Tatanka mesh nodes. -type NodeOracleUpdate struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Types that are valid to be assigned to Update: - // - // *NodeOracleUpdate_PriceUpdate - // *NodeOracleUpdate_FeeRateUpdate - Update isNodeOracleUpdate_Update `protobuf_oneof:"update"` +// QuotaHandshake is exchanged between nodes on connection and periodically +// via heartbeat to share quota information for network-coordinated scheduling. +type QuotaHandshake struct { + state protoimpl.MessageState `protogen:"open.v1"` + Quotas map[string]*QuotaStatus `protobuf:"bytes,1,rep,name=quotas,proto3" json:"quotas,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // source -> quota unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *NodeOracleUpdate) Reset() { - *x = NodeOracleUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[11] +func (x *QuotaHandshake) Reset() { + *x = QuotaHandshake{} + mi := &file_tatanka_pb_messages_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *NodeOracleUpdate) String() string { +func (x *QuotaHandshake) String() string { return protoimpl.X.MessageStringOf(x) } -func (*NodeOracleUpdate) ProtoMessage() {} +func (*QuotaHandshake) ProtoMessage() {} -func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[11] +func (x *QuotaHandshake) ProtoReflect() protoreflect.Message { + mi := &file_tatanka_pb_messages_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -783,52 +688,18 @@ func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use NodeOracleUpdate.ProtoReflect.Descriptor instead. -func (*NodeOracleUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{11} -} - -func (x *NodeOracleUpdate) GetUpdate() isNodeOracleUpdate_Update { - if x != nil { - return x.Update - } - return nil -} - -func (x *NodeOracleUpdate) GetPriceUpdate() *SourcedPriceUpdate { - if x != nil { - if x, ok := x.Update.(*NodeOracleUpdate_PriceUpdate); ok { - return x.PriceUpdate - } - } - return nil +// Deprecated: Use QuotaHandshake.ProtoReflect.Descriptor instead. +func (*QuotaHandshake) Descriptor() ([]byte, []int) { + return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{9} } -func (x *NodeOracleUpdate) GetFeeRateUpdate() *SourcedFeeRateUpdate { +func (x *QuotaHandshake) GetQuotas() map[string]*QuotaStatus { if x != nil { - if x, ok := x.Update.(*NodeOracleUpdate_FeeRateUpdate); ok { - return x.FeeRateUpdate - } + return x.Quotas } return nil } -type isNodeOracleUpdate_Update interface { - isNodeOracleUpdate_Update() -} - -type NodeOracleUpdate_PriceUpdate struct { - PriceUpdate *SourcedPriceUpdate `protobuf:"bytes,1,opt,name=price_update,json=priceUpdate,proto3,oneof"` -} - -type NodeOracleUpdate_FeeRateUpdate struct { - FeeRateUpdate *SourcedFeeRateUpdate `protobuf:"bytes,2,opt,name=fee_rate_update,json=feeRateUpdate,proto3,oneof"` -} - -func (*NodeOracleUpdate_PriceUpdate) isNodeOracleUpdate_Update() {} - -func (*NodeOracleUpdate_FeeRateUpdate) isNodeOracleUpdate_Update() {} - type TatankaForwardRelayResponse_ClientNotFound struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -837,7 +708,7 @@ type TatankaForwardRelayResponse_ClientNotFound struct { func (x *TatankaForwardRelayResponse_ClientNotFound) Reset() { *x = TatankaForwardRelayResponse_ClientNotFound{} - mi := &file_tatanka_pb_messages_proto_msgTypes[12] + mi := &file_tatanka_pb_messages_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -849,7 +720,7 @@ func (x *TatankaForwardRelayResponse_ClientNotFound) String() string { func (*TatankaForwardRelayResponse_ClientNotFound) ProtoMessage() {} func (x *TatankaForwardRelayResponse_ClientNotFound) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[12] + mi := &file_tatanka_pb_messages_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -873,7 +744,7 @@ type TatankaForwardRelayResponse_ClientRejected struct { func (x *TatankaForwardRelayResponse_ClientRejected) Reset() { *x = TatankaForwardRelayResponse_ClientRejected{} - mi := &file_tatanka_pb_messages_proto_msgTypes[13] + mi := &file_tatanka_pb_messages_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -885,7 +756,7 @@ func (x *TatankaForwardRelayResponse_ClientRejected) String() string { func (*TatankaForwardRelayResponse_ClientRejected) ProtoMessage() {} func (x *TatankaForwardRelayResponse_ClientRejected) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[13] + mi := &file_tatanka_pb_messages_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -910,7 +781,7 @@ type DiscoveryResponse_Success struct { func (x *DiscoveryResponse_Success) Reset() { *x = DiscoveryResponse_Success{} - mi := &file_tatanka_pb_messages_proto_msgTypes[14] + mi := &file_tatanka_pb_messages_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -922,7 +793,7 @@ func (x *DiscoveryResponse_Success) String() string { func (*DiscoveryResponse_Success) ProtoMessage() {} func (x *DiscoveryResponse_Success) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[14] + mi := &file_tatanka_pb_messages_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -953,7 +824,7 @@ type DiscoveryResponse_NotFound struct { func (x *DiscoveryResponse_NotFound) Reset() { *x = DiscoveryResponse_NotFound{} - mi := &file_tatanka_pb_messages_proto_msgTypes[15] + mi := &file_tatanka_pb_messages_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -965,7 +836,7 @@ func (x *DiscoveryResponse_NotFound) String() string { func (*DiscoveryResponse_NotFound) ProtoMessage() {} func (x *DiscoveryResponse_NotFound) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[15] + mi := &file_tatanka_pb_messages_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -989,7 +860,7 @@ type WhitelistResponse_Success struct { func (x *WhitelistResponse_Success) Reset() { *x = WhitelistResponse_Success{} - mi := &file_tatanka_pb_messages_proto_msgTypes[16] + mi := &file_tatanka_pb_messages_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1001,7 +872,7 @@ func (x *WhitelistResponse_Success) String() string { func (*WhitelistResponse_Success) ProtoMessage() {} func (x *WhitelistResponse_Success) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[16] + mi := &file_tatanka_pb_messages_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1026,7 +897,7 @@ type WhitelistResponse_Mismatch struct { func (x *WhitelistResponse_Mismatch) Reset() { *x = WhitelistResponse_Mismatch{} - mi := &file_tatanka_pb_messages_proto_msgTypes[17] + mi := &file_tatanka_pb_messages_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1038,7 +909,7 @@ func (x *WhitelistResponse_Mismatch) String() string { func (*WhitelistResponse_Mismatch) ProtoMessage() {} func (x *WhitelistResponse_Mismatch) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[17] + mi := &file_tatanka_pb_messages_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1105,25 +976,28 @@ const file_tatanka_pb_messages_proto_rawDesc = "" + "\bMismatch\x12\x18\n" + "\apeerIDs\x18\x01 \x03(\fR\apeerIDsB\n" + "\n" + - "\bresponse\"<\n" + - "\fSourcedPrice\x12\x16\n" + - "\x06ticker\x18\x01 \x01(\tR\x06ticker\x12\x14\n" + - "\x05price\x18\x02 \x01(\x01R\x05price\"t\n" + - "\x12SourcedPriceUpdate\x12\x16\n" + + "\bresponse\"\xe2\x02\n" + + "\x10NodeOracleUpdate\x12\x16\n" + "\x06source\x18\x01 \x01(\tR\x06source\x12\x1c\n" + - "\ttimestamp\x18\x02 \x01(\x03R\ttimestamp\x12(\n" + - "\x06prices\x18\x03 \x03(\v2\x10.pb.SourcedPriceR\x06prices\"E\n" + - "\x0eSourcedFeeRate\x12\x18\n" + - "\anetwork\x18\x01 \x01(\tR\anetwork\x12\x19\n" + - "\bfee_rate\x18\x02 \x01(\fR\afeeRate\"}\n" + - "\x14SourcedFeeRateUpdate\x12\x16\n" + - "\x06source\x18\x01 \x01(\tR\x06source\x12\x1c\n" + - "\ttimestamp\x18\x02 \x01(\x03R\ttimestamp\x12/\n" + - "\tfee_rates\x18\x03 \x03(\v2\x12.pb.SourcedFeeRateR\bfeeRates\"\x9d\x01\n" + - "\x10NodeOracleUpdate\x12;\n" + - "\fprice_update\x18\x01 \x01(\v2\x16.pb.SourcedPriceUpdateH\x00R\vpriceUpdate\x12B\n" + - "\x0ffee_rate_update\x18\x02 \x01(\v2\x18.pb.SourcedFeeRateUpdateH\x00R\rfeeRateUpdateB\b\n" + - "\x06updateB'Z%github.com/bisoncraft/mesh/tatanka/pbb\x06proto3" + "\ttimestamp\x18\x02 \x01(\x03R\ttimestamp\x128\n" + + "\x06prices\x18\x03 \x03(\v2 .pb.NodeOracleUpdate.PricesEntryR\x06prices\x12?\n" + + "\tfee_rates\x18\x04 \x03(\v2\".pb.NodeOracleUpdate.FeeRatesEntryR\bfeeRates\x12%\n" + + "\x05quota\x18\x05 \x01(\v2\x0f.pb.QuotaStatusR\x05quota\x1a9\n" + + "\vPricesEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\x01R\x05value:\x028\x01\x1a;\n" + + "\rFeeRatesEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\fR\x05value:\x028\x01\"\x88\x01\n" + + "\vQuotaStatus\x12+\n" + + "\x11fetches_remaining\x18\x01 \x01(\x03R\x10fetchesRemaining\x12#\n" + + "\rfetches_limit\x18\x02 \x01(\x03R\ffetchesLimit\x12'\n" + + "\x0freset_timestamp\x18\x03 \x01(\x03R\x0eresetTimestamp\"\x94\x01\n" + + "\x0eQuotaHandshake\x126\n" + + "\x06quotas\x18\x01 \x03(\v2\x1e.pb.QuotaHandshake.QuotasEntryR\x06quotas\x1aJ\n" + + "\vQuotasEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12%\n" + + "\x05value\x18\x02 \x01(\v2\x0f.pb.QuotaStatusR\x05value:\x028\x01B'Z%github.com/bisoncraft/mesh/tatanka/pbb\x06proto3" var ( file_tatanka_pb_messages_proto_rawDescOnce sync.Once @@ -1137,7 +1011,7 @@ func file_tatanka_pb_messages_proto_rawDescGZIP() []byte { return file_tatanka_pb_messages_proto_rawDescData } -var file_tatanka_pb_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 18) +var file_tatanka_pb_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 19) var file_tatanka_pb_messages_proto_goTypes = []any{ (*ClientConnectionMsg)(nil), // 0: pb.ClientConnectionMsg (*TatankaForwardRelayRequest)(nil), // 1: pb.TatankaForwardRelayRequest @@ -1146,34 +1020,36 @@ var file_tatanka_pb_messages_proto_goTypes = []any{ (*DiscoveryResponse)(nil), // 4: pb.DiscoveryResponse (*WhitelistRequest)(nil), // 5: pb.WhitelistRequest (*WhitelistResponse)(nil), // 6: pb.WhitelistResponse - (*SourcedPrice)(nil), // 7: pb.SourcedPrice - (*SourcedPriceUpdate)(nil), // 8: pb.SourcedPriceUpdate - (*SourcedFeeRate)(nil), // 9: pb.SourcedFeeRate - (*SourcedFeeRateUpdate)(nil), // 10: pb.SourcedFeeRateUpdate - (*NodeOracleUpdate)(nil), // 11: pb.NodeOracleUpdate - (*TatankaForwardRelayResponse_ClientNotFound)(nil), // 12: pb.TatankaForwardRelayResponse.ClientNotFound - (*TatankaForwardRelayResponse_ClientRejected)(nil), // 13: pb.TatankaForwardRelayResponse.ClientRejected - (*DiscoveryResponse_Success)(nil), // 14: pb.DiscoveryResponse.Success - (*DiscoveryResponse_NotFound)(nil), // 15: pb.DiscoveryResponse.NotFound - (*WhitelistResponse_Success)(nil), // 16: pb.WhitelistResponse.Success - (*WhitelistResponse_Mismatch)(nil), // 17: pb.WhitelistResponse.Mismatch + (*NodeOracleUpdate)(nil), // 7: pb.NodeOracleUpdate + (*QuotaStatus)(nil), // 8: pb.QuotaStatus + (*QuotaHandshake)(nil), // 9: pb.QuotaHandshake + (*TatankaForwardRelayResponse_ClientNotFound)(nil), // 10: pb.TatankaForwardRelayResponse.ClientNotFound + (*TatankaForwardRelayResponse_ClientRejected)(nil), // 11: pb.TatankaForwardRelayResponse.ClientRejected + (*DiscoveryResponse_Success)(nil), // 12: pb.DiscoveryResponse.Success + (*DiscoveryResponse_NotFound)(nil), // 13: pb.DiscoveryResponse.NotFound + (*WhitelistResponse_Success)(nil), // 14: pb.WhitelistResponse.Success + (*WhitelistResponse_Mismatch)(nil), // 15: pb.WhitelistResponse.Mismatch + nil, // 16: pb.NodeOracleUpdate.PricesEntry + nil, // 17: pb.NodeOracleUpdate.FeeRatesEntry + nil, // 18: pb.QuotaHandshake.QuotasEntry } var file_tatanka_pb_messages_proto_depIdxs = []int32{ - 12, // 0: pb.TatankaForwardRelayResponse.client_not_found:type_name -> pb.TatankaForwardRelayResponse.ClientNotFound - 13, // 1: pb.TatankaForwardRelayResponse.client_rejected:type_name -> pb.TatankaForwardRelayResponse.ClientRejected - 14, // 2: pb.DiscoveryResponse.success:type_name -> pb.DiscoveryResponse.Success - 15, // 3: pb.DiscoveryResponse.not_found:type_name -> pb.DiscoveryResponse.NotFound - 16, // 4: pb.WhitelistResponse.success:type_name -> pb.WhitelistResponse.Success - 17, // 5: pb.WhitelistResponse.mismatch:type_name -> pb.WhitelistResponse.Mismatch - 7, // 6: pb.SourcedPriceUpdate.prices:type_name -> pb.SourcedPrice - 9, // 7: pb.SourcedFeeRateUpdate.fee_rates:type_name -> pb.SourcedFeeRate - 8, // 8: pb.NodeOracleUpdate.price_update:type_name -> pb.SourcedPriceUpdate - 10, // 9: pb.NodeOracleUpdate.fee_rate_update:type_name -> pb.SourcedFeeRateUpdate - 10, // [10:10] is the sub-list for method output_type - 10, // [10:10] is the sub-list for method input_type - 10, // [10:10] is the sub-list for extension type_name - 10, // [10:10] is the sub-list for extension extendee - 0, // [0:10] is the sub-list for field type_name + 10, // 0: pb.TatankaForwardRelayResponse.client_not_found:type_name -> pb.TatankaForwardRelayResponse.ClientNotFound + 11, // 1: pb.TatankaForwardRelayResponse.client_rejected:type_name -> pb.TatankaForwardRelayResponse.ClientRejected + 12, // 2: pb.DiscoveryResponse.success:type_name -> pb.DiscoveryResponse.Success + 13, // 3: pb.DiscoveryResponse.not_found:type_name -> pb.DiscoveryResponse.NotFound + 14, // 4: pb.WhitelistResponse.success:type_name -> pb.WhitelistResponse.Success + 15, // 5: pb.WhitelistResponse.mismatch:type_name -> pb.WhitelistResponse.Mismatch + 16, // 6: pb.NodeOracleUpdate.prices:type_name -> pb.NodeOracleUpdate.PricesEntry + 17, // 7: pb.NodeOracleUpdate.fee_rates:type_name -> pb.NodeOracleUpdate.FeeRatesEntry + 8, // 8: pb.NodeOracleUpdate.quota:type_name -> pb.QuotaStatus + 18, // 9: pb.QuotaHandshake.quotas:type_name -> pb.QuotaHandshake.QuotasEntry + 8, // 10: pb.QuotaHandshake.QuotasEntry.value:type_name -> pb.QuotaStatus + 11, // [11:11] is the sub-list for method output_type + 11, // [11:11] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name } func init() { file_tatanka_pb_messages_proto_init() } @@ -1195,17 +1071,13 @@ func file_tatanka_pb_messages_proto_init() { (*WhitelistResponse_Success_)(nil), (*WhitelistResponse_Mismatch_)(nil), } - file_tatanka_pb_messages_proto_msgTypes[11].OneofWrappers = []any{ - (*NodeOracleUpdate_PriceUpdate)(nil), - (*NodeOracleUpdate_FeeRateUpdate)(nil), - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_tatanka_pb_messages_proto_rawDesc), len(file_tatanka_pb_messages_proto_rawDesc)), NumEnums: 0, - NumMessages: 18, + NumMessages: 19, NumExtensions: 0, NumServices: 0, }, diff --git a/tatanka/pb/messages.proto b/tatanka/pb/messages.proto index 51472c6..c86e9a9 100644 --- a/tatanka/pb/messages.proto +++ b/tatanka/pb/messages.proto @@ -68,38 +68,24 @@ message WhitelistResponse { } } -// SourcedPrice represents a single price entry within a sourced update batch. -message SourcedPrice { - string ticker = 1; - double price = 2; -} - -// SourcedPriceUpdate is a batch of price updates from a single source for sharing -// between Tatanka Mesh nodes. -message SourcedPriceUpdate { +// NodeOracleUpdate contains oracle data for sharing between Tatanka mesh nodes. +message NodeOracleUpdate { string source = 1; int64 timestamp = 2; - repeated SourcedPrice prices = 3; + map prices = 3; // ticker -> price + map fee_rates = 4; // network -> big-endian encoded big.Int + QuotaStatus quota = 5; } -// SourcedFeeRate represents a single fee rate entry within a sourced update batch. -message SourcedFeeRate { - string network = 1; - bytes fee_rate = 2; // big-endian encoded big integer +// QuotaStatus represents quota state for an API source. +message QuotaStatus { + int64 fetches_remaining = 1; + int64 fetches_limit = 2; + int64 reset_timestamp = 3; // Unix timestamp } -// SourcedFeeRateUpdate is a batch of fee rate updates from a single source for sharing -// between Tatanka Mesh nodes. -message SourcedFeeRateUpdate { - string source = 1; - int64 timestamp = 2; - repeated SourcedFeeRate fee_rates = 3; -} - -// NodeOracleUpdate contains oracle data for sharing between Tatanka mesh nodes. -message NodeOracleUpdate { - oneof update { - SourcedPriceUpdate price_update = 1; - SourcedFeeRateUpdate fee_rate_update = 2; - } +// QuotaHandshake is exchanged between nodes on connection and periodically +// via heartbeat to share quota information for network-coordinated scheduling. +message QuotaHandshake { + map quotas = 1; // source -> quota } diff --git a/tatanka/pb_helpers.go b/tatanka/pb_helpers.go new file mode 100644 index 0000000..e4f28ca --- /dev/null +++ b/tatanka/pb_helpers.go @@ -0,0 +1,357 @@ +package tatanka + +import ( + "fmt" + "math/big" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/oracle/sources" + protocolsPb "github.com/bisoncraft/mesh/protocols/pb" + "github.com/bisoncraft/mesh/tatanka/pb" + ma "github.com/multiformats/go-multiaddr" +) + +// libp2pPeerInfoToPb converts a peer.AddrInfo to a protocolsPb.PeerInfo. +func libp2pPeerInfoToPb(peerInfo peer.AddrInfo) *protocolsPb.PeerInfo { + addrBytes := make([][]byte, len(peerInfo.Addrs)) + for i, addr := range peerInfo.Addrs { + addrBytes[i] = addr.Bytes() + } + + return &protocolsPb.PeerInfo{ + Id: []byte(peerInfo.ID), + Addrs: addrBytes, + } +} + +// pbPeerInfoToLibp2p converts a protocolsPb.PeerInfo to a peer.AddrInfo. +func pbPeerInfoToLibp2p(pbPeer *protocolsPb.PeerInfo) (peer.AddrInfo, error) { + peerID, err := peer.IDFromBytes(pbPeer.Id) + if err != nil { + return peer.AddrInfo{}, fmt.Errorf("failed to parse peer ID: %w", err) + } + + addrs := make([]ma.Multiaddr, 0, len(pbPeer.Addrs)) + for _, addrBytes := range pbPeer.Addrs { + addr, err := ma.NewMultiaddrBytes(addrBytes) + if err != nil { + return peer.AddrInfo{}, fmt.Errorf("failed to parse multiaddr: %w", err) + } + addrs = append(addrs, addr) + } + + return peer.AddrInfo{ + ID: peerID, + Addrs: addrs, + }, nil +} + +func oracleUpdateToPb(update *oracle.OracleUpdate) (*pb.NodeOracleUpdate, error) { + if update == nil { + return nil, fmt.Errorf("oracle update is nil") + } + + msg := &pb.NodeOracleUpdate{ + Source: update.Source, + Timestamp: update.Stamp.Unix(), + } + + if len(update.Prices) > 0 { + msg.Prices = make(map[string]float64, len(update.Prices)) + for ticker, price := range update.Prices { + msg.Prices[string(ticker)] = price + } + } + + if len(update.FeeRates) > 0 { + msg.FeeRates = make(map[string][]byte, len(update.FeeRates)) + for network, feeRate := range update.FeeRates { + msg.FeeRates[string(network)] = bigIntToBytes(feeRate) + } + } + + if update.Quota != nil { + msg.Quota = quotaStatusToPb(update.Quota) + } + + return msg, nil +} + +func pbToOracleUpdate(pbUpdate *pb.NodeOracleUpdate) *oracle.OracleUpdate { + update := &oracle.OracleUpdate{ + Source: pbUpdate.Source, + Stamp: time.Unix(pbUpdate.Timestamp, 0), + } + + if len(pbUpdate.Prices) > 0 { + update.Prices = make(map[oracle.Ticker]float64, len(pbUpdate.Prices)) + for ticker, price := range pbUpdate.Prices { + update.Prices[oracle.Ticker(ticker)] = price + } + } + + if len(pbUpdate.FeeRates) > 0 { + update.FeeRates = make(map[oracle.Network]*big.Int, len(pbUpdate.FeeRates)) + for network, feeRateBytes := range pbUpdate.FeeRates { + update.FeeRates[oracle.Network(network)] = new(big.Int).SetBytes(feeRateBytes) + } + } + + return update +} + +func pbToTimestampedQuotaStatus(q *pb.QuotaStatus) *oracle.TimestampedQuotaStatus { + return &oracle.TimestampedQuotaStatus{ + QuotaStatus: &sources.QuotaStatus{ + FetchesRemaining: q.FetchesRemaining, + FetchesLimit: q.FetchesLimit, + ResetTime: time.Unix(q.ResetTimestamp, 0), + }, + ReceivedAt: time.Now(), + } +} + +func quotaStatusToPb(quota *sources.QuotaStatus) *pb.QuotaStatus { + if quota == nil { + return nil + } + return &pb.QuotaStatus{ + FetchesRemaining: quota.FetchesRemaining, + FetchesLimit: quota.FetchesLimit, + ResetTimestamp: quota.ResetTime.Unix(), + } +} + +func quotaStatusesToPb(quotas map[string]*sources.QuotaStatus) map[string]*pb.QuotaStatus { + if len(quotas) == 0 { + return nil + } + result := make(map[string]*pb.QuotaStatus, len(quotas)) + for source, quota := range quotas { + if quota == nil { + continue + } + result[source] = quotaStatusToPb(quota) + } + return result +} + +func pbPushMessageSubscription(topic string, client peer.ID, subscribed bool) *protocolsPb.PushMessage { + messageType := protocolsPb.PushMessage_SUBSCRIBE + if !subscribed { + messageType = protocolsPb.PushMessage_UNSUBSCRIBE + } + return &protocolsPb.PushMessage{ + MessageType: messageType, + Topic: topic, + Sender: []byte(client), + } +} + +func pbPushMessageBroadcast(topic string, data []byte, sender peer.ID) *protocolsPb.PushMessage { + return &protocolsPb.PushMessage{ + MessageType: protocolsPb.PushMessage_BROADCAST, + Topic: topic, + Data: data, + Sender: []byte(sender), + } +} + +func pbResponseError(err error) *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_Error{ + Error: &protocolsPb.Error{ + Error: &protocolsPb.Error_Message{ + Message: err.Error(), + }, + }, + }, + } +} + +func pbResponseUnauthorizedError() *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_Error{ + Error: &protocolsPb.Error{ + Error: &protocolsPb.Error_Unauthorized{ + Unauthorized: &protocolsPb.UnauthorizedError{}, + }, + }, + }, + } +} + +func pbResponseClientAddr(addrs [][]byte) *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_AddrResponse{ + AddrResponse: &protocolsPb.ClientAddrResponse{ + Addrs: addrs, + }, + }, + } +} + +func pbResponsePostBondError(index uint32) *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_Error{ + Error: &protocolsPb.Error{ + Error: &protocolsPb.Error_PostBondError{ + PostBondError: &protocolsPb.PostBondError{ + InvalidBondIndex: index, + }, + }, + }, + }, + } +} + +func pbResponsePostBond(bondStrength uint32) *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_PostBondResponse{ + PostBondResponse: &protocolsPb.PostBondResponse{ + BondStrength: bondStrength, + }, + }, + } +} + +func pbAvailableMeshNodesResponse(peers []*protocolsPb.PeerInfo) *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_AvailableMeshNodesResponse{ + AvailableMeshNodesResponse: &protocolsPb.AvailableMeshNodesResponse{ + Peers: peers, + }, + }, + } +} + +func pbResponseSuccess() *protocolsPb.Response { + return &protocolsPb.Response{ + Response: &protocolsPb.Response_Success{ + Success: &protocolsPb.Success{}, + }, + } +} + +func pbClientRelayMessageSuccess(message []byte) *protocolsPb.ClientRelayMessageResponse { + return &protocolsPb.ClientRelayMessageResponse{ + Response: &protocolsPb.ClientRelayMessageResponse_Message{ + Message: message, + }, + } +} + +func pbClientRelayMessageError(err *protocolsPb.Error) *protocolsPb.ClientRelayMessageResponse { + return &protocolsPb.ClientRelayMessageResponse{ + Response: &protocolsPb.ClientRelayMessageResponse_Error{ + Error: err, + }, + } +} + +func pbClientRelayMessageErrorMessage(message string) *protocolsPb.ClientRelayMessageResponse { + return pbClientRelayMessageError(&protocolsPb.Error{ + Error: &protocolsPb.Error_Message{ + Message: message, + }, + }) +} + +func pbClientRelayMessageCounterpartyNotFound() *protocolsPb.ClientRelayMessageResponse { + return pbClientRelayMessageError(&protocolsPb.Error{ + Error: &protocolsPb.Error_CpNotFoundError{ + CpNotFoundError: &protocolsPb.CounterpartyNotFoundError{}, + }, + }) +} + +func pbClientRelayMessageCounterpartyRejected() *protocolsPb.ClientRelayMessageResponse { + return pbClientRelayMessageError(&protocolsPb.Error{ + Error: &protocolsPb.Error_CpRejectedError{ + CpRejectedError: &protocolsPb.CounterpartyRejectedError{}, + }, + }) +} + +func pbTatankaForwardRelaySuccess(message []byte) *pb.TatankaForwardRelayResponse { + return &pb.TatankaForwardRelayResponse{ + Response: &pb.TatankaForwardRelayResponse_Success{ + Success: message, + }, + } +} + +func pbTatankaForwardRelayClientNotFound() *pb.TatankaForwardRelayResponse { + return &pb.TatankaForwardRelayResponse{ + Response: &pb.TatankaForwardRelayResponse_ClientNotFound_{ + ClientNotFound: &pb.TatankaForwardRelayResponse_ClientNotFound{}, + }, + } +} + +func pbTatankaForwardRelayClientRejected() *pb.TatankaForwardRelayResponse { + return &pb.TatankaForwardRelayResponse{ + Response: &pb.TatankaForwardRelayResponse_ClientRejected_{ + ClientRejected: &pb.TatankaForwardRelayResponse_ClientRejected{}, + }, + } +} + +func pbTatankaForwardRelayError(message string) *pb.TatankaForwardRelayResponse { + return &pb.TatankaForwardRelayResponse{ + Response: &pb.TatankaForwardRelayResponse_Error{ + Error: message, + }, + } +} + +func pbWhitelistResponseSuccess() *pb.WhitelistResponse { + return &pb.WhitelistResponse{ + Response: &pb.WhitelistResponse_Success_{ + Success: &pb.WhitelistResponse_Success{}, + }, + } +} + +func pbWhitelistResponseMismatch(mismatchedPeerIDs [][]byte) *pb.WhitelistResponse { + return &pb.WhitelistResponse{ + Response: &pb.WhitelistResponse_Mismatch_{ + Mismatch: &pb.WhitelistResponse_Mismatch{ + PeerIDs: mismatchedPeerIDs, + }, + }, + } +} + +func pbDiscoveryResponseNotFound() *pb.DiscoveryResponse { + return &pb.DiscoveryResponse{ + Response: &pb.DiscoveryResponse_NotFound_{ + NotFound: &pb.DiscoveryResponse_NotFound{}, + }, + } +} + +func pbDiscoveryResponseSuccess(addrs []ma.Multiaddr) *pb.DiscoveryResponse { + addrBytes := make([][]byte, 0, len(addrs)) + for _, addr := range addrs { + addrBytes = append(addrBytes, addr.Bytes()) + } + + return &pb.DiscoveryResponse{ + Response: &pb.DiscoveryResponse_Success_{ + Success: &pb.DiscoveryResponse_Success{ + Addrs: addrBytes, + }, + }, + } +} + +// bigIntToBytes converts big.Int to big-endian encoded bytes. +func bigIntToBytes(bi *big.Int) []byte { + if bi == nil || bi.Sign() == 0 { + return []byte{0} + } + return bi.Bytes() +} diff --git a/tatanka/tatanka.go b/tatanka/tatanka.go index 9f6c900..9170a1e 100644 --- a/tatanka/tatanka.go +++ b/tatanka/tatanka.go @@ -20,9 +20,11 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/oracle/sources" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/bisoncraft/mesh/tatanka/admin" + "github.com/bisoncraft/mesh/tatanka/pb" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -40,6 +42,9 @@ const ( // whitelistProtocol is the protocol used to verify the whitelist alignment of a tatanka node. whitelistProtocol = "/tatanka/whitelist/1.0.0" + + // quotaHandshakeProtocol is the protocol used to exchange quota information between tatanka nodes. + quotaHandshakeProtocol = "/tatanka/quota-handshake/1.0.0" ) // Config is the configuration for the tatanka node @@ -53,9 +58,9 @@ type Config struct { WhitelistPath string // Oracle Configuration - CMCKey string - TatumKey string - CryptoApisKey string + CMCKey string + TatumKey string + BlockcypherToken string } // Option is a functional option for configuring TatankaNode. @@ -71,11 +76,12 @@ func WithHost(h host.Host) Option { // Oracle defines the requirements for implementing an oracle. type Oracle interface { Run(ctx context.Context) - MergePrices(sourcedUpdate *oracle.SourcedPriceUpdate) map[oracle.Ticker]float64 - MergeFeeRates(sourcedUpdate *oracle.SourcedFeeRateUpdate) map[oracle.Network]*big.Int - Prices() map[oracle.Ticker]float64 - FeeRates() map[oracle.Network]*big.Int - GetSourceWeight(sourceName string) float64 + Merge(update *oracle.OracleUpdate, senderID string) *oracle.MergeResult + Price(ticker oracle.Ticker) (float64, bool) + FeeRate(network oracle.Network) (*big.Int, bool) + GetLocalQuotas() map[string]*sources.QuotaStatus + UpdatePeerSourceQuota(peerID string, quota *oracle.TimestampedQuotaStatus, source string) + OracleSnapshot() *oracle.OracleSnapshot } // TatankaNode is a permissioned node in the tatanka mesh @@ -192,6 +198,7 @@ func (t *TatankaNode) Run(ctx context.Context) error { handleBroadcastMessage: t.handleBroadcastMessage, handleClientConnectionMessage: t.handleClientConnectionMessage, handleOracleUpdate: t.handleOracleUpdate, + handleQuotaHeartbeat: t.handleQuotaHeartbeat, }) if err != nil { t.markReady(err) @@ -213,11 +220,18 @@ func (t *TatankaNode) Run(ctx context.Context) error { // Only create oracle if not provided (e.g., via test setup) if t.oracle == nil { t.oracle, err = oracle.New(&oracle.Config{ - Log: t.config.Logger, - CMCKey: t.config.CMCKey, - TatumKey: t.config.TatumKey, - CryptoApisKey: t.config.CryptoApisKey, - PublishUpdate: t.gossipSub.publishOracleUpdate, + Log: t.config.Logger, + CMCKey: t.config.CMCKey, + TatumKey: t.config.TatumKey, + BlockcypherToken: t.config.BlockcypherToken, + NodeID: t.node.ID().String(), + PublishUpdate: t.gossipSub.publishOracleUpdate, + OnStateUpdate: func(update *oracle.OracleSnapshot) { + if t.adminServer != nil { + t.adminServer.BroadcastOracleUpdate("oracle_update", update) + } + }, + PublishQuotaHeartbeat: t.gossipSub.publishQuotaHeartbeat, }) if err != nil { return fmt.Errorf("failed to create oracle: %v", err) @@ -229,7 +243,7 @@ func (t *TatankaNode) Run(ctx context.Context) error { } if t.config.AdminPort > 0 { adminAddr := fmt.Sprintf(":%d", t.config.AdminPort) - server := admin.NewServer(t.config.Logger, adminAddr) + server := admin.NewServer(t.config.Logger, adminAddr, t.oracle) whitelistIDs := t.getWhitelist().allPeerIDs() whitelist := make([]string, 0, len(whitelistIDs)) for id := range whitelistIDs { @@ -251,7 +265,17 @@ func (t *TatankaNode) Run(ctx context.Context) error { t.adminServer = server } - t.connectionManager = newMeshConnectionManager(t.config.Logger, t.node, t.getWhitelist(), adminCallback) + t.connectionManager = newMeshConnectionManager( + t.config.Logger, t.node, t.getWhitelist(), adminCallback, + func() map[string]*pb.QuotaStatus { + return quotaStatusesToPb(t.oracle.GetLocalQuotas()) + }, + func(peerID peer.ID, quotas map[string]*pb.QuotaStatus) { + for source, q := range quotas { + t.oracle.UpdatePeerSourceQuota(peerID.String(), pbToTimestampedQuotaStatus(q), source) + } + }, + ) t.log.Infof("Admin interface available (or not) on :%d", t.config.AdminPort) @@ -391,6 +415,7 @@ func (t *TatankaNode) setupStreamHandlers() { t.setStreamHandler(protocols.AvailableMeshNodesProtocol, t.handleAvailableMeshNodes, t.requireBonds) t.setStreamHandler(discoveryProtocol, t.handleDiscovery, t.isWhitelistPeer) t.setStreamHandler(whitelistProtocol, t.handleWhitelist, t.isWhitelistPeer) + t.setStreamHandler(quotaHandshakeProtocol, t.handleQuotaHandshake, t.isWhitelistPeer) } func (t *TatankaNode) setupObservability() { diff --git a/tatanka/tatanka_test.go b/tatanka/tatanka_test.go index 93f34ff..0786e48 100644 --- a/tatanka/tatanka_test.go +++ b/tatanka/tatanka_test.go @@ -23,9 +23,9 @@ import ( "github.com/bisoncraft/mesh/bond" "github.com/bisoncraft/mesh/codec" "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/oracle/sources" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" - "github.com/bisoncraft/mesh/tatanka/pb" "google.golang.org/protobuf/proto" ) @@ -58,37 +58,32 @@ func (to *testOracle) Run(ctx context.Context) { <-ctx.Done() } -func (to *testOracle) Next() <-chan any { - return nil -} - -func (to *testOracle) MergePrices(sourcedUpdate *oracle.SourcedPriceUpdate) map[oracle.Ticker]float64 { - return make(map[oracle.Ticker]float64) +func (to *testOracle) Merge(update *oracle.OracleUpdate, senderID string) *oracle.MergeResult { + return &oracle.MergeResult{} } -func (to *testOracle) MergeFeeRates(sourcedUpdate *oracle.SourcedFeeRateUpdate) map[oracle.Network]*big.Int { - return make(map[oracle.Network]*big.Int) +func (to *testOracle) Price(oracle.Ticker) (float64, bool) { return 0, false } +func (to *testOracle) FeeRate(oracle.Network) (*big.Int, bool) { + return nil, false } -func (to *testOracle) Prices() map[oracle.Ticker]float64 { return make(map[oracle.Ticker]float64) } -func (to *testOracle) FeeRates() map[oracle.Network]*big.Int { return make(map[oracle.Network]*big.Int) } -func (to *testOracle) GetSourceWeight(sourceName string) float64 { return 1.0 } +func (to *testOracle) GetLocalQuotas() map[string]*sources.QuotaStatus { return nil } +func (to *testOracle) UpdatePeerSourceQuota(string, *oracle.TimestampedQuotaStatus, string) {} +func (to *testOracle) OracleSnapshot() *oracle.OracleSnapshot { return nil } -// tOracle is a test oracle that tracks merged price and fee rate updates. +// tOracle is a test oracle that tracks merged updates. type tOracle struct { - mtx sync.Mutex - mergedPrices []*oracle.SourcedPriceUpdate - mergedFeeRates []*oracle.SourcedFeeRateUpdate - prices map[oracle.Ticker]float64 - feeRates map[oracle.Network]*big.Int + mtx sync.Mutex + merged []*oracle.OracleUpdate + prices map[oracle.Ticker]float64 + feeRates map[oracle.Network]*big.Int } var _ Oracle = (*tOracle)(nil) func newTOracle() *tOracle { return &tOracle{ - mergedPrices: make([]*oracle.SourcedPriceUpdate, 0), - mergedFeeRates: make([]*oracle.SourcedFeeRateUpdate, 0), - prices: make(map[oracle.Ticker]float64), - feeRates: make(map[oracle.Network]*big.Int), + merged: make([]*oracle.OracleUpdate, 0), + prices: make(map[oracle.Ticker]float64), + feeRates: make(map[oracle.Network]*big.Int), } } @@ -96,70 +91,61 @@ func (t *tOracle) Run(ctx context.Context) { <-ctx.Done() } -func (t *tOracle) MergePrices(sourcedUpdate *oracle.SourcedPriceUpdate) map[oracle.Ticker]float64 { +func (t *tOracle) Merge(update *oracle.OracleUpdate, senderID string) *oracle.MergeResult { t.mtx.Lock() defer t.mtx.Unlock() - t.mergedPrices = append(t.mergedPrices, sourcedUpdate) + t.merged = append(t.merged, update) - // Return the prices that were updated - updated := make(map[oracle.Ticker]float64) - for _, p := range sourcedUpdate.Prices { - updated[p.Ticker] = p.Price - t.prices[p.Ticker] = p.Price - } - return updated -} + result := &oracle.MergeResult{} -func (t *tOracle) MergeFeeRates(sourcedUpdate *oracle.SourcedFeeRateUpdate) map[oracle.Network]*big.Int { - t.mtx.Lock() - defer t.mtx.Unlock() - t.mergedFeeRates = append(t.mergedFeeRates, sourcedUpdate) + if len(update.Prices) > 0 { + result.Prices = make(map[oracle.Ticker]float64, len(update.Prices)) + for ticker, price := range update.Prices { + result.Prices[ticker] = price + t.prices[ticker] = price + } + } - // Return the fee rates that were updated - updated := make(map[oracle.Network]*big.Int) - for _, fr := range sourcedUpdate.FeeRates { - // Decode the big-endian bytes to big.Int - bigIntValue := new(big.Int).SetBytes(fr.FeeRate) - updated[fr.Network] = bigIntValue - t.feeRates[fr.Network] = bigIntValue + if len(update.FeeRates) > 0 { + result.FeeRates = make(map[oracle.Network]*big.Int, len(update.FeeRates)) + for network, feeRate := range update.FeeRates { + result.FeeRates[network] = feeRate + t.feeRates[network] = feeRate + } } - return updated + + return result } -func (t *tOracle) Prices() map[oracle.Ticker]float64 { +func (t *tOracle) Price(ticker oracle.Ticker) (float64, bool) { t.mtx.Lock() defer t.mtx.Unlock() - // Return a copy to avoid races with concurrent modifications - result := make(map[oracle.Ticker]float64) - for k, v := range t.prices { - result[k] = v - } - return result + price, found := t.prices[ticker] + return price, found } -func (t *tOracle) FeeRates() map[oracle.Network]*big.Int { +func (t *tOracle) FeeRate(network oracle.Network) (*big.Int, bool) { t.mtx.Lock() defer t.mtx.Unlock() - // Return a copy to avoid races with concurrent modifications - result := make(map[oracle.Network]*big.Int) - for k, v := range t.feeRates { - result[k] = v + value, found := t.feeRates[network] + if !found { + return nil, false } - return result + return new(big.Int).Set(value), true } -func (t *tOracle) GetSourceWeight(sourceName string) float64 { - return 1.0 -} +func (t *tOracle) GetLocalQuotas() map[string]*sources.QuotaStatus { return nil } + +func (t *tOracle) UpdatePeerSourceQuota(string, *oracle.TimestampedQuotaStatus, string) {} + +func (t *tOracle) OracleSnapshot() *oracle.OracleSnapshot { return nil } -// SetPrices sets the prices map with proper locking. func (t *tOracle) SetPrices(prices map[oracle.Ticker]float64) { t.mtx.Lock() defer t.mtx.Unlock() t.prices = prices } -// SetFeeRates sets the fee rates map with proper locking. func (t *tOracle) SetFeeRates(feeRates map[oracle.Network]*big.Int) { t.mtx.Lock() defer t.mtx.Unlock() @@ -769,43 +755,21 @@ func requireEventually(t *testing.T, condition func() bool, timeout, tick time.D t.Fatalf("Condition failed after %v: %s", timeout, fmt.Sprintf(msg, args...)) } -// pbNodePriceUpdate converts a SourcedPriceUpdate to a NodeOracleUpdate for testing. -func pbNodePriceUpdate(update *oracle.SourcedPriceUpdate) *pb.NodeOracleUpdate { - pbPrices := make([]*pb.SourcedPrice, len(update.Prices)) - for i, p := range update.Prices { - pbPrices[i] = &pb.SourcedPrice{ - Ticker: string(p.Ticker), - Price: p.Price, - } - } - return &pb.NodeOracleUpdate{ - Update: &pb.NodeOracleUpdate_PriceUpdate{ - PriceUpdate: &pb.SourcedPriceUpdate{ - Source: update.Source, - Timestamp: update.Stamp.Unix(), - Prices: pbPrices, - }, - }, +// newPriceUpdate creates an OracleUpdate with only prices for testing. +func newPriceUpdate(source string, stamp time.Time, prices map[oracle.Ticker]float64) *oracle.OracleUpdate { + return &oracle.OracleUpdate{ + Source: source, + Stamp: stamp, + Prices: prices, } } -// pbNodeFeeRateUpdate converts a SourcedFeeRateUpdate to a NodeOracleUpdate for testing. -func pbNodeFeeRateUpdate(update *oracle.SourcedFeeRateUpdate) *pb.NodeOracleUpdate { - pbFeeRates := make([]*pb.SourcedFeeRate, len(update.FeeRates)) - for i, fr := range update.FeeRates { - pbFeeRates[i] = &pb.SourcedFeeRate{ - Network: string(fr.Network), - FeeRate: fr.FeeRate, - } - } - return &pb.NodeOracleUpdate{ - Update: &pb.NodeOracleUpdate_FeeRateUpdate{ - FeeRateUpdate: &pb.SourcedFeeRateUpdate{ - Source: update.Source, - Timestamp: update.Stamp.Unix(), - FeeRates: pbFeeRates, - }, - }, +// newFeeRateUpdate creates an OracleUpdate with only fee rates for testing. +func newFeeRateUpdate(source string, stamp time.Time, feeRates map[oracle.Network]*big.Int) *oracle.OracleUpdate { + return &oracle.OracleUpdate{ + Source: source, + Stamp: stamp, + FeeRates: feeRates, } } @@ -1197,18 +1161,11 @@ func TestGossipSubOracleUpdates_PriceUpdates(t *testing.T) { // Node 0 publishes price updates now := time.Now() - sourcedUpdate := &oracle.SourcedPriceUpdate{ - Source: "test-source", - Stamp: now, - Weight: 1.0, - Prices: []*oracle.SourcedPrice{ - {Ticker: "BTC", Price: 50000.0}, - {Ticker: "ETH", Price: 3000.0}, - }, - } - - oracleUpdate := pbNodePriceUpdate(sourcedUpdate) - if err := nodes[0].gossipSub.publishOracleUpdate(ctx, oracleUpdate); err != nil { + update := newPriceUpdate("test-source", now, map[oracle.Ticker]float64{ + "BTC": 50000.0, + "ETH": 3000.0, + }) + if err := nodes[0].gossipSub.publishOracleUpdate(ctx, update); err != nil { t.Fatalf("Failed to publish oracle update: %v", err) } @@ -1218,19 +1175,18 @@ func TestGossipSubOracleUpdates_PriceUpdates(t *testing.T) { // Verify that all nodes received and merged the updates for i := 0; i < numMeshNodes; i++ { oracles[i].mtx.Lock() - mergedCount := len(oracles[i].mergedPrices) + mergedCount := len(oracles[i].merged) oracles[i].mtx.Unlock() if mergedCount != 1 { - t.Errorf("Node %d: expected 1 merged price update, got %d", i, mergedCount) + t.Errorf("Node %d: expected 1 merged update, got %d", i, mergedCount) continue } oracles[i].mtx.Lock() - merged := oracles[i].mergedPrices[0] + merged := oracles[i].merged[0] oracles[i].mtx.Unlock() - // Verify the merged update if merged.Source != "test-source" { t.Errorf("Node %d: expected source 'test-source', got %s", i, merged.Source) } @@ -1238,11 +1194,11 @@ func TestGossipSubOracleUpdates_PriceUpdates(t *testing.T) { t.Errorf("Node %d: expected 2 prices, got %d", i, len(merged.Prices)) continue } - if merged.Prices[0].Ticker != "BTC" || merged.Prices[0].Price != 50000.0 { - t.Errorf("Node %d: first price incorrect: %+v", i, merged.Prices[0]) + if merged.Prices["BTC"] != 50000.0 { + t.Errorf("Node %d: BTC price incorrect: %v", i, merged.Prices["BTC"]) } - if merged.Prices[1].Ticker != "ETH" || merged.Prices[1].Price != 3000.0 { - t.Errorf("Node %d: second price incorrect: %+v", i, merged.Prices[1]) + if merged.Prices["ETH"] != 3000.0 { + t.Errorf("Node %d: ETH price incorrect: %v", i, merged.Prices["ETH"]) } } } @@ -1286,18 +1242,11 @@ func TestGossipSubOracleUpdates_FeeRateUpdates(t *testing.T) { // Node 1 publishes fee rate updates now := time.Now() - sourcedUpdate := &oracle.SourcedFeeRateUpdate{ - Source: "test-source", - Stamp: now, - Weight: 1.0, - FeeRates: []*oracle.SourcedFeeRate{ - {Network: "Bitcoin", FeeRate: big.NewInt(100).Bytes()}, - {Network: "Ethereum", FeeRate: big.NewInt(50).Bytes()}, - }, - } - - oracleUpdate := pbNodeFeeRateUpdate(sourcedUpdate) - if err := nodes[1].gossipSub.publishOracleUpdate(ctx, oracleUpdate); err != nil { + update := newFeeRateUpdate("test-source", now, map[oracle.Network]*big.Int{ + "Bitcoin": big.NewInt(100), + "Ethereum": big.NewInt(50), + }) + if err := nodes[1].gossipSub.publishOracleUpdate(ctx, update); err != nil { t.Fatalf("Failed to publish oracle update: %v", err) } @@ -1307,19 +1256,18 @@ func TestGossipSubOracleUpdates_FeeRateUpdates(t *testing.T) { // Verify that all nodes received and merged the updates for i := 0; i < numMeshNodes; i++ { oracles[i].mtx.Lock() - mergedCount := len(oracles[i].mergedFeeRates) + mergedCount := len(oracles[i].merged) oracles[i].mtx.Unlock() if mergedCount != 1 { - t.Errorf("Node %d: expected 1 merged fee rate update, got %d", i, mergedCount) + t.Errorf("Node %d: expected 1 merged update, got %d", i, mergedCount) continue } oracles[i].mtx.Lock() - merged := oracles[i].mergedFeeRates[0] + merged := oracles[i].merged[0] oracles[i].mtx.Unlock() - // Verify the merged update if merged.Source != "test-source" { t.Errorf("Node %d: expected source 'test-source', got %s", i, merged.Source) } @@ -1327,11 +1275,11 @@ func TestGossipSubOracleUpdates_FeeRateUpdates(t *testing.T) { t.Errorf("Node %d: expected 2 fee rates, got %d", i, len(merged.FeeRates)) continue } - if merged.FeeRates[0].Network != "Bitcoin" || new(big.Int).SetBytes(merged.FeeRates[0].FeeRate).Cmp(big.NewInt(100)) != 0 { - t.Errorf("Node %d: first fee rate incorrect: %+v", i, merged.FeeRates[0]) + if merged.FeeRates["Bitcoin"].Cmp(big.NewInt(100)) != 0 { + t.Errorf("Node %d: Bitcoin fee rate incorrect: %v", i, merged.FeeRates["Bitcoin"]) } - if merged.FeeRates[1].Network != "Ethereum" || new(big.Int).SetBytes(merged.FeeRates[1].FeeRate).Cmp(big.NewInt(50)) != 0 { - t.Errorf("Node %d: second fee rate incorrect: %+v", i, merged.FeeRates[1]) + if merged.FeeRates["Ethereum"].Cmp(big.NewInt(50)) != 0 { + t.Errorf("Node %d: Ethereum fee rate incorrect: %v", i, merged.FeeRates["Ethereum"]) } } } @@ -1376,72 +1324,37 @@ func TestGossipSubOracleUpdates_MultipleNodes(t *testing.T) { now := time.Now() // Node 0 publishes price updates - priceUpdate0 := &oracle.SourcedPriceUpdate{ - Source: "node-0", - Stamp: now, - Weight: 1.0, - Prices: []*oracle.SourcedPrice{ - {Ticker: "BTC", Price: 50000.0}, - }, - } - if err := nodes[0].gossipSub.publishOracleUpdate(ctx, pbNodePriceUpdate(priceUpdate0)); err != nil { + if err := nodes[0].gossipSub.publishOracleUpdate(ctx, newPriceUpdate("node-0", now, map[oracle.Ticker]float64{ + "BTC": 50000.0, + })); err != nil { t.Fatalf("Failed to publish price update from node 0: %v", err) } // Node 1 publishes fee rate updates - feeRateUpdate := &oracle.SourcedFeeRateUpdate{ - Source: "node-1", - Stamp: now, - Weight: 1.0, - FeeRates: []*oracle.SourcedFeeRate{ - {Network: "Bitcoin", FeeRate: big.NewInt(100).Bytes()}, - }, - } - if err := nodes[1].gossipSub.publishOracleUpdate(ctx, pbNodeFeeRateUpdate(feeRateUpdate)); err != nil { + if err := nodes[1].gossipSub.publishOracleUpdate(ctx, newFeeRateUpdate("node-1", now, map[oracle.Network]*big.Int{ + "Bitcoin": big.NewInt(100), + })); err != nil { t.Fatalf("Failed to publish fee rate update from node 1: %v", err) } // Node 2 publishes price updates - priceUpdate2 := &oracle.SourcedPriceUpdate{ - Source: "node-2", - Stamp: now, - Weight: 0.8, - Prices: []*oracle.SourcedPrice{ - {Ticker: "ETH", Price: 3000.0}, - }, - } - if err := nodes[2].gossipSub.publishOracleUpdate(ctx, pbNodePriceUpdate(priceUpdate2)); err != nil { + if err := nodes[2].gossipSub.publishOracleUpdate(ctx, newPriceUpdate("node-2", now, map[oracle.Ticker]float64{ + "ETH": 3000.0, + })); err != nil { t.Fatalf("Failed to publish price update from node 2: %v", err) } // Wait for gossip propagation time.Sleep(2 * time.Second) - // Verify all nodes received all price updates (2 price updates from nodes 0 and 2) - for i, oracle := range oracles { - oracle.mtx.Lock() - priceCount := len(oracle.mergedPrices) - oracle.mtx.Unlock() - - // All nodes should receive both price updates - expectedPriceCount := 2 - - if priceCount != expectedPriceCount { - t.Errorf("Node %d: expected %d price updates, got %d", i, expectedPriceCount, priceCount) - } - } - - // Verify all nodes received the fee rate update - for i, oracle := range oracles { - oracle.mtx.Lock() - feeRateCount := len(oracle.mergedFeeRates) - oracle.mtx.Unlock() - - // All nodes should receive the fee rate update - expectedFeeRateCount := 1 + // Verify all nodes received all 3 updates (2 price + 1 fee rate) + for i, orc := range oracles { + orc.mtx.Lock() + mergedCount := len(orc.merged) + orc.mtx.Unlock() - if feeRateCount != expectedFeeRateCount { - t.Errorf("Node %d: expected %d fee rate updates, got %d", i, expectedFeeRateCount, feeRateCount) + if mergedCount != 3 { + t.Errorf("Node %d: expected 3 merged updates, got %d", i, mergedCount) } } } @@ -1535,28 +1448,16 @@ func TestGossipSubOracleUpdates_ClientDelivery(t *testing.T) { // Node 0 publishes price updates via gossipsub now := time.Now() - priceUpdate := &oracle.SourcedPriceUpdate{ - Source: "test-source", - Stamp: now, - Weight: 1.0, - Prices: []*oracle.SourcedPrice{ - {Ticker: "BTC", Price: 50000.0}, - }, - } - if err := nodes[0].gossipSub.publishOracleUpdate(ctx, pbNodePriceUpdate(priceUpdate)); err != nil { + if err := nodes[0].gossipSub.publishOracleUpdate(ctx, newPriceUpdate("test-source", now, map[oracle.Ticker]float64{ + "BTC": 50000.0, + })); err != nil { t.Fatalf("Failed to publish price update: %v", err) } // Node 1 publishes fee rate updates via gossipsub - feeRateUpdate := &oracle.SourcedFeeRateUpdate{ - Source: "test-source", - Stamp: now, - Weight: 1.0, - FeeRates: []*oracle.SourcedFeeRate{ - {Network: "BTC", FeeRate: big.NewInt(100).Bytes()}, - }, - } - if err := nodes[1].gossipSub.publishOracleUpdate(ctx, pbNodeFeeRateUpdate(feeRateUpdate)); err != nil { + if err := nodes[1].gossipSub.publishOracleUpdate(ctx, newFeeRateUpdate("test-source", now, map[oracle.Network]*big.Int{ + "BTC": big.NewInt(100), + })); err != nil { t.Fatalf("Failed to publish fee rate update: %v", err) } diff --git a/testing/client/client.go b/testing/client/client.go index c7b46f2..5358105 100644 --- a/testing/client/client.go +++ b/testing/client/client.go @@ -20,7 +20,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/bisoncraft/mesh/bond" tmc "github.com/bisoncraft/mesh/client" - "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "google.golang.org/protobuf/proto" ) @@ -430,14 +430,14 @@ func (c *Client) Run(ctx context.Context, bonds []*bond.BondParams) { // decodeTopicData decodes topic data to a human-readable string. func decodeTopicData(topic string, data []byte) string { - if strings.HasPrefix(topic, oracle.PriceTopicPrefix) { - ticker := topic[len(oracle.PriceTopicPrefix):] + if strings.HasPrefix(topic, protocols.PriceTopicPrefix) { + ticker := topic[len(protocols.PriceTopicPrefix):] var priceUpdate protocolsPb.ClientPriceUpdate if err := proto.Unmarshal(data, &priceUpdate); err == nil { return fmt.Sprintf("%s: $%.2f", ticker, priceUpdate.Price) } - } else if strings.HasPrefix(topic, oracle.FeeRateTopicPrefix) { - network := topic[len(oracle.FeeRateTopicPrefix):] + } else if strings.HasPrefix(topic, protocols.FeeRateTopicPrefix) { + network := topic[len(protocols.FeeRateTopicPrefix):] var feeRateUpdate protocolsPb.ClientFeeRateUpdate if err := proto.Unmarshal(data, &feeRateUpdate); err == nil { feeRate := new(big.Int).SetBytes(feeRateUpdate.FeeRate)