diff --git a/client/mesh_connection.go b/client/mesh_connection.go index 684abd1..8becaa9 100644 --- a/client/mesh_connection.go +++ b/client/mesh_connection.go @@ -25,6 +25,7 @@ const ( var ( errInvalidBondIndex = errors.New("invalid bond index") + errUnauthorized = errors.New("unauthorized") ) // meshConnection represents a connection to a mesh peer. @@ -134,7 +135,20 @@ func (m *meshConnection) broadcast(ctx context.Context, topic string, data []byt Topic: topic, Data: data, } - return codec.WriteLengthPrefixedMessage(s, req) + if err := codec.WriteLengthPrefixedMessage(s, req); err != nil { + return fmt.Errorf("failed to send publish request: %w", err) + } + + resp := &protocolsPb.Response{} + if err := codec.ReadLengthPrefixedMessage(s, resp); err != nil { + return fmt.Errorf("failed to read publish response: %w", err) + } + + if respErr := resp.GetError(); respErr != nil { + return fmt.Errorf("publish failed: %s", respErr.GetMessage()) + } + + return nil } // unsubscribe unsubscribes from the provided topic. @@ -176,7 +190,7 @@ func (m *meshConnection) postBondInternal(ctx context.Context, req *protocolsPb. switch v := resp.Response.(type) { case *protocolsPb.Response_Error: - switch v.Error.GetError().(type) { + switch rErr := v.Error.GetError().(type) { case *protocolsPb.Error_PostBondError: bondErr := v.Error.GetPostBondError() err := m.bondInfo.RemoveBondAtIndex(bondErr.InvalidBondIndex) @@ -191,8 +205,11 @@ func (m *meshConnection) postBondInternal(ctx context.Context, req *protocolsPb. errMsg := v.Error.GetMessage() return fmt.Errorf("%s: %v", hostID, errMsg) + case *protocolsPb.Error_Unauthorized: + return errUnauthorized + default: - return fmt.Errorf("%s: unknown error type %T", hostID, v) + return fmt.Errorf("%s: unknown error type %T", hostID, rErr) } case *protocolsPb.Response_PostBondResponse: @@ -208,7 +225,7 @@ func (m *meshConnection) postBondInternal(ctx context.Context, req *protocolsPb. // postBond posts the connection's bond, retrying on invalid bond index errors. This needs to be // called before the connection's push stream is established. func (m *meshConnection) postBond(ctx context.Context) error { - hostID := m.host.ID() + hostID := m.host.ID().ShortString() for range maxPostBondRetries { req, err := bond.PostBondReqFromBondInfo(m.bondInfo) if err != nil { @@ -225,6 +242,9 @@ func (m *meshConnection) postBond(ctx context.Context) error { // Retry if an invalid bond index error is returned. continue + case errors.Is(err, errUnauthorized): + return err + case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled): return fmt.Errorf("%s: post bond retry cancelled due to context: %w", hostID, err) diff --git a/client/mesh_connection_manager.go b/client/mesh_connection_manager.go index ffac9ae..fb2c405 100644 --- a/client/mesh_connection_manager.go +++ b/client/mesh_connection_manager.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "math/rand/v2" "sync" "sync/atomic" @@ -188,6 +189,13 @@ func (m *meshConnectionManager) connectToAvailableNode(ctx context.Context) (mes result := m.connectToNode(ctx, peerID) if result.connectErr != nil { + if errors.Is(result.connectErr, errUnauthorized) { + m.log.Infof("%s: unauthorized", peerID.ShortString()) + errCh := make(chan error, 1) + errCh <- result.connectErr + return nil, errCh, false + } + m.log.Errorf("Connection to %s failed: %v", peerID.ShortString(), result.connectErr) continue } @@ -220,6 +228,10 @@ func (m *meshConnectionManager) run(ctx context.Context) { m.setPrimaryConnection(conn) runErrCh = errCh } else { + // No available nodes - schedule retry with backoff + if errCh != nil { + runErrCh = errCh + } reconnectTimer.Reset(backoff) backoff *= 2 if backoff > maxReconnectDelay { @@ -237,9 +249,14 @@ func (m *meshConnectionManager) run(ctx context.Context) { refreshTimer.Reset(meshNodesRefreshInterval) case <-reconnectTimer.C: attemptConnect() - case <-runErrCh: + case err := <-runErrCh: m.setPrimaryConnection(nil) runErrCh = nil + if errors.Is(err, errUnauthorized) { + return + } + + // Primary connection failed, immediately try alternatives attemptConnect() } } diff --git a/client/mesh_connection_manager_test.go b/client/mesh_connection_manager_test.go index 71bdd39..3482ab6 100644 --- a/client/mesh_connection_manager_test.go +++ b/client/mesh_connection_manager_test.go @@ -77,7 +77,7 @@ func TestMeshConnectionManagerFailover(t *testing.T) { return false } return c.remotePeerID() == expected - }, 2*time.Second, 10*time.Millisecond, "primary connection not set to %s", expected) + }, 5*time.Second, 10*time.Millisecond, "primary connection not set to %s", expected) } waitForPrimary(node1ID) @@ -85,7 +85,7 @@ func TestMeshConnectionManagerFailover(t *testing.T) { requireEventually(t, func() bool { addrs := h.Peerstore().Addrs(node2ID) return len(addrs) > 0 - }, 2*time.Second, 10*time.Millisecond, "peerstore missing addresses for node 2") + }, 5*time.Second, 10*time.Millisecond, "peerstore missing addresses for node 2") node1Available = false conn1.fail(errors.New("node-1 down")) diff --git a/client/mesh_connection_test.go b/client/mesh_connection_test.go index 9a48790..2476442 100644 --- a/client/mesh_connection_test.go +++ b/client/mesh_connection_test.go @@ -208,6 +208,13 @@ func (h *meshConnHarness) setupDefaultHandlers(t *testing.T) { t.Fatalf("Publish read error: %v", err) } h.publishReceived <- &msg + + resp := &protocolsPb.Response{ + Response: &protocolsPb.Response_Success{Success: &protocolsPb.Success{}}, + } + if err := codec.WriteLengthPrefixedMessage(s, resp); err != nil { + t.Fatalf("Failed to send publish response: %v", err) + } }) h.tatankaHost.SetStreamHandler(protocols.ClientSubscribeProtocol, func(s network.Stream) { diff --git a/cmd/testclient/main.go b/cmd/testclient/main.go index 6178b9f..b4b2486 100644 --- a/cmd/testclient/main.go +++ b/cmd/testclient/main.go @@ -49,6 +49,7 @@ type Config struct { BondParams []*bondParamsFlag `long:"bondparams" description:"The test client bond params."` ClientPort int `long:"clientport" description:"The port to listen on for client connections"` WebPort int `long:"webport" description:"The web interface port."` + Spam bool `long:"spam" description:"Enable publish spam feature."` } // defaultAppDataDir returns the default application data directory. @@ -204,6 +205,7 @@ func main() { ClientPort: cfg.ClientPort, WebPort: cfg.WebPort, Logger: log, + Spam: cfg.Spam, } tc, err := client.NewClient(tcCfg) diff --git a/tatanka/ban_manager.go b/tatanka/ban_manager.go new file mode 100644 index 0000000..6c193ea --- /dev/null +++ b/tatanka/ban_manager.go @@ -0,0 +1,308 @@ +package tatanka + +import ( + "context" + "fmt" + "slices" + "strings" + "sync" + "time" + + pb "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/peer" +) + +type infractionType int + +const ( + MalformedMessage infractionType = iota + InvalidBond + NodeImpersonation + RateLimitViolation + RateLimitAbuse +) + +const ( + maxInfractionPenalties = 10 + infractionRetentionPeriod = time.Hour +) + +func infractionTypeString(it infractionType) string { + switch it { + case MalformedMessage: + return "sent malformed message" + case InvalidBond: + return "invalid bond" + case NodeImpersonation: + return "attempted to impersonate node" + case RateLimitViolation: + return "rate limit violation" + case RateLimitAbuse: + return "rate limit abuse" + default: + return "unknown infraction" + } +} + +type infraction struct { + infractionType infractionType + penalty uint32 + expiry time.Time +} + + +type banManagerConfig struct { + disconnectClient func(string) + nodeID peer.ID + publishInfraction func(context.Context, *pb.ClientInfractionMsg) error + now func() time.Time + log slog.Logger +} + +type banManager struct { + cfg *banManagerConfig + mtx sync.RWMutex + clientInfractions map[string][]infraction + log slog.Logger +} + +func newBanManager(cfg *banManagerConfig) *banManager { + return &banManager{ + cfg: cfg, + clientInfractions: make(map[string][]infraction), + log: cfg.log, + } +} + +func infractionPenaltyAndDuration(it infractionType) (uint32, time.Duration, error) { + switch it { + case MalformedMessage, InvalidBond: + return 1, infractionRetentionPeriod, nil + case NodeImpersonation: + return maxInfractionPenalties, infractionRetentionPeriod * 24, nil + case RateLimitViolation: + return 2, infractionRetentionPeriod * 2, nil + case RateLimitAbuse: + return 5, infractionRetentionPeriod * 6, nil + default: + return 0, 0, fmt.Errorf("unexpected infraction type provided %v", it) + } +} + +func (bm *banManager) recordInfraction(ip string, id peer.ID, infractionType infractionType) error { + bm.mtx.Lock() + defer bm.mtx.Unlock() + + if _, ok := bm.clientInfractions[ip]; !ok { + bm.clientInfractions[ip] = make([]infraction, 0) + } + + infractions := bm.clientInfractions[ip] + + penalty, duration, err := infractionPenaltyAndDuration(infractionType) + if err != nil { + return err + } + + expiry := bm.cfg.now().Add(duration) + infractions = append(infractions, infraction{ + infractionType: infractionType, + penalty: penalty, + expiry: expiry, + }) + + bm.clientInfractions[ip] = infractions + + now := bm.cfg.now() + var penalties uint32 + for _, inf := range infractions { + if !inf.expiry.Before(now) { + penalties += inf.penalty + } + } + + // Publish the infraction for other nodes to gossip + reporterBytes, err := bm.cfg.nodeID.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal node ID for infraction: %v", err) + } + + infractionMsg := &pb.ClientInfractionMsg{ + Ip: ip, + Reporter: reporterBytes, + InfractionType: uint32(infractionType), + Penalty: penalty, + Expiry: expiry.UnixMilli(), + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + if err := bm.cfg.publishInfraction(ctx, infractionMsg); err != nil { + return fmt.Errorf("failed to publish infraction for client %s (%s): %v", id.ShortString(), ip, err) + } + + // Disconnect the client once it meets or exceeds the maximum infraction penalties. + if penalties >= maxInfractionPenalties { + bm.cfg.disconnectClient(ip) + } + + return nil +} + +func (bm *banManager) recordRemoteInfraction(msg *pb.ClientInfractionMsg) error { + if msg == nil { + return fmt.Errorf("client infraction message cannot be nil") + } + + bm.mtx.Lock() + defer bm.mtx.Unlock() + + ip := msg.Ip + if _, ok := bm.clientInfractions[ip]; !ok { + bm.clientInfractions[ip] = make([]infraction, 0) + } + + // Prevent duplicate infractions from gossip re-delivery + infractionKey := InfractionDedupKey(msg.Ip, msg.Reporter, msg.InfractionType, msg.Expiry) + + infractions := bm.clientInfractions[ip] + for _, existing := range infractions { + existingKey := InfractionDedupKey(ip, msg.Reporter, uint32(existing.infractionType), existing.expiry.UnixMilli()) + if existingKey == infractionKey { + return nil + } + } + + infractions = append(infractions, infraction{ + infractionType: infractionType(msg.InfractionType), + penalty: msg.Penalty, + expiry: time.UnixMilli(msg.Expiry), + }) + bm.clientInfractions[ip] = infractions + + + now := bm.cfg.now() + var penalties uint32 + for _, inf := range infractions { + if !inf.expiry.Before(now) { + penalties += inf.penalty + } + } + + if penalties >= maxInfractionPenalties { + bm.cfg.disconnectClient(ip) + } + + return nil +} + +// InfractionDedupKey generates a unique key for deduplication based on (IP, reporter, type, expiry) +// This is used by both recordRemoteInfraction and syncInfractionsFromRandomPeer +func InfractionDedupKey(ip string, reporter []byte, infractionType uint32, expiry int64) string { + return fmt.Sprintf("%s_%x_%d_%d", ip, reporter, infractionType, expiry) +} + + +func (bm *banManager) purgeExpiredInfractions() { + bm.mtx.Lock() + defer bm.mtx.Unlock() + + cutoffTime := bm.cfg.now() + for ip, infractions := range bm.clientInfractions { + filtered := slices.DeleteFunc(infractions, func(inf infraction) bool { + return inf.expiry.Before(cutoffTime) + }) + infractions = filtered + bm.clientInfractions[ip] = infractions + + if len(infractions) == 0 { + delete(bm.clientInfractions, ip) + } + } +} + +func (bm *banManager) isClientBanned(ip string) bool { + bm.mtx.RLock() + defer bm.mtx.RUnlock() + + now := bm.cfg.now() + + if infractions, ok := bm.clientInfractions[ip]; ok { + var penalties uint32 + for _, inf := range infractions { + if !inf.expiry.Before(now) { + penalties += inf.penalty + } + } + + if penalties >= maxInfractionPenalties { + return true + } + } + + return false +} + +func (bm *banManager) getClientBanReason(ip string) string { + bm.mtx.RLock() + defer bm.mtx.RUnlock() + + now := bm.cfg.now() + + if infractions, ok := bm.clientInfractions[ip]; ok { + var penalties uint32 + var reasons []string + for _, inf := range infractions { + if !inf.expiry.Before(now) { + penalties += inf.penalty + reasons = append(reasons, infractionTypeString(inf.infractionType)) + } + } + + if penalties >= maxInfractionPenalties { + return strings.Join(reasons, ", ") + } + } + + return "" +} + +// getActiveInfractions returns all non-expired infractions across all clients. It is safe for +// concurrent use. +func (bm *banManager) getActiveInfractions() []*pb.ClientInfractionMsg { + bm.mtx.RLock() + defer bm.mtx.RUnlock() + + now := bm.cfg.now() + var result []*pb.ClientInfractionMsg + + for ip, infractions := range bm.clientInfractions { + for _, inf := range infractions { + if !inf.expiry.Before(now) { + result = append(result, &pb.ClientInfractionMsg{ + Ip: ip, + InfractionType: uint32(inf.infractionType), + Penalty: inf.penalty, + Expiry: inf.expiry.UnixMilli(), + }) + } + } + } + + return result +} + +func (bm *banManager) run(ctx context.Context) { + ticker := time.NewTicker(time.Minute * 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + bm.purgeExpiredInfractions() + } + } +} diff --git a/tatanka/ban_manager_test.go b/tatanka/ban_manager_test.go new file mode 100644 index 0000000..85f8ece --- /dev/null +++ b/tatanka/ban_manager_test.go @@ -0,0 +1,668 @@ +package tatanka + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + "time" + + pb "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +func testLogger() slog.Logger { + backend := slog.NewBackend(os.Stdout) + return backend.Logger("test") +} + +func generateTestPeerID(t *testing.T) peer.ID { + priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) + if err != nil { + t.Fatalf("Failed to generate key pair: %v", err) + } + peerID, err := peer.IDFromPublicKey(priv.GetPublic()) + if err != nil { + t.Fatalf("Failed to generate peer ID: %v", err) + } + return peerID +} + +func TestBanManagerRecordInfraction(t *testing.T) { + tests := []struct { + name string + ipAddr string + numInfra int + shouldBan bool + }{ + { + name: "IPv4 below threshold", + ipAddr: "192.168.1.1", + numInfra: 3, + shouldBan: false, + }, + { + name: "IPv4 at threshold", + ipAddr: "10.0.0.1", + numInfra: 10, + shouldBan: true, + }, + { + name: "IPv6 below threshold", + ipAddr: "2001:db8::1", + numInfra: 2, + shouldBan: false, + }, + { + name: "IPv6 at threshold", + ipAddr: "fe80::1", + numInfra: 10, + shouldBan: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + disconnected := false + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) { + disconnected = true + }, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + for i := 0; i < tt.numInfra; i++ { + err := bm.recordInfraction(tt.ipAddr, clientID, MalformedMessage) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + } + + if disconnected != tt.shouldBan { + t.Errorf("expected disconnect=%v, got disconnect=%v", tt.shouldBan, disconnected) + } + }) + } +} + +func TestBanManagerMultipleIPs(t *testing.T) { + disconnectedIPs := make(map[string]bool) + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) { + disconnectedIPs[ip] = true + }, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddrs := []string{"192.168.1.1", "192.168.1.2", "2001:db8::1", "2001:db8::2"} + + // Record 9 infractions for each IP (below threshold) + for _, ip := range ipAddrs { + for i := 0; i < 9; i++ { + err := bm.recordInfraction(ip, clientID, MalformedMessage) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + } + } + + // Record one more infraction for first IP to trigger threshold + err := bm.recordInfraction(ipAddrs[0], clientID, MalformedMessage) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + + if !disconnectedIPs[ipAddrs[0]] { + t.Errorf("expected disconnect for %s", ipAddrs[0]) + } + + for _, ip := range ipAddrs[1:] { + if disconnectedIPs[ip] { + t.Errorf("unexpected disconnect for %s", ip) + } + } +} + +func TestBanManagerPublishInfractionCallback(t *testing.T) { + t.Run("Publish each infraction", func(t *testing.T) { + var publishedInfractions []*pb.ClientInfractionMsg + var mtx sync.Mutex + nodeID := generateTestPeerID(t) + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: nodeID, + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { + mtx.Lock() + publishedInfractions = append(publishedInfractions, msg) + mtx.Unlock() + return nil + }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddr := "192.168.1.100" + + // Record 10 infractions - each should be published + for i := 0; i < 10; i++ { + err := bm.recordInfraction(ipAddr, clientID, MalformedMessage) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + } + + // Verify each infraction was published + mtx.Lock() + if len(publishedInfractions) != 10 { + t.Errorf("expected 10 published infractions, got %d", len(publishedInfractions)) + } + + for i, infraction := range publishedInfractions { + if infraction.Ip != ipAddr { + t.Errorf("infraction %d: expected IP %s, got %s", i, ipAddr, infraction.Ip) + } + if infraction.Penalty != 1 { + t.Errorf("infraction %d: expected penalty 1, got %d", i, infraction.Penalty) + } + if len(infraction.Reporter) == 0 { + t.Errorf("infraction %d: reporter should not be empty", i) + } + if infraction.Expiry == 0 { + t.Errorf("infraction %d: expiry should be set", i) + } + if infraction.InfractionType != uint32(MalformedMessage) { + t.Errorf("infraction %d: expected type MalformedMessage, got %d", i, infraction.InfractionType) + } + } + mtx.Unlock() + }) + + t.Run("Publish all infractions", func(t *testing.T) { + var publishedInfractions []*pb.ClientInfractionMsg + var mtx sync.Mutex + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: generateTestPeerID(t), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { + mtx.Lock() + publishedInfractions = append(publishedInfractions, msg) + mtx.Unlock() + return nil + }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddr := "192.168.1.100" + + // Record 9 infractions - each should be published, even though below ban threshold + for i := 0; i < 9; i++ { + bm.recordInfraction(ipAddr, clientID, MalformedMessage) + } + + mtx.Lock() + if len(publishedInfractions) != 9 { + t.Errorf("expected 9 published infractions, got %d", len(publishedInfractions)) + } + mtx.Unlock() + }) + + t.Run("Publish error handling", func(t *testing.T) { + var publishedBans []*pb.ClientInfractionMsg + var mtx sync.Mutex + publishErr := fmt.Errorf("publish failed") + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: generateTestPeerID(t), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { + return publishErr + }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddr := "192.168.1.100" + + // Record 10 infractions + var lastErr error + for i := 0; i < 10; i++ { + err := bm.recordInfraction(ipAddr, clientID, MalformedMessage) + if err != nil { + lastErr = err + } + } + + // Verify error was returned + if lastErr == nil { + t.Errorf("expected error from recordInfraction") + } + + mtx.Lock() + if len(publishedBans) != 0 { + t.Errorf("expected 0 bans (due to error), got %d", len(publishedBans)) + } + mtx.Unlock() + }) +} + +func TestBanManagerPublishInfractionMultipleTypes(t *testing.T) { + var publishedInfractions []*pb.ClientInfractionMsg + var mtx sync.Mutex + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: generateTestPeerID(t), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { + mtx.Lock() + publishedInfractions = append(publishedInfractions, msg) + mtx.Unlock() + return nil + }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddr := "192.168.1.100" + + // Record 5 malformed messages (penalty 1 each) and 5 invalid bonds (penalty 1 each) + for i := 0; i < 5; i++ { + bm.recordInfraction(ipAddr, clientID, MalformedMessage) + } + for i := 0; i < 5; i++ { + bm.recordInfraction(ipAddr, clientID, InvalidBond) + } + + // Should publish 10 individual infractions + mtx.Lock() + if len(publishedInfractions) != 10 { + t.Errorf("expected 10 published infractions, got %d", len(publishedInfractions)) + } + + malformedCount := 0 + invalidBondCount := 0 + for _, inf := range publishedInfractions { + if inf.InfractionType == uint32(MalformedMessage) { + malformedCount++ + } else if inf.InfractionType == uint32(InvalidBond) { + invalidBondCount++ + } + } + + if malformedCount != 5 { + t.Errorf("expected 5 malformed message infractions, got %d", malformedCount) + } + if invalidBondCount != 5 { + t.Errorf("expected 5 invalid bond infractions, got %d", invalidBondCount) + } + mtx.Unlock() +} + +func TestBanManagerIsClientBanned(t *testing.T) { + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddr := "192.168.1.100" + + // Record infractions below threshold + for i := 0; i < 9; i++ { + bm.recordInfraction(ipAddr, clientID, MalformedMessage) + } + + // Should not be banned yet + if bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to not be banned", ipAddr) + } + + // Record one more to hit threshold + bm.recordInfraction(ipAddr, clientID, MalformedMessage) + + // Should now be banned + if !bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to be banned", ipAddr) + } +} + +func TestBanManagerPurgeExpiredInfractions(t *testing.T) { + now := time.Unix(1000, 0) + currentTime := now + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: func() time.Time { return currentTime }, + }) + + clientID := peer.ID("test-client") + ipAddr := "192.168.1.50" + + // Record infractions at current time + for i := 0; i < 5; i++ { + bm.recordInfraction(ipAddr, clientID, MalformedMessage) + } + + // Verify infractions are stored + bm.mtx.Lock() + if len(bm.clientInfractions[ipAddr]) != 5 { + bm.mtx.Unlock() + t.Errorf("expected 5 infractions, got %d", len(bm.clientInfractions[ipAddr])) + return + } + bm.mtx.Unlock() + + // Advance time beyond retention period + currentTime = now.Add(infractionRetentionPeriod + 1*time.Second) + + // Purge expired infractions + bm.purgeExpiredInfractions() + + // All infractions should be gone + bm.mtx.Lock() + if _, exists := bm.clientInfractions[ipAddr]; exists { + bm.mtx.Unlock() + t.Errorf("expected infractions to be purged") + return + } + bm.mtx.Unlock() +} + +func TestBanManagerMultipleInfractionTypes(t *testing.T) { + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + ipAddr := "2001:db8::100" + + // Record 5 malformed messages and 5 invalid bonds (total 10) + for i := 0; i < 5; i++ { + err := bm.recordInfraction(ipAddr, clientID, MalformedMessage) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + } + + for i := 0; i < 5; i++ { + err := bm.recordInfraction(ipAddr, clientID, InvalidBond) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + } + + // Should be banned (total 10 penalties) + if !bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to be banned with 10 total infractions", ipAddr) + } +} + +func TestBanManagerRun(t *testing.T) { + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Start the cleanup loop + bm.run(ctx) + + // Should exit cleanly after context cancels +} + +func TestBanManagerRecordRemoteInfraction(t *testing.T) { + reporterID := generateTestPeerID(t) + reporterBytes, err := reporterID.Marshal() + if err != nil { + t.Fatalf("Failed to marshal reporter ID: %v", err) + } + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: generateTestPeerID(t), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + ipAddr := "192.168.1.100" + + // Record 10 infractions worth 1 point each to reach the ban threshold + // Use different expiry times for each to avoid deduplication + for i := 0; i < 10; i++ { + infractionMsg := &pb.ClientInfractionMsg{ + Ip: ipAddr, + Reporter: reporterBytes, + InfractionType: uint32(MalformedMessage), + Penalty: 1, + Expiry: time.Now().Add(time.Hour).Add(time.Duration(i) * time.Millisecond).UnixMilli(), + } + + err := bm.recordRemoteInfraction(infractionMsg) + if err != nil { + t.Fatalf("recordRemoteInfraction: %v", err) + } + } + + // Verify client is banned + if !bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to be banned after recording remote infractions", ipAddr) + } +} + +func TestBanManagerRemoteInfractionExpiry(t *testing.T) { + now := time.Unix(1000, 0) + currentTime := now + + reporterID := generateTestPeerID(t) + reporterBytes, err := reporterID.Marshal() + if err != nil { + t.Fatalf("Failed to marshal reporter ID: %v", err) + } + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: generateTestPeerID(t), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: func() time.Time { return currentTime }, + }) + + ipAddr := "192.168.1.100" + + // Record 10 remote infractions that expire in 1 hour + // Use different expiry times for each to avoid deduplication + for i := 0; i < 10; i++ { + infractionMsg := &pb.ClientInfractionMsg{ + Ip: ipAddr, + Reporter: reporterBytes, + InfractionType: uint32(MalformedMessage), + Penalty: 1, + Expiry: now.Add(time.Hour).Add(time.Duration(i) * time.Millisecond).UnixMilli(), + } + + err := bm.recordRemoteInfraction(infractionMsg) + if err != nil { + t.Fatalf("recordRemoteInfraction: %v", err) + } + } + + // Should be banned initially + if !bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to be banned", ipAddr) + } + + // Advance time beyond expiry + currentTime = now.Add(2 * time.Hour) + + // Should no longer be banned + if bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to not be banned after expiry", ipAddr) + } +} + + + +func TestGetIPFromStream(t *testing.T) { + // This test is minimal since we can't easily mock network streams + // But it verifies the function handles nil input gracefully + ip, err := getIPFromStream(nil) + if err == nil { + t.Errorf("expected error for nil stream, got none") + } + if ip != "" { + t.Errorf("expected empty string for nil stream, got %s", ip) + } +} + +func TestBanManagerLocalAndRemoteInfractions(t *testing.T) { + reporterID := generateTestPeerID(t) + reporterBytes, err := reporterID.Marshal() + if err != nil { + t.Fatalf("Failed to marshal reporter ID: %v", err) + } + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) {}, + nodeID: generateTestPeerID(t), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + clientID := generateTestPeerID(t) + ipAddr := "192.168.1.100" + + // Record 5 local infractions (with small delays to ensure different timestamps) + for i := 0; i < 5; i++ { + bm.recordInfraction(ipAddr, clientID, MalformedMessage) + time.Sleep(1 * time.Millisecond) // Ensure different timestamps + } + + // Record 5 remote infractions + // Use different expiry times for each to avoid deduplication + for i := 0; i < 5; i++ { + infractionMsg := &pb.ClientInfractionMsg{ + Ip: ipAddr, + Reporter: reporterBytes, + InfractionType: uint32(MalformedMessage), + Penalty: 1, + Expiry: time.Now().Add(time.Hour).Add(time.Duration(i) * time.Millisecond).UnixMilli(), + } + + err := bm.recordRemoteInfraction(infractionMsg) + if err != nil { + t.Fatalf("recordRemoteInfraction: %v", err) + } + } + + // Client should be banned (5 local + 5 remote = 10 total) + if !bm.isClientBanned(ipAddr) { + t.Errorf("expected %s to be banned with combined local and remote infractions", ipAddr) + } +} + +func TestBanManagerDisconnectCallback(t *testing.T) { + tests := []struct { + name string + ipAddr string + infractions int + shouldDisconnect bool + }{ + { + name: "Disconnect at threshold", + ipAddr: "192.168.1.1", + infractions: 10, + shouldDisconnect: true, + }, + { + name: "No disconnect below threshold", + ipAddr: "192.168.1.2", + infractions: 9, + shouldDisconnect: false, + }, + { + name: "Multiple IPs, only one at threshold", + ipAddr: "192.168.1.3", + infractions: 10, + shouldDisconnect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + disconnectCalls := []string{} + var mtx sync.Mutex + + bm := newBanManager(&banManagerConfig{ + disconnectClient: func(ip string) { + mtx.Lock() + disconnectCalls = append(disconnectCalls, ip) + mtx.Unlock() + }, + nodeID: peer.ID("this-node"), + publishInfraction: func(ctx context.Context, msg *pb.ClientInfractionMsg) error { return nil }, + log: testLogger(), + now: time.Now, + }) + + clientID := peer.ID("test-client") + + // Record infractions up to the specified count + for i := 0; i < tt.infractions; i++ { + err := bm.recordInfraction(tt.ipAddr, clientID, MalformedMessage) + if err != nil { + t.Fatalf("recordInfraction: %v", err) + } + } + + // Verify disconnect callback was invoked correctly + mtx.Lock() + if tt.shouldDisconnect { + if len(disconnectCalls) != 1 { + t.Errorf("expected 1 disconnect call, got %d", len(disconnectCalls)) + } else if disconnectCalls[0] != tt.ipAddr { + t.Errorf("expected disconnect for %s, got %s", tt.ipAddr, disconnectCalls[0]) + } + } else { + if len(disconnectCalls) != 0 { + t.Errorf("expected 0 disconnect calls, got %d", len(disconnectCalls)) + } + } + mtx.Unlock() + }) + } +} diff --git a/tatanka/broadcast_rate_limiter.go b/tatanka/broadcast_rate_limiter.go new file mode 100644 index 0000000..d3d3585 --- /dev/null +++ b/tatanka/broadcast_rate_limiter.go @@ -0,0 +1,130 @@ +package tatanka + +import ( + "context" + "math" + "sync" + "time" + + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/peer" +) + +const ( + tokensPerSecond = 4.0 // 4 messages per second sustained rate + bucketCapacity = 8.0 // Allow bursts of up to 8 messages + warningThreshold = 3 + + abuseThreshold = 10 + violationWindowDuration = 5 * time.Minute +) + +type clientBucket struct { + tokens float64 + lastRefillTime time.Time + violations uint32 + lastViolation time.Time +} + +type rateLimitConfig struct { + recordInfraction func(ip string, peerID peer.ID, infractionType infractionType) error + now func() time.Time + log slog.Logger +} + +type broadcastRateLimiter struct { + cfg *rateLimitConfig + + mtx sync.RWMutex + clientBuckets map[peer.ID]*clientBucket +} + +func newBroadcastRateLimiter(cfg *rateLimitConfig) *broadcastRateLimiter { + return &broadcastRateLimiter{ + clientBuckets: make(map[peer.ID]*clientBucket), + cfg: cfg, + } +} + +func (rl *broadcastRateLimiter) allowBroadcast(client peer.ID) (bool, infractionType) { + rl.mtx.Lock() + defer rl.mtx.Unlock() + + now := rl.cfg.now() + + bucket, exists := rl.clientBuckets[client] + if !exists { + bucket = &clientBucket{ + tokens: bucketCapacity, + lastRefillTime: now, + } + + rl.clientBuckets[client] = bucket + bucket.tokens-- + + return true, 0 + } + + // Refill tokens based on elapsed time + elapsed := now.Sub(bucket.lastRefillTime).Seconds() + tokensToAdd := elapsed * tokensPerSecond + bucket.tokens = math.Min(bucket.tokens+tokensToAdd, bucketCapacity) + bucket.lastRefillTime = now + + if bucket.tokens >= 1.0 { + bucket.tokens-- + return true, 0 + } + + // Rate limit violation + shouldRecord := rl.recordViolation(bucket, now) + if !shouldRecord { + return false, 0 + } + + if bucket.violations >= abuseThreshold { + return false, RateLimitAbuse + } + + return false, RateLimitViolation +} + +func (rl *broadcastRateLimiter) recordViolation(bucket *clientBucket, now time.Time) bool { + if now.Sub(bucket.lastViolation) >= violationWindowDuration { + bucket.violations = 0 + } + + bucket.violations++ + bucket.lastViolation = now + + // Only report violations after warning threshold + return bucket.violations > uint32(warningThreshold) +} + +func (rl *broadcastRateLimiter) cleanup() { + rl.mtx.Lock() + defer rl.mtx.Unlock() + + now := rl.cfg.now() + cutoff := now.Add(-violationWindowDuration) + + for client, bucket := range rl.clientBuckets { + if bucket.lastRefillTime.Before(cutoff) && bucket.lastViolation.Before(cutoff) { + delete(rl.clientBuckets, client) + } + } +} + +func (rl *broadcastRateLimiter) run(ctx context.Context) { + ticker := time.NewTicker(time.Minute * 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rl.cleanup() + } + } +} diff --git a/tatanka/broadcast_rate_limiter_test.go b/tatanka/broadcast_rate_limiter_test.go new file mode 100644 index 0000000..08dbda0 --- /dev/null +++ b/tatanka/broadcast_rate_limiter_test.go @@ -0,0 +1,342 @@ +package tatanka + +import ( + "testing" + "time" + + "github.com/decred/slog" + "github.com/libp2p/go-libp2p/core/peer" +) + +func TestTokenBucketRefill(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // First broadcast is always allowed + allowed, _ := limiter.allowBroadcast(clientID) + if !allowed { + t.Fatalf("First broadcast should be allowed") + } + + // Consume remaining bucket capacity (7 tokens left) + for i := 0; i < 7; i++ { + allowed, _ := limiter.allowBroadcast(clientID) + if !allowed { + t.Fatalf("Broadcast %d should be allowed (bucket not empty)", i+2) + } + } + + // Now bucket is empty + allowed, _ = limiter.allowBroadcast(clientID) + if allowed { + t.Fatalf("Broadcast should be denied (bucket empty)") + } + + // Simulate 1 second passing - should get 4 new tokens + now = now.Add(1 * time.Second) + allowed, _ = limiter.allowBroadcast(clientID) + if !allowed { + t.Fatalf("Broadcast should be allowed (tokens refilled)") + } +} + +func TestTokenBucketBurstAllowance(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // Bucket capacity is 8, so 8 broadcasts should be allowed immediately + for i := 0; i < 8; i++ { + allowed, _ := limiter.allowBroadcast(clientID) + if !allowed { + t.Fatalf("Broadcast %d should be allowed (burst capacity = 8)", i+1) + } + } + + // 9th should be denied + allowed, _ := limiter.allowBroadcast(clientID) + if allowed { + t.Fatalf("9th broadcast should be denied (bucket capacity exceeded)") + } +} + +func TestTokenBucketRateEnforcement(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // Consume full bucket + for i := 0; i < 8; i++ { + limiter.allowBroadcast(clientID) + } + + // Try immediately - should fail + allowed, _ := limiter.allowBroadcast(clientID) + if allowed { + t.Fatalf("Should be rate limited immediately") + } + + // After 100ms (0.4 tokens), should still fail + now = now.Add(100 * time.Millisecond) + allowed, _ = limiter.allowBroadcast(clientID) + if allowed { + t.Fatalf("Should be rate limited at 100ms") + } + + // After 300ms (1.2 tokens), should succeed + now = now.Add(200 * time.Millisecond) + allowed, _ = limiter.allowBroadcast(clientID) + if !allowed { + t.Fatalf("Should be allowed at 300ms (1+ token available)") + } +} + +func TestViolationCounting(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // Consume full bucket + for i := 0; i < 8; i++ { + limiter.allowBroadcast(clientID) + } + + // First 3 violations should not record infractions (warning threshold = 3) + for i := 0; i < 3; i++ { + _, infractionType := limiter.allowBroadcast(clientID) + if infractionType != 0 { + t.Fatalf("Violation %d should not record infraction yet", i+1) + } + } + + // 4th violation should record infraction + _, infractionType := limiter.allowBroadcast(clientID) + if infractionType == 0 { + t.Fatalf("Violation 4 should record infraction") + } +} + +func TestViolationWindowExpiry(t *testing.T) { + currentTime := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return currentTime }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // Consume full bucket and record violations + for i := 0; i < 8; i++ { + limiter.allowBroadcast(clientID) + } + + // Record 5 violations (keep consuming, tokens won't refill yet) + for i := 0; i < 5; i++ { + limiter.allowBroadcast(clientID) + } + + // Verify violations are recorded before window expiry + limiter.mtx.RLock() + bucket := limiter.clientBuckets[clientID] + limiter.mtx.RUnlock() + + if bucket.violations != 5 { + t.Fatalf("Expected 5 violations before window expiry, got %d", bucket.violations) + } + + // Move past violation window - time moves enough that window expires but not enough to refill many tokens + currentTime = currentTime.Add(violationWindowDuration + 100*time.Millisecond) + + // Consume bucket again (more violations after window expires) + // These should trigger a new window with reset violations + for i := 0; i < 8; i++ { + limiter.allowBroadcast(clientID) + } + + // Record one more violation after window expiry + limiter.allowBroadcast(clientID) + + limiter.mtx.RLock() + bucket = limiter.clientBuckets[clientID] + limiter.mtx.RUnlock() + + // After window expires, violations should have been reset and then incremented from abuse + // Should be between 1 and 2 depending on exact timing + if bucket.violations < 1 || bucket.violations > 2 { + t.Fatalf("Expected violations between 1-2 after window expiry, got %d", bucket.violations) + } +} + +func TestViolationAbuseThreshold(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // Consume full bucket + for i := 0; i < 8; i++ { + limiter.allowBroadcast(clientID) + } + + // Record violations to reach severe abuse threshold (10 violations) + var lastInfractionType infractionType + for i := 0; i < 10; i++ { + _, lastInfractionType = limiter.allowBroadcast(clientID) + } + + // Check the infraction type + if lastInfractionType != RateLimitAbuse { + t.Fatalf("Expected RateLimitAbuse at 10 violations, got %v", lastInfractionType) + } + + // Test that at 9 violations it's still normal violation + now = now.Add(violationWindowDuration + 1*time.Second) + limiter.allowBroadcast(clientID) // Reset violations + + // Consume full bucket again + now = time.Now() // Reset time to prevent accumulation + cfg.now = func() time.Time { return now } + limiter2 := newBroadcastRateLimiter(cfg) + + for i := 0; i < 8; i++ { + limiter2.allowBroadcast(clientID) + } + + // Record 9 violations + for i := 0; i < 9; i++ { + _, lastInfractionType = limiter2.allowBroadcast(clientID) + } + + if lastInfractionType != RateLimitViolation { + t.Fatalf("Expected RateLimitViolation at 9 violations, got %v", lastInfractionType) + } +} + +func TestNewClientFirstBroadcast(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // First broadcast for any client should be allowed + allowed, infractionType := limiter.allowBroadcast(clientID) + if !allowed { + t.Fatalf("First broadcast for new client should be allowed") + } + if infractionType != 0 { + t.Fatalf("First broadcast should not record infraction") + } +} + +func TestConcurrentBroadcasts(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + clientID := generateTestPeerID(t) + + // This test should be run with go test -race to detect race conditions + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 100; j++ { + limiter.allowBroadcast(clientID) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + <-done + } +} + +func TestMultipleIndependentClients(t *testing.T) { + now := time.Now() + cfg := &rateLimitConfig{ + recordInfraction: func(ip string, peerID peer.ID, it infractionType) error { + return nil + }, + now: func() time.Time { return now }, + log: slog.Disabled, + } + + limiter := newBroadcastRateLimiter(cfg) + client1 := generateTestPeerID(t) + client2 := generateTestPeerID(t) + + // Consume client1's bucket + for i := 0; i < 8; i++ { + limiter.allowBroadcast(client1) + } + + // client1 should be rate limited + allowed, _ := limiter.allowBroadcast(client1) + if allowed { + t.Fatalf("Client1 should be rate limited") + } + + // client2 should still be allowed + allowed, _ = limiter.allowBroadcast(client2) + if !allowed { + t.Fatalf("Client2 should not be rate limited") + } +} + diff --git a/tatanka/gossipsub.go b/tatanka/gossipsub.go index 0178cfb..4f66ca5 100644 --- a/tatanka/gossipsub.go +++ b/tatanka/gossipsub.go @@ -28,6 +28,11 @@ const ( // oracleUpdatesTopicName is the name of the pubsub topic used to // propagate oracle updates between tatanka nodes. oracleUpdatesTopicName = "oracle_updates" + + // clientInfractionsTopicName is the name of the pubsub topic used to propagate + // client infractions between tatanka nodes. + clientInfractionsTopicName = "client_infractions" + ) type clientConnectionUpdate struct { @@ -72,20 +77,22 @@ type gossipSubCfg struct { handleBroadcastMessage func(msg *protocolsPb.PushMessage) handleClientConnectionMessage func(update *clientConnectionUpdate) handleOracleUpdate func(update *pb.NodeOracleUpdate) + handleClientInfractionMessage func(msg *pb.ClientInfractionMsg) } // gossipSub manages the nodes connection to a gossip sub network between tatanka // nodes. This network is used to gossip all client broadcast messages, -// client connections, and oracle updates. +// client connections, oracle updates, and client infractions. type gossipSub struct { - log slog.Logger - ps *pubsub.PubSub - cfg *gossipSubCfg - clientMessageTopic *pubsub.Topic - clientConnectionsTopic *pubsub.Topic - oracleUpdatesTopic *pubsub.Topic + log slog.Logger + ps *pubsub.PubSub + cfg *gossipSubCfg + clientMessageTopic *pubsub.Topic + clientConnectionsTopic *pubsub.Topic + oracleUpdatesTopic *pubsub.Topic + clientInfractionsTopic *pubsub.Topic zstdEncoder *zstd.Encoder - zstdDecoder *zstd.Decoder + zstdDecoder *zstd.Decoder } func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { @@ -122,6 +129,11 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { return nil, fmt.Errorf("failed to join oracle updates topic: %w", err) } + clientInfractionsTopic, err := ps.Join(clientInfractionsTopicName) + if err != nil { + return nil, fmt.Errorf("failed to join client infractions topic: %w", err) + } + zstdEncoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) if err != nil { return nil, fmt.Errorf("failed to create zstd encoder: %w", err) @@ -139,6 +151,7 @@ func newGossipSub(ctx context.Context, cfg *gossipSubCfg) (*gossipSub, error) { clientMessageTopic: clientMessageTopic, clientConnectionsTopic: clientConnectionsTopic, oracleUpdatesTopic: oracleUpdatesTopic, + clientInfractionsTopic: clientInfractionsTopic, zstdEncoder: zstdEncoder, zstdDecoder: zstdDecoder, }, nil @@ -240,6 +253,33 @@ func (gs *gossipSub) listenForOracleUpdates(ctx context.Context) error { } } +func (gs *gossipSub) listenForClientInfractions(ctx context.Context) error { + sub, err := gs.clientInfractionsTopic.Subscribe() + if err != nil { + return fmt.Errorf("failed to subscribe to client infractions topic: %w", err) + } + + for { + msg, err := sub.Next(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + + if msg != nil { + infractionMsg := &pb.ClientInfractionMsg{} + if err := proto.Unmarshal(msg.Data, infractionMsg); err != nil { + gs.log.Errorf("Failed to unmarshal client infraction message: %v", err) + continue + } + + gs.cfg.handleClientInfractionMessage(infractionMsg) + } + } +} + func (gs *gossipSub) publishClientMessage(ctx context.Context, msg *protocolsPb.PushMessage) error { data, err := proto.Marshal(msg) if err != nil { @@ -268,6 +308,15 @@ func (gs *gossipSub) publishOracleUpdate(ctx context.Context, update *pb.NodeOra return gs.oracleUpdatesTopic.Publish(ctx, compressed) } +func (gs *gossipSub) publishClientInfractionMessage(ctx context.Context, msg *pb.ClientInfractionMsg) error { + data, err := proto.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal client infraction message: %w", err) + } + + return gs.clientInfractionsTopic.Publish(ctx, data) +} + func (gs *gossipSub) run(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) @@ -289,5 +338,11 @@ func (gs *gossipSub) run(ctx context.Context) error { return err }) + g.Go(func() error { + err := gs.listenForClientInfractions(ctx) + gs.log.Debug("Client infractions listener stopped.") + return err + }) + return g.Wait() } diff --git a/tatanka/handler_permissions.go b/tatanka/handler_permissions.go index 46805a6..6ec43c4 100644 --- a/tatanka/handler_permissions.go +++ b/tatanka/handler_permissions.go @@ -2,6 +2,7 @@ package tatanka import ( "errors" + "fmt" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/protocol" @@ -27,6 +28,19 @@ func requireAny(permissions ...permissionDecorator) permissionDecorator { } } +// requireAll returns a single decorator that succeeds only if ALL +// underlying permissions pass. It returns an error if ANY permission fails. +func requireAll(permissions ...permissionDecorator) permissionDecorator { + return func(s network.Stream) error { + for _, p := range permissions { + if err := p(s); err != nil { + return err + } + } + return nil + } +} + func requireNoPermission(s network.Stream) error { return nil } @@ -78,3 +92,17 @@ func (t *TatankaNode) setStreamHandler(protocolID string, handler func(s network t.node.SetStreamHandler(protocol.ID(protocolID), finalHandler) } + +func (t *TatankaNode) requireNotBanned(s network.Stream) error { + ip, err := getIPFromStream(s) + if err != nil { + return errUnauthorized + } + + if t.banManager.isClientBanned(ip) { + reason := t.banManager.getClientBanReason(ip) + return fmt.Errorf("%w: %s", errUnauthorized, reason) + } + + return nil +} diff --git a/tatanka/handler_permissions_test.go b/tatanka/handler_permissions_test.go index ae7da91..d821788 100644 --- a/tatanka/handler_permissions_test.go +++ b/tatanka/handler_permissions_test.go @@ -2,9 +2,11 @@ package tatanka import ( "context" + "errors" "testing" "time" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/bisoncraft/mesh/codec" @@ -12,6 +14,143 @@ import ( protocolsPb "github.com/bisoncraft/mesh/protocols/pb" ) +// TestRequireAny tests the requireAny permission combinator. +func TestRequireAny(t *testing.T) { + tests := []struct { + name string + perms []permissionDecorator + wantErr bool + wantMsg string + }{ + { + name: "all permissions fail, returns last error", + perms: []permissionDecorator{ + func(s network.Stream) error { return errors.New("error1") }, + func(s network.Stream) error { return errors.New("error2") }, + func(s network.Stream) error { return errors.New("error3") }, + }, + wantErr: true, + wantMsg: "error3", + }, + { + name: "first permission succeeds, returns nil", + perms: []permissionDecorator{ + func(s network.Stream) error { return nil }, + func(s network.Stream) error { return errors.New("error2") }, + func(s network.Stream) error { return errors.New("error3") }, + }, + wantErr: false, + }, + { + name: "last permission succeeds, returns nil", + perms: []permissionDecorator{ + func(s network.Stream) error { return errors.New("error1") }, + func(s network.Stream) error { return errors.New("error2") }, + func(s network.Stream) error { return nil }, + }, + wantErr: false, + }, + { + name: "single permission fails", + perms: []permissionDecorator{ + func(s network.Stream) error { return errors.New("fail") }, + }, + wantErr: true, + wantMsg: "fail", + }, + { + name: "single permission succeeds", + perms: []permissionDecorator{ + func(s network.Stream) error { return nil }, + }, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + combined := requireAny(tc.perms...) + err := combined(nil) + + if (err != nil) != tc.wantErr { + t.Errorf("got error %v, want error %v", err, tc.wantErr) + } + if tc.wantErr && tc.wantMsg != "" && err.Error() != tc.wantMsg { + t.Errorf("got error message %q, want %q", err.Error(), tc.wantMsg) + } + }) + } +} + +// TestRequireAll tests the requireAll permission combinator. +func TestRequireAll(t *testing.T) { + tests := []struct { + name string + perms []permissionDecorator + wantErr bool + wantMsg string + }{ + { + name: "all permissions pass, returns nil", + perms: []permissionDecorator{ + func(s network.Stream) error { return nil }, + func(s network.Stream) error { return nil }, + func(s network.Stream) error { return nil }, + }, + wantErr: false, + }, + { + name: "first permission fails, returns error immediately", + perms: []permissionDecorator{ + func(s network.Stream) error { return errors.New("error1") }, + func(s network.Stream) error { return errors.New("error2") }, + func(s network.Stream) error { return errors.New("error3") }, + }, + wantErr: true, + wantMsg: "error1", + }, + { + name: "last permission fails, returns error", + perms: []permissionDecorator{ + func(s network.Stream) error { return nil }, + func(s network.Stream) error { return nil }, + func(s network.Stream) error { return errors.New("error3") }, + }, + wantErr: true, + wantMsg: "error3", + }, + { + name: "single permission fails", + perms: []permissionDecorator{ + func(s network.Stream) error { return errors.New("fail") }, + }, + wantErr: true, + wantMsg: "fail", + }, + { + name: "single permission succeeds", + perms: []permissionDecorator{ + func(s network.Stream) error { return nil }, + }, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + combined := requireAll(tc.perms...) + err := combined(nil) + + if (err != nil) != tc.wantErr { + t.Errorf("got error %v, want error %v", err, tc.wantErr) + } + if tc.wantErr && tc.wantMsg != "" && err.Error() != tc.wantMsg { + t.Errorf("got error message %q, want %q", err.Error(), tc.wantMsg) + } + }) + } +} + // TestPushPermissions tests the permissions for the push protocol. func TestPushPermissions(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) diff --git a/tatanka/handlers.go b/tatanka/handlers.go index a3f517e..44491a2 100644 --- a/tatanka/handlers.go +++ b/tatanka/handlers.go @@ -2,19 +2,21 @@ package tatanka import ( "context" + "errors" "fmt" + "io" "math/big" "strings" "time" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" "github.com/bisoncraft/mesh/bond" "github.com/bisoncraft/mesh/codec" "github.com/bisoncraft/mesh/oracle" "github.com/bisoncraft/mesh/protocols" protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/bisoncraft/mesh/tatanka/pb" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" "google.golang.org/protobuf/proto" ) @@ -61,6 +63,11 @@ func pbPeerInfoToLibp2p(pbPeer *protocolsPb.PeerInfo) (peer.AddrInfo, error) { // handleClientPush is called when the client opens a push stream to the node. func (t *TatankaNode) handleClientPush(s network.Stream) { client := s.Conn().RemotePeer() + ip, err := getIPFromStream(s) + if err != nil { + t.log.Debugf("Failed to fetch client ip from stream: %v", err) + return + } var success bool defer func() { @@ -72,6 +79,11 @@ func (t *TatankaNode) handleClientPush(s network.Stream) { initialSubs := &protocolsPb.InitialSubscriptions{} if err := codec.ReadLengthPrefixedMessage(s, initialSubs); err != nil { t.log.Errorf("Failed to read initial subscriptions from client %s: %v", client.ShortString(), err) + if !errors.Is(err, io.EOF) { + if err := t.banManager.recordInfraction(ip, client, MalformedMessage); err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + } + } return } @@ -110,11 +122,20 @@ func (t *TatankaNode) handleClientSubscribe(s network.Stream) { defer func() { _ = s.Close() }() client := s.Conn().RemotePeer() + ip, err := getIPFromStream(s) + if err != nil { + t.log.Debugf("Failed to fetch client ip from stream: %v", err) + return + } subscribeMessage := &protocolsPb.SubscribeRequest{} if err := codec.ReadLengthPrefixedMessage(s, subscribeMessage); err != nil { t.log.Debugf("Failed to read/unmarshal subscribe message from client %s: %v.", client.ShortString(), err) - // TODO: client sent invalid message, remove client? + if !errors.Is(err, io.EOF) { + if err := t.banManager.recordInfraction(ip, client, MalformedMessage); err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + } + } return } @@ -188,11 +209,50 @@ func (t *TatankaNode) handleClientPublish(s network.Stream) { defer func() { _ = s.Close() }() client := s.Conn().RemotePeer() + ip, err := getIPFromStream(s) + if err != nil { + t.log.Debugf("Failed to fetch client ip from stream: %v", err) + return + } publishMessage := &protocolsPb.PublishRequest{} if err := codec.ReadLengthPrefixedMessage(s, publishMessage); err != nil { t.log.Debugf("Failed to read/unmarshal publish message from client %s: %v.", client.ShortString(), err) - // TODO: remove client? + if !errors.Is(err, io.EOF) { + if err := t.banManager.recordInfraction(ip, client, MalformedMessage); err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + } + } + return + } + + sendError := func(msg string) { + response := &protocolsPb.Response{ + Response: &protocolsPb.Response_Error{ + Error: &protocolsPb.Error{ + Error: &protocolsPb.Error_Message{ + Message: msg, + }, + }, + }, + } + if err := codec.WriteLengthPrefixedMessage(s, response); err != nil { + t.log.Errorf("Failed to write error response: %v", err) + } + } + + allowed, infractionType := t.broadcastRateLimiter.allowBroadcast(client) + if !allowed { + t.log.Warnf("Client %s rate limit exceeded, dropping broadcast to topic %s", + client.ShortString(), publishMessage.Topic) + sendError("broadcast rate limit exceeded") + + if infractionType != 0 { + err = t.banManager.recordInfraction(ip, client, infractionType) + if err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + } + } return } @@ -200,6 +260,12 @@ func (t *TatankaNode) handleClientPublish(s network.Stream) { strings.HasPrefix(publishMessage.Topic, oracle.FeeRateTopicPrefix) { t.log.Warnf("Client %s attempted to publish to restricted oracle topic %s", client.ShortString(), publishMessage.Topic) + sendError("cannot publish to restricted topic") + err = t.banManager.recordInfraction(ip, client, NodeImpersonation) + if err != nil { + t.log.Errorf("Failed to record infraction: %v", err) + } + return } @@ -207,9 +273,15 @@ func (t *TatankaNode) handleClientPublish(s network.Stream) { defer cancel() message := pbPushMessageBroadcast(publishMessage.Topic, publishMessage.Data, client) - err := t.gossipSub.publishClientMessage(ctx, message) + err = t.gossipSub.publishClientMessage(ctx, message) if err != nil { t.log.Errorf("Failed to publish client message: %w", err) + sendError("failed to publish message") + return + } + + if err := codec.WriteLengthPrefixedMessage(s, pbResponseSuccess()); err != nil { + t.log.Errorf("Failed to write success response: %v", err) } } @@ -217,11 +289,20 @@ func (t *TatankaNode) handlePostBonds(s network.Stream) { defer func() { _ = s.Close() }() client := s.Conn().RemotePeer() + ip, err := getIPFromStream(s) + if err != nil { + t.log.Debugf("Failed to fetch client ip from stream: %v", err) + return + } postBondMessage := &protocolsPb.PostBondRequest{} if err := codec.ReadLengthPrefixedMessage(s, postBondMessage); err != nil { t.log.Debugf("Failed to read/unmarshal post bond message from client %s: %v.", client.ShortString(), err) - // TODO: remove client? + if !errors.Is(err, io.EOF) { + if err := t.banManager.recordInfraction(ip, client, MalformedMessage); err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + } + } return } @@ -247,6 +328,11 @@ func (t *TatankaNode) handlePostBonds(s network.Stream) { return } if !valid { + err = t.banManager.recordInfraction(ip, client, InvalidBond) + if err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + return + } sendInvalidBondIndex(uint32(i)) return } @@ -302,10 +388,20 @@ func (t *TatankaNode) handleClientRelayMessage(s network.Stream) { defer func() { _ = s.Close() }() client := s.Conn().RemotePeer() + ip, err := getIPFromStream(s) + if err != nil { + t.log.Debugf("Failed to fetch client ip from stream: %v", err) + return + } requestMessage := &protocolsPb.ClientRelayMessageRequest{} if err := codec.ReadLengthPrefixedMessage(s, requestMessage); err != nil { t.log.Debugf("Failed to read/unmarshal relay message request from client %s: %v.", client.ShortString(), err) + if !errors.Is(err, io.EOF) { + if err := t.banManager.recordInfraction(ip, client, MalformedMessage); err != nil { + t.log.Errorf("Failed to record infraction for client %s (%s): %v", client.ShortString(), ip, err) + } + } return } @@ -747,6 +843,27 @@ func (t *TatankaNode) handleWhitelist(s network.Stream) { } } +func (t *TatankaNode) handleClientInfractionsSnapshot(s network.Stream) { + defer func() { _ = s.Close() }() + + remotePeerID := s.Conn().RemotePeer() + + request := &pb.ClientInfractionsSnapshotRequest{} + if err := codec.ReadLengthPrefixedMessage(s, request); err != nil { + t.log.Warnf("Failed to read client infractions snapshot request from peer %s: %v", + remotePeerID.ShortString(), err) + return + } + + infractions := t.banManager.getActiveInfractions() + response := pbClientInfractionsSnapshotResponse(infractions) + + if err := codec.WriteLengthPrefixedMessage(s, response); err != nil { + t.log.Warnf("Failed to write client infractions snapshot response to peer %s: %v", + remotePeerID.ShortString(), err) + } +} + // handleAvailableMeshNodes handles a request from a client to get a list of // all mesh nodes that this tatanka node is connected to. func (t *TatankaNode) handleAvailableMeshNodes(s network.Stream) { @@ -1003,3 +1120,9 @@ func bigIntToBytes(bi *big.Int) []byte { } return bi.Bytes() } + +func pbClientInfractionsSnapshotResponse(infractions []*pb.ClientInfractionMsg) *pb.ClientInfractionsSnapshotResponse { + return &pb.ClientInfractionsSnapshotResponse{ + Infractions: infractions, + } +} diff --git a/tatanka/pb/messages.pb.go b/tatanka/pb/messages.pb.go index 5106ee2..6154015 100644 --- a/tatanka/pb/messages.pb.go +++ b/tatanka/pb/messages.pb.go @@ -2,7 +2,7 @@ // versions: // protoc-gen-go v1.36.10 // protoc v6.33.1 -// source: tatanka/pb/messages.proto +// source: messages.proto package pb @@ -34,7 +34,7 @@ type ClientConnectionMsg struct { func (x *ClientConnectionMsg) Reset() { *x = ClientConnectionMsg{} - mi := &file_tatanka_pb_messages_proto_msgTypes[0] + mi := &file_messages_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -46,7 +46,7 @@ func (x *ClientConnectionMsg) String() string { func (*ClientConnectionMsg) ProtoMessage() {} func (x *ClientConnectionMsg) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[0] + mi := &file_messages_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -59,7 +59,7 @@ func (x *ClientConnectionMsg) ProtoReflect() protoreflect.Message { // Deprecated: Use ClientConnectionMsg.ProtoReflect.Descriptor instead. func (*ClientConnectionMsg) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{0} + return file_messages_proto_rawDescGZIP(), []int{0} } func (x *ClientConnectionMsg) GetId() []byte { @@ -90,6 +90,152 @@ func (x *ClientConnectionMsg) GetTimestamp() int64 { return 0 } +// ClientBanMsg is used internally by tatanka nodes to share client ban information. +type ClientBanMsg struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` // banned client IP address + Reporter []byte `protobuf:"bytes,2,opt,name=reporter,proto3" json:"reporter,omitempty"` // peer.ID of reporting node + TotalPenalties uint32 `protobuf:"varint,3,opt,name=total_penalties,json=totalPenalties,proto3" json:"total_penalties,omitempty"` + Expiry int64 `protobuf:"varint,4,opt,name=expiry,proto3" json:"expiry,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ClientBanMsg) Reset() { + *x = ClientBanMsg{} + mi := &file_messages_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ClientBanMsg) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientBanMsg) ProtoMessage() {} + +func (x *ClientBanMsg) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientBanMsg.ProtoReflect.Descriptor instead. +func (*ClientBanMsg) Descriptor() ([]byte, []int) { + return file_messages_proto_rawDescGZIP(), []int{1} +} + +func (x *ClientBanMsg) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +func (x *ClientBanMsg) GetReporter() []byte { + if x != nil { + return x.Reporter + } + return nil +} + +func (x *ClientBanMsg) GetTotalPenalties() uint32 { + if x != nil { + return x.TotalPenalties + } + return 0 +} + +func (x *ClientBanMsg) GetExpiry() int64 { + if x != nil { + return x.Expiry + } + return 0 +} + +// ClientInfractionMsg is used to gossip individual client infractions between tatanka nodes. +type ClientInfractionMsg struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` // client IP address + Reporter []byte `protobuf:"bytes,2,opt,name=reporter,proto3" json:"reporter,omitempty"` // peer.ID of reporting node + InfractionType uint32 `protobuf:"varint,3,opt,name=infraction_type,json=infractionType,proto3" json:"infraction_type,omitempty"` // infractionType enum value + Penalty uint32 `protobuf:"varint,4,opt,name=penalty,proto3" json:"penalty,omitempty"` // penalty for this infraction + Expiry int64 `protobuf:"varint,5,opt,name=expiry,proto3" json:"expiry,omitempty"` // when this infraction expires + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ClientInfractionMsg) Reset() { + *x = ClientInfractionMsg{} + mi := &file_messages_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ClientInfractionMsg) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientInfractionMsg) ProtoMessage() {} + +func (x *ClientInfractionMsg) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientInfractionMsg.ProtoReflect.Descriptor instead. +func (*ClientInfractionMsg) Descriptor() ([]byte, []int) { + return file_messages_proto_rawDescGZIP(), []int{2} +} + +func (x *ClientInfractionMsg) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +func (x *ClientInfractionMsg) GetReporter() []byte { + if x != nil { + return x.Reporter + } + return nil +} + +func (x *ClientInfractionMsg) GetInfractionType() uint32 { + if x != nil { + return x.InfractionType + } + return 0 +} + +func (x *ClientInfractionMsg) GetPenalty() uint32 { + if x != nil { + return x.Penalty + } + return 0 +} + +func (x *ClientInfractionMsg) GetExpiry() int64 { + if x != nil { + return x.Expiry + } + return 0 +} + type TatankaForwardRelayRequest struct { state protoimpl.MessageState `protogen:"open.v1"` InitiatorId []byte `protobuf:"bytes,1,opt,name=initiator_id,json=initiatorId,proto3" json:"initiator_id,omitempty"` // peer.ID serialized as bytes @@ -101,7 +247,7 @@ type TatankaForwardRelayRequest struct { func (x *TatankaForwardRelayRequest) Reset() { *x = TatankaForwardRelayRequest{} - mi := &file_tatanka_pb_messages_proto_msgTypes[1] + mi := &file_messages_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -113,7 +259,7 @@ func (x *TatankaForwardRelayRequest) String() string { func (*TatankaForwardRelayRequest) ProtoMessage() {} func (x *TatankaForwardRelayRequest) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[1] + mi := &file_messages_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -126,7 +272,7 @@ func (x *TatankaForwardRelayRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TatankaForwardRelayRequest.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayRequest) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{1} + return file_messages_proto_rawDescGZIP(), []int{3} } func (x *TatankaForwardRelayRequest) GetInitiatorId() []byte { @@ -165,7 +311,7 @@ type TatankaForwardRelayResponse struct { func (x *TatankaForwardRelayResponse) Reset() { *x = TatankaForwardRelayResponse{} - mi := &file_tatanka_pb_messages_proto_msgTypes[2] + mi := &file_messages_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -177,7 +323,7 @@ func (x *TatankaForwardRelayResponse) String() string { func (*TatankaForwardRelayResponse) ProtoMessage() {} func (x *TatankaForwardRelayResponse) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[2] + mi := &file_messages_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -190,7 +336,7 @@ func (x *TatankaForwardRelayResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TatankaForwardRelayResponse.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayResponse) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{2} + return file_messages_proto_rawDescGZIP(), []int{4} } func (x *TatankaForwardRelayResponse) GetResponse() isTatankaForwardRelayResponse_Response { @@ -273,7 +419,7 @@ type DiscoveryRequest struct { func (x *DiscoveryRequest) Reset() { *x = DiscoveryRequest{} - mi := &file_tatanka_pb_messages_proto_msgTypes[3] + mi := &file_messages_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -285,7 +431,7 @@ func (x *DiscoveryRequest) String() string { func (*DiscoveryRequest) ProtoMessage() {} func (x *DiscoveryRequest) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[3] + mi := &file_messages_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -298,7 +444,7 @@ func (x *DiscoveryRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryRequest.ProtoReflect.Descriptor instead. func (*DiscoveryRequest) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{3} + return file_messages_proto_rawDescGZIP(), []int{5} } func (x *DiscoveryRequest) GetId() []byte { @@ -321,7 +467,7 @@ type DiscoveryResponse struct { func (x *DiscoveryResponse) Reset() { *x = DiscoveryResponse{} - mi := &file_tatanka_pb_messages_proto_msgTypes[4] + mi := &file_messages_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -333,7 +479,7 @@ func (x *DiscoveryResponse) String() string { func (*DiscoveryResponse) ProtoMessage() {} func (x *DiscoveryResponse) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[4] + mi := &file_messages_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -346,7 +492,7 @@ func (x *DiscoveryResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryResponse.ProtoReflect.Descriptor instead. func (*DiscoveryResponse) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{4} + return file_messages_proto_rawDescGZIP(), []int{6} } func (x *DiscoveryResponse) GetResponse() isDiscoveryResponse_Response { @@ -399,7 +545,7 @@ type WhitelistRequest struct { func (x *WhitelistRequest) Reset() { *x = WhitelistRequest{} - mi := &file_tatanka_pb_messages_proto_msgTypes[5] + mi := &file_messages_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -411,7 +557,7 @@ func (x *WhitelistRequest) String() string { func (*WhitelistRequest) ProtoMessage() {} func (x *WhitelistRequest) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[5] + mi := &file_messages_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -424,7 +570,7 @@ func (x *WhitelistRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WhitelistRequest.ProtoReflect.Descriptor instead. func (*WhitelistRequest) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{5} + return file_messages_proto_rawDescGZIP(), []int{7} } func (x *WhitelistRequest) GetPeerIDs() [][]byte { @@ -447,7 +593,7 @@ type WhitelistResponse struct { func (x *WhitelistResponse) Reset() { *x = WhitelistResponse{} - mi := &file_tatanka_pb_messages_proto_msgTypes[6] + mi := &file_messages_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -459,7 +605,7 @@ func (x *WhitelistResponse) String() string { func (*WhitelistResponse) ProtoMessage() {} func (x *WhitelistResponse) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[6] + mi := &file_messages_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -472,7 +618,7 @@ func (x *WhitelistResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WhitelistResponse.ProtoReflect.Descriptor instead. func (*WhitelistResponse) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{6} + return file_messages_proto_rawDescGZIP(), []int{8} } func (x *WhitelistResponse) GetResponse() isWhitelistResponse_Response { @@ -527,7 +673,7 @@ type SourcedPrice struct { func (x *SourcedPrice) Reset() { *x = SourcedPrice{} - mi := &file_tatanka_pb_messages_proto_msgTypes[7] + mi := &file_messages_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -539,7 +685,7 @@ func (x *SourcedPrice) String() string { func (*SourcedPrice) ProtoMessage() {} func (x *SourcedPrice) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[7] + mi := &file_messages_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -552,7 +698,7 @@ func (x *SourcedPrice) ProtoReflect() protoreflect.Message { // Deprecated: Use SourcedPrice.ProtoReflect.Descriptor instead. func (*SourcedPrice) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{7} + return file_messages_proto_rawDescGZIP(), []int{9} } func (x *SourcedPrice) GetTicker() string { @@ -582,7 +728,7 @@ type SourcedPriceUpdate struct { func (x *SourcedPriceUpdate) Reset() { *x = SourcedPriceUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[8] + mi := &file_messages_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -594,7 +740,7 @@ func (x *SourcedPriceUpdate) String() string { func (*SourcedPriceUpdate) ProtoMessage() {} func (x *SourcedPriceUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[8] + mi := &file_messages_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -607,7 +753,7 @@ func (x *SourcedPriceUpdate) ProtoReflect() protoreflect.Message { // Deprecated: Use SourcedPriceUpdate.ProtoReflect.Descriptor instead. func (*SourcedPriceUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{8} + return file_messages_proto_rawDescGZIP(), []int{10} } func (x *SourcedPriceUpdate) GetSource() string { @@ -642,7 +788,7 @@ type SourcedFeeRate struct { func (x *SourcedFeeRate) Reset() { *x = SourcedFeeRate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[9] + mi := &file_messages_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -654,7 +800,7 @@ func (x *SourcedFeeRate) String() string { func (*SourcedFeeRate) ProtoMessage() {} func (x *SourcedFeeRate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[9] + mi := &file_messages_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -667,7 +813,7 @@ func (x *SourcedFeeRate) ProtoReflect() protoreflect.Message { // Deprecated: Use SourcedFeeRate.ProtoReflect.Descriptor instead. func (*SourcedFeeRate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{9} + return file_messages_proto_rawDescGZIP(), []int{11} } func (x *SourcedFeeRate) GetNetwork() string { @@ -697,7 +843,7 @@ type SourcedFeeRateUpdate struct { func (x *SourcedFeeRateUpdate) Reset() { *x = SourcedFeeRateUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[10] + mi := &file_messages_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -709,7 +855,7 @@ func (x *SourcedFeeRateUpdate) String() string { func (*SourcedFeeRateUpdate) ProtoMessage() {} func (x *SourcedFeeRateUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[10] + mi := &file_messages_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -722,7 +868,7 @@ func (x *SourcedFeeRateUpdate) ProtoReflect() protoreflect.Message { // Deprecated: Use SourcedFeeRateUpdate.ProtoReflect.Descriptor instead. func (*SourcedFeeRateUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{10} + return file_messages_proto_rawDescGZIP(), []int{12} } func (x *SourcedFeeRateUpdate) GetSource() string { @@ -760,7 +906,7 @@ type NodeOracleUpdate struct { func (x *NodeOracleUpdate) Reset() { *x = NodeOracleUpdate{} - mi := &file_tatanka_pb_messages_proto_msgTypes[11] + mi := &file_messages_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -772,7 +918,7 @@ func (x *NodeOracleUpdate) String() string { func (*NodeOracleUpdate) ProtoMessage() {} func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[11] + mi := &file_messages_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -785,7 +931,7 @@ func (x *NodeOracleUpdate) ProtoReflect() protoreflect.Message { // Deprecated: Use NodeOracleUpdate.ProtoReflect.Descriptor instead. func (*NodeOracleUpdate) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{11} + return file_messages_proto_rawDescGZIP(), []int{13} } func (x *NodeOracleUpdate) GetUpdate() isNodeOracleUpdate_Update { @@ -829,6 +975,88 @@ func (*NodeOracleUpdate_PriceUpdate) isNodeOracleUpdate_Update() {} func (*NodeOracleUpdate_FeeRateUpdate) isNodeOracleUpdate_Update() {} +// ClientInfractionsSnapshotRequest is sent by a tatanka node to query another node for its active client infractions. +type ClientInfractionsSnapshotRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ClientInfractionsSnapshotRequest) Reset() { + *x = ClientInfractionsSnapshotRequest{} + mi := &file_messages_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ClientInfractionsSnapshotRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientInfractionsSnapshotRequest) ProtoMessage() {} + +func (x *ClientInfractionsSnapshotRequest) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientInfractionsSnapshotRequest.ProtoReflect.Descriptor instead. +func (*ClientInfractionsSnapshotRequest) Descriptor() ([]byte, []int) { + return file_messages_proto_rawDescGZIP(), []int{14} +} + +// ClientInfractionsSnapshotResponse contains the snapshot of active client infractions. +type ClientInfractionsSnapshotResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Infractions []*ClientInfractionMsg `protobuf:"bytes,1,rep,name=infractions,proto3" json:"infractions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ClientInfractionsSnapshotResponse) Reset() { + *x = ClientInfractionsSnapshotResponse{} + mi := &file_messages_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ClientInfractionsSnapshotResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientInfractionsSnapshotResponse) ProtoMessage() {} + +func (x *ClientInfractionsSnapshotResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientInfractionsSnapshotResponse.ProtoReflect.Descriptor instead. +func (*ClientInfractionsSnapshotResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_rawDescGZIP(), []int{15} +} + +func (x *ClientInfractionsSnapshotResponse) GetInfractions() []*ClientInfractionMsg { + if x != nil { + return x.Infractions + } + return nil +} + type TatankaForwardRelayResponse_ClientNotFound struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -837,7 +1065,7 @@ type TatankaForwardRelayResponse_ClientNotFound struct { func (x *TatankaForwardRelayResponse_ClientNotFound) Reset() { *x = TatankaForwardRelayResponse_ClientNotFound{} - mi := &file_tatanka_pb_messages_proto_msgTypes[12] + mi := &file_messages_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -849,7 +1077,7 @@ func (x *TatankaForwardRelayResponse_ClientNotFound) String() string { func (*TatankaForwardRelayResponse_ClientNotFound) ProtoMessage() {} func (x *TatankaForwardRelayResponse_ClientNotFound) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[12] + mi := &file_messages_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -862,7 +1090,7 @@ func (x *TatankaForwardRelayResponse_ClientNotFound) ProtoReflect() protoreflect // Deprecated: Use TatankaForwardRelayResponse_ClientNotFound.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayResponse_ClientNotFound) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{2, 0} + return file_messages_proto_rawDescGZIP(), []int{4, 0} } type TatankaForwardRelayResponse_ClientRejected struct { @@ -873,7 +1101,7 @@ type TatankaForwardRelayResponse_ClientRejected struct { func (x *TatankaForwardRelayResponse_ClientRejected) Reset() { *x = TatankaForwardRelayResponse_ClientRejected{} - mi := &file_tatanka_pb_messages_proto_msgTypes[13] + mi := &file_messages_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -885,7 +1113,7 @@ func (x *TatankaForwardRelayResponse_ClientRejected) String() string { func (*TatankaForwardRelayResponse_ClientRejected) ProtoMessage() {} func (x *TatankaForwardRelayResponse_ClientRejected) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[13] + mi := &file_messages_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -898,7 +1126,7 @@ func (x *TatankaForwardRelayResponse_ClientRejected) ProtoReflect() protoreflect // Deprecated: Use TatankaForwardRelayResponse_ClientRejected.ProtoReflect.Descriptor instead. func (*TatankaForwardRelayResponse_ClientRejected) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{2, 1} + return file_messages_proto_rawDescGZIP(), []int{4, 1} } type DiscoveryResponse_Success struct { @@ -910,7 +1138,7 @@ type DiscoveryResponse_Success struct { func (x *DiscoveryResponse_Success) Reset() { *x = DiscoveryResponse_Success{} - mi := &file_tatanka_pb_messages_proto_msgTypes[14] + mi := &file_messages_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -922,7 +1150,7 @@ func (x *DiscoveryResponse_Success) String() string { func (*DiscoveryResponse_Success) ProtoMessage() {} func (x *DiscoveryResponse_Success) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[14] + mi := &file_messages_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -935,7 +1163,7 @@ func (x *DiscoveryResponse_Success) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryResponse_Success.ProtoReflect.Descriptor instead. func (*DiscoveryResponse_Success) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{4, 0} + return file_messages_proto_rawDescGZIP(), []int{6, 0} } func (x *DiscoveryResponse_Success) GetAddrs() [][]byte { @@ -953,7 +1181,7 @@ type DiscoveryResponse_NotFound struct { func (x *DiscoveryResponse_NotFound) Reset() { *x = DiscoveryResponse_NotFound{} - mi := &file_tatanka_pb_messages_proto_msgTypes[15] + mi := &file_messages_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -965,7 +1193,7 @@ func (x *DiscoveryResponse_NotFound) String() string { func (*DiscoveryResponse_NotFound) ProtoMessage() {} func (x *DiscoveryResponse_NotFound) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[15] + mi := &file_messages_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -978,7 +1206,7 @@ func (x *DiscoveryResponse_NotFound) ProtoReflect() protoreflect.Message { // Deprecated: Use DiscoveryResponse_NotFound.ProtoReflect.Descriptor instead. func (*DiscoveryResponse_NotFound) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{4, 1} + return file_messages_proto_rawDescGZIP(), []int{6, 1} } type WhitelistResponse_Success struct { @@ -989,7 +1217,7 @@ type WhitelistResponse_Success struct { func (x *WhitelistResponse_Success) Reset() { *x = WhitelistResponse_Success{} - mi := &file_tatanka_pb_messages_proto_msgTypes[16] + mi := &file_messages_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1001,7 +1229,7 @@ func (x *WhitelistResponse_Success) String() string { func (*WhitelistResponse_Success) ProtoMessage() {} func (x *WhitelistResponse_Success) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[16] + mi := &file_messages_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1014,7 +1242,7 @@ func (x *WhitelistResponse_Success) ProtoReflect() protoreflect.Message { // Deprecated: Use WhitelistResponse_Success.ProtoReflect.Descriptor instead. func (*WhitelistResponse_Success) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{6, 0} + return file_messages_proto_rawDescGZIP(), []int{8, 0} } type WhitelistResponse_Mismatch struct { @@ -1026,7 +1254,7 @@ type WhitelistResponse_Mismatch struct { func (x *WhitelistResponse_Mismatch) Reset() { *x = WhitelistResponse_Mismatch{} - mi := &file_tatanka_pb_messages_proto_msgTypes[17] + mi := &file_messages_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1038,7 +1266,7 @@ func (x *WhitelistResponse_Mismatch) String() string { func (*WhitelistResponse_Mismatch) ProtoMessage() {} func (x *WhitelistResponse_Mismatch) ProtoReflect() protoreflect.Message { - mi := &file_tatanka_pb_messages_proto_msgTypes[17] + mi := &file_messages_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1051,7 +1279,7 @@ func (x *WhitelistResponse_Mismatch) ProtoReflect() protoreflect.Message { // Deprecated: Use WhitelistResponse_Mismatch.ProtoReflect.Descriptor instead. func (*WhitelistResponse_Mismatch) Descriptor() ([]byte, []int) { - return file_tatanka_pb_messages_proto_rawDescGZIP(), []int{6, 1} + return file_messages_proto_rawDescGZIP(), []int{8, 1} } func (x *WhitelistResponse_Mismatch) GetPeerIDs() [][]byte { @@ -1061,17 +1289,28 @@ func (x *WhitelistResponse_Mismatch) GetPeerIDs() [][]byte { return nil } -var File_tatanka_pb_messages_proto protoreflect.FileDescriptor +var File_messages_proto protoreflect.FileDescriptor -const file_tatanka_pb_messages_proto_rawDesc = "" + +const file_messages_proto_rawDesc = "" + "\n" + - "\x19tatanka/pb/messages.proto\x12\x02pb\"\x82\x01\n" + + "\x0emessages.proto\x12\x02pb\"\x82\x01\n" + "\x13ClientConnectionMsg\x12\x0e\n" + "\x02id\x18\x01 \x01(\fR\x02id\x12\x1f\n" + "\vreporter_id\x18\x02 \x01(\fR\n" + "reporterId\x12\x1c\n" + "\tconnected\x18\x03 \x01(\bR\tconnected\x12\x1c\n" + - "\ttimestamp\x18\x04 \x01(\x03R\ttimestamp\"\x82\x01\n" + + "\ttimestamp\x18\x04 \x01(\x03R\ttimestamp\"{\n" + + "\fClientBanMsg\x12\x0e\n" + + "\x02ip\x18\x01 \x01(\tR\x02ip\x12\x1a\n" + + "\breporter\x18\x02 \x01(\fR\breporter\x12'\n" + + "\x0ftotal_penalties\x18\x03 \x01(\rR\x0etotalPenalties\x12\x16\n" + + "\x06expiry\x18\x04 \x01(\x03R\x06expiry\"\x9c\x01\n" + + "\x13ClientInfractionMsg\x12\x0e\n" + + "\x02ip\x18\x01 \x01(\tR\x02ip\x12\x1a\n" + + "\breporter\x18\x02 \x01(\fR\breporter\x12'\n" + + "\x0finfraction_type\x18\x03 \x01(\rR\x0einfractionType\x12\x18\n" + + "\apenalty\x18\x04 \x01(\rR\apenalty\x12\x16\n" + + "\x06expiry\x18\x05 \x01(\x03R\x06expiry\"\x82\x01\n" + "\x1aTatankaForwardRelayRequest\x12!\n" + "\finitiator_id\x18\x01 \x01(\fR\vinitiatorId\x12'\n" + "\x0fcounterparty_id\x18\x02 \x01(\fR\x0ecounterpartyId\x12\x18\n" + @@ -1123,79 +1362,87 @@ const file_tatanka_pb_messages_proto_rawDesc = "" + "\x10NodeOracleUpdate\x12;\n" + "\fprice_update\x18\x01 \x01(\v2\x16.pb.SourcedPriceUpdateH\x00R\vpriceUpdate\x12B\n" + "\x0ffee_rate_update\x18\x02 \x01(\v2\x18.pb.SourcedFeeRateUpdateH\x00R\rfeeRateUpdateB\b\n" + - "\x06updateB'Z%github.com/bisoncraft/mesh/tatanka/pbb\x06proto3" + "\x06update\"\"\n" + + " ClientInfractionsSnapshotRequest\"^\n" + + "!ClientInfractionsSnapshotResponse\x129\n" + + "\vinfractions\x18\x01 \x03(\v2\x17.pb.ClientInfractionMsgR\vinfractionsB'Z%github.com/bisoncraft/mesh/tatanka/pbb\x06proto3" var ( - file_tatanka_pb_messages_proto_rawDescOnce sync.Once - file_tatanka_pb_messages_proto_rawDescData []byte + file_messages_proto_rawDescOnce sync.Once + file_messages_proto_rawDescData []byte ) -func file_tatanka_pb_messages_proto_rawDescGZIP() []byte { - file_tatanka_pb_messages_proto_rawDescOnce.Do(func() { - file_tatanka_pb_messages_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_tatanka_pb_messages_proto_rawDesc), len(file_tatanka_pb_messages_proto_rawDesc))) +func file_messages_proto_rawDescGZIP() []byte { + file_messages_proto_rawDescOnce.Do(func() { + file_messages_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_messages_proto_rawDesc), len(file_messages_proto_rawDesc))) }) - return file_tatanka_pb_messages_proto_rawDescData + return file_messages_proto_rawDescData } -var file_tatanka_pb_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 18) -var file_tatanka_pb_messages_proto_goTypes = []any{ +var file_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 22) +var file_messages_proto_goTypes = []any{ (*ClientConnectionMsg)(nil), // 0: pb.ClientConnectionMsg - (*TatankaForwardRelayRequest)(nil), // 1: pb.TatankaForwardRelayRequest - (*TatankaForwardRelayResponse)(nil), // 2: pb.TatankaForwardRelayResponse - (*DiscoveryRequest)(nil), // 3: pb.DiscoveryRequest - (*DiscoveryResponse)(nil), // 4: pb.DiscoveryResponse - (*WhitelistRequest)(nil), // 5: pb.WhitelistRequest - (*WhitelistResponse)(nil), // 6: pb.WhitelistResponse - (*SourcedPrice)(nil), // 7: pb.SourcedPrice - (*SourcedPriceUpdate)(nil), // 8: pb.SourcedPriceUpdate - (*SourcedFeeRate)(nil), // 9: pb.SourcedFeeRate - (*SourcedFeeRateUpdate)(nil), // 10: pb.SourcedFeeRateUpdate - (*NodeOracleUpdate)(nil), // 11: pb.NodeOracleUpdate - (*TatankaForwardRelayResponse_ClientNotFound)(nil), // 12: pb.TatankaForwardRelayResponse.ClientNotFound - (*TatankaForwardRelayResponse_ClientRejected)(nil), // 13: pb.TatankaForwardRelayResponse.ClientRejected - (*DiscoveryResponse_Success)(nil), // 14: pb.DiscoveryResponse.Success - (*DiscoveryResponse_NotFound)(nil), // 15: pb.DiscoveryResponse.NotFound - (*WhitelistResponse_Success)(nil), // 16: pb.WhitelistResponse.Success - (*WhitelistResponse_Mismatch)(nil), // 17: pb.WhitelistResponse.Mismatch -} -var file_tatanka_pb_messages_proto_depIdxs = []int32{ - 12, // 0: pb.TatankaForwardRelayResponse.client_not_found:type_name -> pb.TatankaForwardRelayResponse.ClientNotFound - 13, // 1: pb.TatankaForwardRelayResponse.client_rejected:type_name -> pb.TatankaForwardRelayResponse.ClientRejected - 14, // 2: pb.DiscoveryResponse.success:type_name -> pb.DiscoveryResponse.Success - 15, // 3: pb.DiscoveryResponse.not_found:type_name -> pb.DiscoveryResponse.NotFound - 16, // 4: pb.WhitelistResponse.success:type_name -> pb.WhitelistResponse.Success - 17, // 5: pb.WhitelistResponse.mismatch:type_name -> pb.WhitelistResponse.Mismatch - 7, // 6: pb.SourcedPriceUpdate.prices:type_name -> pb.SourcedPrice - 9, // 7: pb.SourcedFeeRateUpdate.fee_rates:type_name -> pb.SourcedFeeRate - 8, // 8: pb.NodeOracleUpdate.price_update:type_name -> pb.SourcedPriceUpdate - 10, // 9: pb.NodeOracleUpdate.fee_rate_update:type_name -> pb.SourcedFeeRateUpdate - 10, // [10:10] is the sub-list for method output_type - 10, // [10:10] is the sub-list for method input_type - 10, // [10:10] is the sub-list for extension type_name - 10, // [10:10] is the sub-list for extension extendee - 0, // [0:10] is the sub-list for field type_name -} - -func init() { file_tatanka_pb_messages_proto_init() } -func file_tatanka_pb_messages_proto_init() { - if File_tatanka_pb_messages_proto != nil { + (*ClientBanMsg)(nil), // 1: pb.ClientBanMsg + (*ClientInfractionMsg)(nil), // 2: pb.ClientInfractionMsg + (*TatankaForwardRelayRequest)(nil), // 3: pb.TatankaForwardRelayRequest + (*TatankaForwardRelayResponse)(nil), // 4: pb.TatankaForwardRelayResponse + (*DiscoveryRequest)(nil), // 5: pb.DiscoveryRequest + (*DiscoveryResponse)(nil), // 6: pb.DiscoveryResponse + (*WhitelistRequest)(nil), // 7: pb.WhitelistRequest + (*WhitelistResponse)(nil), // 8: pb.WhitelistResponse + (*SourcedPrice)(nil), // 9: pb.SourcedPrice + (*SourcedPriceUpdate)(nil), // 10: pb.SourcedPriceUpdate + (*SourcedFeeRate)(nil), // 11: pb.SourcedFeeRate + (*SourcedFeeRateUpdate)(nil), // 12: pb.SourcedFeeRateUpdate + (*NodeOracleUpdate)(nil), // 13: pb.NodeOracleUpdate + (*ClientInfractionsSnapshotRequest)(nil), // 14: pb.ClientInfractionsSnapshotRequest + (*ClientInfractionsSnapshotResponse)(nil), // 15: pb.ClientInfractionsSnapshotResponse + (*TatankaForwardRelayResponse_ClientNotFound)(nil), // 16: pb.TatankaForwardRelayResponse.ClientNotFound + (*TatankaForwardRelayResponse_ClientRejected)(nil), // 17: pb.TatankaForwardRelayResponse.ClientRejected + (*DiscoveryResponse_Success)(nil), // 18: pb.DiscoveryResponse.Success + (*DiscoveryResponse_NotFound)(nil), // 19: pb.DiscoveryResponse.NotFound + (*WhitelistResponse_Success)(nil), // 20: pb.WhitelistResponse.Success + (*WhitelistResponse_Mismatch)(nil), // 21: pb.WhitelistResponse.Mismatch +} +var file_messages_proto_depIdxs = []int32{ + 16, // 0: pb.TatankaForwardRelayResponse.client_not_found:type_name -> pb.TatankaForwardRelayResponse.ClientNotFound + 17, // 1: pb.TatankaForwardRelayResponse.client_rejected:type_name -> pb.TatankaForwardRelayResponse.ClientRejected + 18, // 2: pb.DiscoveryResponse.success:type_name -> pb.DiscoveryResponse.Success + 19, // 3: pb.DiscoveryResponse.not_found:type_name -> pb.DiscoveryResponse.NotFound + 20, // 4: pb.WhitelistResponse.success:type_name -> pb.WhitelistResponse.Success + 21, // 5: pb.WhitelistResponse.mismatch:type_name -> pb.WhitelistResponse.Mismatch + 9, // 6: pb.SourcedPriceUpdate.prices:type_name -> pb.SourcedPrice + 11, // 7: pb.SourcedFeeRateUpdate.fee_rates:type_name -> pb.SourcedFeeRate + 10, // 8: pb.NodeOracleUpdate.price_update:type_name -> pb.SourcedPriceUpdate + 12, // 9: pb.NodeOracleUpdate.fee_rate_update:type_name -> pb.SourcedFeeRateUpdate + 2, // 10: pb.ClientInfractionsSnapshotResponse.infractions:type_name -> pb.ClientInfractionMsg + 11, // [11:11] is the sub-list for method output_type + 11, // [11:11] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name +} + +func init() { file_messages_proto_init() } +func file_messages_proto_init() { + if File_messages_proto != nil { return } - file_tatanka_pb_messages_proto_msgTypes[2].OneofWrappers = []any{ + file_messages_proto_msgTypes[4].OneofWrappers = []any{ (*TatankaForwardRelayResponse_Success)(nil), (*TatankaForwardRelayResponse_ClientNotFound_)(nil), (*TatankaForwardRelayResponse_ClientRejected_)(nil), (*TatankaForwardRelayResponse_Error)(nil), } - file_tatanka_pb_messages_proto_msgTypes[4].OneofWrappers = []any{ + file_messages_proto_msgTypes[6].OneofWrappers = []any{ (*DiscoveryResponse_Success_)(nil), (*DiscoveryResponse_NotFound_)(nil), } - file_tatanka_pb_messages_proto_msgTypes[6].OneofWrappers = []any{ + file_messages_proto_msgTypes[8].OneofWrappers = []any{ (*WhitelistResponse_Success_)(nil), (*WhitelistResponse_Mismatch_)(nil), } - file_tatanka_pb_messages_proto_msgTypes[11].OneofWrappers = []any{ + file_messages_proto_msgTypes[13].OneofWrappers = []any{ (*NodeOracleUpdate_PriceUpdate)(nil), (*NodeOracleUpdate_FeeRateUpdate)(nil), } @@ -1203,17 +1450,17 @@ func file_tatanka_pb_messages_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_tatanka_pb_messages_proto_rawDesc), len(file_tatanka_pb_messages_proto_rawDesc)), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_messages_proto_rawDesc), len(file_messages_proto_rawDesc)), NumEnums: 0, - NumMessages: 18, + NumMessages: 22, NumExtensions: 0, NumServices: 0, }, - GoTypes: file_tatanka_pb_messages_proto_goTypes, - DependencyIndexes: file_tatanka_pb_messages_proto_depIdxs, - MessageInfos: file_tatanka_pb_messages_proto_msgTypes, + GoTypes: file_messages_proto_goTypes, + DependencyIndexes: file_messages_proto_depIdxs, + MessageInfos: file_messages_proto_msgTypes, }.Build() - File_tatanka_pb_messages_proto = out.File - file_tatanka_pb_messages_proto_goTypes = nil - file_tatanka_pb_messages_proto_depIdxs = nil + File_messages_proto = out.File + file_messages_proto_goTypes = nil + file_messages_proto_depIdxs = nil } diff --git a/tatanka/pb/messages.proto b/tatanka/pb/messages.proto index 51472c6..1383b19 100644 --- a/tatanka/pb/messages.proto +++ b/tatanka/pb/messages.proto @@ -12,6 +12,23 @@ message ClientConnectionMsg { int64 timestamp = 4; } +// ClientBanMsg is used internally by tatanka nodes to share client ban information. +message ClientBanMsg { + string ip = 1; // banned client IP address + bytes reporter = 2; // peer.ID of reporting node + uint32 total_penalties = 3; + int64 expiry = 4; +} + +// ClientInfractionMsg is used to gossip individual client infractions between tatanka nodes. +message ClientInfractionMsg { + string ip = 1; // client IP address + bytes reporter = 2; // peer.ID of reporting node + uint32 infraction_type = 3; // infractionType enum value + uint32 penalty = 4; // penalty for this infraction + int64 expiry = 5; // when this infraction expires +} + message TatankaForwardRelayRequest { bytes initiator_id = 1; // peer.ID serialized as bytes bytes counterparty_id = 2; // peer.ID serialized as bytes @@ -103,3 +120,11 @@ message NodeOracleUpdate { SourcedFeeRateUpdate fee_rate_update = 2; } } + +// ClientInfractionsSnapshotRequest is sent by a tatanka node to query another node for its active client infractions. +message ClientInfractionsSnapshotRequest {} + +// ClientInfractionsSnapshotResponse contains the snapshot of active client infractions. +message ClientInfractionsSnapshotResponse { + repeated ClientInfractionMsg infractions = 1; +} diff --git a/tatanka/push_stream_manager.go b/tatanka/push_stream_manager.go index 5ccace6..fbfb010 100644 --- a/tatanka/push_stream_manager.go +++ b/tatanka/push_stream_manager.go @@ -6,11 +6,12 @@ import ( "sync" "time" + "github.com/bisoncraft/mesh/codec" + protocolsPb "github.com/bisoncraft/mesh/protocols/pb" "github.com/decred/slog" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/bisoncraft/mesh/codec" - protocolsPb "github.com/bisoncraft/mesh/protocols/pb" + ma "github.com/multiformats/go-multiaddr" ) const ( @@ -36,8 +37,11 @@ type pushStreamManager struct { log slog.Logger notifyConnected notifyConnectedFunc - mtx sync.RWMutex - pushStreams map[peer.ID]*pushStreamWrapper + pushStreamMtx sync.RWMutex + pushStreams map[peer.ID]*pushStreamWrapper + + ipsMtx sync.RWMutex + ips map[string]peer.ID } // newPushStreamManager creates a new pushStreamManager. The notifyConnectedFunc @@ -48,6 +52,7 @@ func newPushStreamManager(log slog.Logger, f notifyConnectedFunc) *pushStreamMan log: log, notifyConnected: f, pushStreams: make(map[peer.ID]*pushStreamWrapper), + ips: make(map[string]peer.ID), } } @@ -58,7 +63,16 @@ func newPushStreamManager(log slog.Logger, f notifyConnectedFunc) *pushStreamMan func (p *pushStreamManager) newPushStream(stream network.Stream) { client := stream.Conn().RemotePeer() - p.mtx.Lock() + ip, err := getIPFromStream(stream) + if err != nil { + p.log.Errorf("Failed to fetch ip from stream: %v", err) + } + + p.ipsMtx.Lock() + p.ips[ip] = client + p.ipsMtx.Unlock() + + p.pushStreamMtx.Lock() oldWrapper := p.pushStreams[client] wrapper := &pushStreamWrapper{ stream: stream, @@ -66,7 +80,7 @@ func (p *pushStreamManager) newPushStream(stream network.Stream) { } p.pushStreams[client] = wrapper newStreamTimestamp := time.Now() - p.mtx.Unlock() + p.pushStreamMtx.Unlock() if oldWrapper != nil { _ = oldWrapper.stream.Close() @@ -93,16 +107,20 @@ func (p *pushStreamManager) newPushStream(stream network.Stream) { p.log.Debugf("Error discarding data from client %s push stream: %v", client.ShortString(), err) } - p.mtx.Lock() + p.pushStreamMtx.Lock() if p.pushStreams[client] == wrapper { close(wrapper.writeCh) delete(p.pushStreams, client) timestamp := time.Now() - p.mtx.Unlock() + p.pushStreamMtx.Unlock() + + p.ipsMtx.Lock() + delete(p.ips, ip) + p.ipsMtx.Unlock() p.notifyConnected(client, timestamp, false) } else { - p.mtx.Unlock() + p.pushStreamMtx.Unlock() } }() } @@ -121,8 +139,8 @@ func (p *pushStreamManager) distribute(clients []peer.ID, msg *protocolsPb.PushM return } - p.mtx.RLock() - defer p.mtx.RUnlock() + p.pushStreamMtx.RLock() + defer p.pushStreamMtx.RUnlock() for _, client := range clients { if client == sender { @@ -138,3 +156,49 @@ func (p *pushStreamManager) distribute(clients []peer.ID, msg *protocolsPb.PushM } } } + +func (p *pushStreamManager) disconnectClientByIP(ip string) { + p.ipsMtx.RLock() + clientID, ok := p.ips[ip] + p.ipsMtx.RUnlock() + + if !ok { + p.log.Debugf("No connected client found with IP %s", ip) + return + } + + p.pushStreamMtx.RLock() + wrapper, ok := p.pushStreams[clientID] + p.pushStreamMtx.RUnlock() + + if !ok { + p.log.Debugf("No stream found for connected client %s", clientID.ShortString()) + return + } + + if err := wrapper.stream.Close(); err != nil { + p.log.Errorf("Error closing stream for client %s: %v", clientID.ShortString(), err) + return + } +} + +func getIPFromStream(s network.Stream) (string, error) { + if s == nil { + return "", errors.New("network stream cannot be nil") + } + + remoteMa := s.Conn().RemoteMultiaddr() + if remoteMa == nil { + return "", errors.New("remote multi address cannot be nil") + } + + if ipv4, err := remoteMa.ValueForProtocol(ma.P_IP4); err == nil && ipv4 != "" { + return ipv4, nil + } + + if ipv6, err := remoteMa.ValueForProtocol(ma.P_IP6); err == nil && ipv6 != "" { + return ipv6, nil + } + + return "", errors.New("failed to deduce ip address from stream") +} diff --git a/tatanka/tatanka.go b/tatanka/tatanka.go index 9f6c900..ea312b5 100644 --- a/tatanka/tatanka.go +++ b/tatanka/tatanka.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math/big" + "math/rand" "net/http" "net/http/pprof" "os" @@ -14,15 +15,18 @@ import ( "sync/atomic" "time" + "github.com/bisoncraft/mesh/codec" + "github.com/bisoncraft/mesh/oracle" + "github.com/bisoncraft/mesh/protocols" + protocolsPb "github.com/bisoncraft/mesh/protocols/pb" + "github.com/bisoncraft/mesh/tatanka/admin" + pb "github.com/bisoncraft/mesh/tatanka/pb" "github.com/decred/slog" "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/bisoncraft/mesh/oracle" - "github.com/bisoncraft/mesh/protocols" - protocolsPb "github.com/bisoncraft/mesh/protocols/pb" - "github.com/bisoncraft/mesh/tatanka/admin" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -40,6 +44,10 @@ const ( // whitelistProtocol is the protocol used to verify the whitelist alignment of a tatanka node. whitelistProtocol = "/tatanka/whitelist/1.0.0" + + // clientInfractionsSnapshotProtocol is the protocol used to query a tatanka + // node for a snapshot of active client infractions. + clientInfractionsSnapshotProtocol = "/tatanka/client-infractions-snapshot/1.0.0" ) // Config is the configuration for the tatanka node @@ -80,16 +88,18 @@ type Oracle interface { // TatankaNode is a permissioned node in the tatanka mesh type TatankaNode struct { - config *Config - node host.Host - log slog.Logger - whitelist atomic.Value // *whitelist - readyCh chan struct{} - readyOnce sync.Once - readyErr atomic.Value // error - privateKey crypto.PrivKey - bondVerifier *bondVerifier - bondStorage bondStorage + config *Config + node host.Host + log slog.Logger + whitelist atomic.Value // *whitelist + readyCh chan struct{} + readyOnce sync.Once + readyErr atomic.Value // error + privateKey crypto.PrivKey + bondVerifier *bondVerifier + bondStorage bondStorage + banManager *banManager + broadcastRateLimiter *broadcastRateLimiter gossipSub *gossipSub clientConnectionManager *clientConnectionManager @@ -153,6 +163,14 @@ func (t *TatankaNode) handleClientConnectionMessage(update *clientConnectionUpda t.clientConnectionManager.updateClientConnectionInfo(update) } +func (t *TatankaNode) handleClientInfractionMessage(msg *pb.ClientInfractionMsg) { + err := t.banManager.recordRemoteInfraction(msg) + if err != nil { + t.log.Errorf("Failed to record remote infraction: %v", err) + return + } +} + // Run starts the tatanka node and blocks until the context is done. func (t *TatankaNode) Run(ctx context.Context) error { wg := sync.WaitGroup{} @@ -192,6 +210,7 @@ func (t *TatankaNode) Run(ctx context.Context) error { handleBroadcastMessage: t.handleBroadcastMessage, handleClientConnectionMessage: t.handleClientConnectionMessage, handleOracleUpdate: t.handleOracleUpdate, + handleClientInfractionMessage: t.handleClientInfractionMessage, }) if err != nil { t.markReady(err) @@ -224,6 +243,20 @@ func (t *TatankaNode) Run(ctx context.Context) error { } } + t.banManager = newBanManager(&banManagerConfig{ + disconnectClient: t.pushStreamManager.disconnectClientByIP, + nodeID: t.node.ID(), + publishInfraction: t.gossipSub.publishClientInfractionMessage, + now: time.Now, + log: t.config.Logger, + }) + + t.broadcastRateLimiter = newBroadcastRateLimiter(&rateLimitConfig{ + recordInfraction: t.banManager.recordInfraction, + now: time.Now, + log: t.config.Logger, + }) + // Create admin callback function and setup the admin server if configured. adminCallback := func(peerID peer.ID, connected bool, whitelistMismatch bool, addresses []string, peerWhitelist []string) { } @@ -292,6 +325,20 @@ func (t *TatankaNode) Run(ctx context.Context) error { } }() + // Run ban manager + wg.Add(1) + go func() { + defer wg.Done() + t.banManager.run(ctx) + }() + + // Run broadcast rate limiter + wg.Add(1) + go func() { + defer wg.Done() + t.broadcastRateLimiter.run(ctx) + }() + // Maintain mesh connectivity wg.Add(1) go func() { @@ -301,6 +348,11 @@ func (t *TatankaNode) Run(ctx context.Context) error { // Wait for the initial connectivity pass to finish before reporting ready. t.connectionManager.waitInitial(ctx) + + // Query one random connected peer for their infraction snapshot + // (all peers maintain equivalent sets via gossip) + t.syncInfractionsFromRandomPeer(ctx) + t.markReady(nil) // Run Oracle @@ -341,6 +393,81 @@ func (t *TatankaNode) WaitReady(ctx context.Context) error { } } +// fetchClientInfractionsSnapshot fetches a snapshot of active client infractions from a peer. +func (t *TatankaNode) fetchClientInfractionsSnapshot(ctx context.Context, peerID peer.ID) ([]*pb.ClientInfractionMsg, error) { + s, err := t.node.NewStream(ctx, peerID, clientInfractionsSnapshotProtocol) + if err != nil { + return nil, fmt.Errorf("failed to open stream: %w", err) + } + defer func() { _ = s.Close() }() + + request := &pb.ClientInfractionsSnapshotRequest{} + if err := codec.WriteLengthPrefixedMessage(s, request); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + response := &pb.ClientInfractionsSnapshotResponse{} + if err := codec.ReadLengthPrefixedMessage(s, response); err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + return response.Infractions, nil +} + +// syncInfractionsFromRandomPeer queries a random connected peer for their infraction snapshot. +func (t *TatankaNode) syncInfractionsFromRandomPeer(ctx context.Context) { + whitelist := t.getWhitelist() + + var connectedPeers []peer.ID + for peerID := range whitelist.allPeerIDs() { + if peerID == t.node.ID() { + continue + } + if t.node.Network().Connectedness(peerID) == network.Connected { + connectedPeers = append(connectedPeers, peerID) + } + } + + if len(connectedPeers) == 0 { + t.log.Info("No connected peers to sync infractions from") + return + } + + rand.Shuffle(len(connectedPeers), func(i, j int) { + connectedPeers[i], connectedPeers[j] = connectedPeers[j], connectedPeers[i] + }) + + for _, peerID := range connectedPeers { + infractions, err := t.fetchClientInfractionsSnapshot(ctx, peerID) + if err != nil { + t.log.Warnf("Failed to fetch infractions from peer %s: %v", peerID.ShortString(), err) + continue + } + + // Deduplicate to prevent gossip re-delivery from creating duplicates + seen := make(map[string]bool) + var dedupedInfractions []*pb.ClientInfractionMsg + for _, infraction := range infractions { + infractionKey := InfractionDedupKey(infraction.Ip, infraction.Reporter, infraction.InfractionType, infraction.Expiry) + if !seen[infractionKey] { + seen[infractionKey] = true + dedupedInfractions = append(dedupedInfractions, infraction) + } + } + + for _, infraction := range dedupedInfractions { + if err := t.banManager.recordRemoteInfraction(infraction); err != nil { + t.log.Errorf("Failed to record remote infraction from snapshot: %v", err) + } + } + + t.log.Infof("Synced %d infractions from peer %s", len(dedupedInfractions), peerID.ShortString()) + return + } + + t.log.Warn("Failed to sync infractions from any connected peer") +} + // markReady signals that initialization is complete. If an error occurred, // it is stored for WaitReady callers. Only the first call takes effect. func (t *TatankaNode) markReady(err error) { @@ -382,15 +509,23 @@ func getOrCreatePrivateKey(filePath string) (crypto.PrivKey, error) { } func (t *TatankaNode) setupStreamHandlers() { - t.setStreamHandler(protocols.PostBondsProtocol, t.handlePostBonds, requireNoPermission) t.setStreamHandler(forwardRelayProtocol, t.handleForwardRelay, t.isWhitelistPeer) - t.setStreamHandler(protocols.ClientSubscribeProtocol, t.handleClientSubscribe, t.requireBonds) - t.setStreamHandler(protocols.ClientPublishProtocol, t.handleClientPublish, t.requireBonds) - t.setStreamHandler(protocols.ClientPushProtocol, t.handleClientPush, t.requireBonds) - t.setStreamHandler(protocols.ClientRelayMessageProtocol, t.handleClientRelayMessage, t.requireBonds) - t.setStreamHandler(protocols.AvailableMeshNodesProtocol, t.handleAvailableMeshNodes, t.requireBonds) t.setStreamHandler(discoveryProtocol, t.handleDiscovery, t.isWhitelistPeer) t.setStreamHandler(whitelistProtocol, t.handleWhitelist, t.isWhitelistPeer) + t.setStreamHandler(clientInfractionsSnapshotProtocol, t.handleClientInfractionsSnapshot, t.isWhitelistPeer) + + t.setStreamHandler(protocols.PostBondsProtocol, t.handlePostBonds, + requireAll(t.requireNotBanned, requireNoPermission)) + t.setStreamHandler(protocols.ClientSubscribeProtocol, t.handleClientSubscribe, + requireAll(t.requireNotBanned, t.requireBonds)) + t.setStreamHandler(protocols.ClientPublishProtocol, t.handleClientPublish, + requireAll(t.requireNotBanned, t.requireBonds)) + t.setStreamHandler(protocols.ClientPushProtocol, t.handleClientPush, + requireAll(t.requireNotBanned, t.requireBonds)) + t.setStreamHandler(protocols.ClientRelayMessageProtocol, t.handleClientRelayMessage, + requireAll(t.requireNotBanned, t.requireBonds)) + t.setStreamHandler(protocols.AvailableMeshNodesProtocol, t.handleAvailableMeshNodes, + requireAll(t.requireNotBanned, t.requireBonds)) } func (t *TatankaNode) setupObservability() { diff --git a/tatanka/tatanka_test.go b/tatanka/tatanka_test.go index 93f34ff..a69598c 100644 --- a/tatanka/tatanka_test.go +++ b/tatanka/tatanka_test.go @@ -1096,8 +1096,10 @@ func TestClientSubscriptionAndBroadcast(t *testing.T) { rng := rand.New(rand.NewSource(time.Now().UnixNano())) - // Run 100 iterations of random publish/receive - for iteration := 0; iteration < 100; iteration++ { + // Run 20 iterations of random publish/receive with sufficient delay to avoid rate limiting. + // Rate limiter allows 4 messages/sec per client with burst of 8. + // With 6 clients and random publisher selection, 250ms delay ensures we stay within limits. + for iteration := 0; iteration < 20; iteration++ { // Randomly select a topic var topic string var subscribers []*testClient @@ -1148,8 +1150,8 @@ func TestClientSubscriptionAndBroadcast(t *testing.T) { iteration, subscriber.host.ID().ShortString(), topic) } - // Small delay between iterations - time.Sleep(50 * time.Millisecond) + // Delay between iterations to stay within rate limits (4 msgs/sec per client) + time.Sleep(250 * time.Millisecond) } // Terminate clients. @@ -1673,3 +1675,96 @@ func TestClientSubscriptionEvents(t *testing.T) { clients[idx].Close() } } + +func TestInfractionSnapshotSync(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + mnet, err := mocknet.WithNPeers(3) + if err != nil { + t.Fatal(err) + } + + allPeers := mnet.Peers() + h1 := mnet.Host(allPeers[0]) + h2 := mnet.Host(allPeers[1]) + h3 := mnet.Host(allPeers[2]) + + whitelistPeers := []*peer.AddrInfo{ + {ID: h1.ID(), Addrs: h1.Addrs()}, + {ID: h2.ID(), Addrs: h2.Addrs()}, + {ID: h3.ID(), Addrs: h3.Addrs()}, + } + mockWhitelist := &whitelist{peers: whitelistPeers} + + t.Log("Starting Node 1") + dir1 := t.TempDir() + node1 := newTestNode(t, ctx, h1, dir1, mockWhitelist) + + t.Log("Starting Node 2") + dir2 := t.TempDir() + if _, err := mnet.LinkPeers(h2.ID(), h1.ID()); err != nil { + t.Fatalf("Failed to link h2 to h1: %v", err) + } + if _, err := mnet.ConnectPeers(h2.ID(), h1.ID()); err != nil { + t.Fatalf("Failed to connect h2 to h1: %v", err) + } + node2 := newTestNode(t, ctx, h2, dir2, mockWhitelist) + + time.Sleep(1 * time.Second) + checkFullyConnected(t, []*TatankaNode{node1, node2}) + + t.Log("Recording infractions on Node 1") + testIPs := []string{"192.168.1.100", "192.168.1.101"} + testClientID := peer.ID("test-client-id") + + for _, ip := range testIPs { + if err := node1.banManager.recordInfraction(ip, testClientID, MalformedMessage); err != nil { + t.Fatalf("Failed to record infraction: %v", err) + } + } + + time.Sleep(1 * time.Second) + + t.Log("Starting Node 3") + dir3 := t.TempDir() + + if _, err := mnet.LinkPeers(h3.ID(), h1.ID()); err != nil { + t.Fatalf("Failed to link h3 to h1: %v", err) + } + if _, err := mnet.ConnectPeers(h3.ID(), h1.ID()); err != nil { + t.Fatalf("Failed to connect h3 to h1: %v", err) + } + + node3 := newTestNode(t, ctx, h3, dir3, mockWhitelist) + + readyCtx, readyCancel := context.WithTimeout(ctx, 5*time.Second) + if err := node3.WaitReady(readyCtx); err != nil { + readyCancel() + t.Fatalf("Node 3 failed to be ready: %v", err) + } + readyCancel() + + time.Sleep(500 * time.Millisecond) + + t.Log("Verifying Node 3 synced infractions") + node3Infractions := node3.banManager.getActiveInfractions() + + if len(node3Infractions) < len(testIPs) { + t.Fatalf("Expected at least %d infractions on node 3, got %d", len(testIPs), len(node3Infractions)) + } + + infractionIPs := make(map[string]bool) + for _, inf := range node3Infractions { + infractionIPs[inf.Ip] = true + } + + for _, testIP := range testIPs { + if !infractionIPs[testIP] { + t.Errorf("Expected IP %s in node 3 infractions, but it wasn't found", testIP) + } + } + + t.Logf("Node 3 synced %d infractions containing all expected IPs", len(node3Infractions)) + t.Log("SUCCESS: Node 3 successfully synced infractions from peers on startup") +} diff --git a/testing/client/client.go b/testing/client/client.go index c7b46f2..33a6466 100644 --- a/testing/client/client.go +++ b/testing/client/client.go @@ -1,9 +1,11 @@ package client import ( + "bytes" "context" "embed" "encoding/base64" + "encoding/json" "encoding/hex" "errors" "fmt" @@ -42,6 +44,7 @@ type Config struct { ClientPort int WebPort int Logger slog.Logger + Spam bool } // Client represents a tatanka test client. @@ -394,6 +397,57 @@ func writeSSEEvent(w http.ResponseWriter, e Event) error { return err } +func (c *Client) publishSpam(ctx context.Context) { + // Wait a few seconds for the server to start + select { + case <-time.After(2 * time.Second): + case <-ctx.Done(): + return + } + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + client := &http.Client{ + Timeout: 5 * time.Second, + } + + c.log.Infof("Periodically publishing spam messages") + + for { + select { + case <-ticker.C: + now := time.Now().UnixMilli() + timeStr := strconv.FormatInt(now, 10) + + payload := map[string]string{ + "topic": "time", + "data": base64.StdEncoding.EncodeToString([]byte(timeStr)), + } + + payloadJSON, err := json.Marshal(payload) + if err != nil { + c.log.Debugf("Failed to marshal spam payload: %v", err) + continue + } + + resp, err := client.Post( + fmt.Sprintf("http://localhost:%d/broadcast", c.cfg.WebPort), + "application/json", + bytes.NewReader(payloadJSON), + ) + if err != nil { + c.log.Debugf("Failed to publish spam message: %v", err) + continue + } + resp.Body.Close() + + case <-ctx.Done(): + return + } + } +} + func (c *Client) Run(ctx context.Context, bonds []*bond.BondParams) { c.log.Infof("Running test client ...") @@ -425,6 +479,14 @@ func (c *Client) Run(ctx context.Context, bonds []*bond.BondParams) { } }() + if c.cfg.Spam { + wg.Add(1) + go func() { + defer wg.Done() + c.publishSpam(ctx) + }() + } + wg.Wait() } diff --git a/testing/harness.sh b/testing/harness.sh index 7c45f2e..1ff4d93 100755 --- a/testing/harness.sh +++ b/testing/harness.sh @@ -68,6 +68,7 @@ create_client_config() { local node_addr=$2 local client_port=$3 local web_port=$4 + local spam=$5 local config_path=$client_dir/testclient.conf cat < $config_path @@ -76,6 +77,7 @@ loglevel=debug nodeaddr=$node_addr clientport=$client_port webport=$web_port +spam=$spam EOF echo $config_path @@ -84,6 +86,7 @@ EOF start_harness() { local num_nodes=$1 local num_clients=$2 + local spam=${3:-false} mkdir -p $ROOT_DIR @@ -146,7 +149,7 @@ start_harness() { client_port=$((12455 + i)) web_port=$((12465 + i)) - client_cfg=$(create_client_config $client_dir "$node_addr" $client_port $web_port) + client_cfg=$(create_client_config $client_dir "$node_addr" $client_port $web_port $spam) tmux new-window -t $session_name -n client-$i tmux send-keys -t $session_name "ROOT_DIR=$ROOT_DIR $TESTCLIENT_BIN -C $client_cfg" C-m @@ -158,14 +161,31 @@ start_harness() { # Check if number of nodes and clients are provided if [ $# -lt 2 ]; then - echo "Usage: $0 " + echo "Usage: $0 [-s|--spam]" echo " num_nodes: Number of tatanka nodes to start" echo " num_clients: Number of test clients to start" + echo " -s, --spam: Enable spam publisher on clients (default: disabled)" exit 1 fi num_nodes=$1 num_clients=$2 +spam=false + +# Parse remaining arguments for flags +shift 2 +while [ $# -gt 0 ]; do + case "$1" in + -s|--spam) + spam=true + shift + ;; + *) + echo "Error: Unknown option '$1'" + exit 1 + ;; + esac +done # Validate that the argument is a positive integer if ! [[ "$num_nodes" =~ ^[0-9]+$ ]] || [ "$num_nodes" -eq 0 ]; then @@ -178,6 +198,6 @@ if ! [[ "$num_clients" =~ ^[0-9]+$ ]] || [ "$num_clients" -lt 0 ]; then exit 1 fi -start_harness $num_nodes $num_clients +start_harness $num_nodes $num_clients $spam tmux attach-session -t $session_name \ No newline at end of file