Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions app/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,23 @@ func TestAllowRedisTokenBucketReject(t *testing.T) {
<-done
}

func TestAllowRedisTokenBucketError(t *testing.T) {
rl := NewRateLimiter(1, time.Second, "token_bucket")
t.Cleanup(rl.Stop)
srv, cli := net.Pipe()
done := make(chan struct{})
go func() {
defer func() { srv.Close(); close(done) }()
br := bufio.NewReader(srv)
parseRedisCommand(t, br)
srv.Write([]byte("-ERR nope\r\n"))
}()
if _, err := rl.allowRedisTokenBucket(cli, "k"); err == nil {
t.Fatal("expected redis error")
}
<-done
}

func TestAllowRedisLeakyBucketPoolFullClosesConnection(t *testing.T) {
old := *redisAddr
*redisAddr = "dummy"
Expand Down Expand Up @@ -815,6 +832,23 @@ func TestAllowRedisLeakyBucketReject(t *testing.T) {
<-done
}

func TestAllowRedisLeakyBucketError(t *testing.T) {
rl := NewRateLimiter(1, time.Second, "leaky_bucket")
t.Cleanup(rl.Stop)
srv, cli := net.Pipe()
done := make(chan struct{})
go func() {
defer func() { srv.Close(); close(done) }()
br := bufio.NewReader(srv)
parseRedisCommand(t, br)
srv.Write([]byte("-ERR boom\r\n"))
}()
if _, err := rl.allowRedisLeakyBucket(cli, "k"); err == nil {
t.Fatal("expected redis GET error")
}
<-done
}

func TestAllowRedisLeakyBucketAllow(t *testing.T) {
rl := NewRateLimiter(1, time.Hour, "leaky_bucket")
t.Cleanup(rl.Stop)
Expand Down
139 changes: 139 additions & 0 deletions app/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,41 @@ import (
"fmt"
"io"
"net"
"strconv"
"strings"
"testing"
"time"
)

func readRESPCommand(t *testing.T, br *bufio.Reader) []string {
t.Helper()

line, err := br.ReadString('\n')
if err != nil {
t.Fatalf("read command length: %v", err)
}
if !strings.HasPrefix(line, "*") {
t.Fatalf("unexpected command prefix %q", line)
}
n, err := strconv.Atoi(strings.TrimSpace(line[1:]))
if err != nil {
t.Fatalf("parse command length: %v", err)
}

args := make([]string, n)
for i := 0; i < n; i++ {
if _, err := br.ReadString('\n'); err != nil {
t.Fatalf("read bulk len: %v", err)
}
arg, err := br.ReadString('\n')
if err != nil {
t.Fatalf("read bulk data: %v", err)
}
args[i] = strings.TrimSpace(arg)
}
return args
}

func TestRedisCmdInt(t *testing.T) {
srv, cli := net.Pipe()
defer srv.Close()
Expand Down Expand Up @@ -268,6 +299,23 @@ func TestRedisCmdStringReadError(t *testing.T) {
}
}

func TestRedisCmdStringBulkReadError(t *testing.T) {
srv, cli := net.Pipe()
defer cli.Close()

go func() {
br := bufio.NewReader(srv)
br.ReadBytes('\n')
br.ReadBytes('\n')
srv.Write([]byte("$4\r\n"))
srv.Close()
}()

if _, err := redisCmdString(cli, "GET", "key"); err == nil {
t.Fatal("expected bulk read error")
}
}

func TestRedisCmdStringSimple(t *testing.T) {
srv, cli := net.Pipe()
defer srv.Close()
Expand Down Expand Up @@ -309,3 +357,94 @@ func TestRedisCmdStringInteger(t *testing.T) {
t.Fatalf("expected 5, got %q", val)
}
}

func TestAllowRedisTokenBucketEmpty(t *testing.T) {
oldAddr := *redisAddr
*redisAddr = "redis://example:6379"
t.Cleanup(func() { *redisAddr = oldAddr })

rl := NewRateLimiter(1, time.Second, "token_bucket")
srv, cli := net.Pipe()
defer srv.Close()
defer cli.Close()

now := time.Now().UnixNano()

go func() {
br := bufio.NewReader(srv)

// GET k
readRESPCommand(t, br)
payload := fmt.Sprintf("0 %d", now)
srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload)))

// SET k <payload>
readRESPCommand(t, br)
srv.Write([]byte("+OK\r\n"))

// PEXPIRE k <ttl>
readRESPCommand(t, br)
srv.Write([]byte(":1\r\n"))
}()

allowed, err := rl.allowRedisTokenBucket(cli, "k")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if allowed {
t.Fatal("expected request to be rate limited when bucket is empty")
}
}

func TestAllowRedisLeakyBucketOverLimit(t *testing.T) {
oldAddr := *redisAddr
*redisAddr = "redis://example:6379"
t.Cleanup(func() { *redisAddr = oldAddr })

rl := NewRateLimiter(1, time.Second, "leaky_bucket")
now := time.Now().UnixNano()
srv, cli := net.Pipe()
defer srv.Close()
defer cli.Close()

go func() {
br := bufio.NewReader(srv)

// GET k
readRESPCommand(t, br)
payload := fmt.Sprintf("2 %d", now)
srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload)))

// SET k <payload>
readRESPCommand(t, br)
srv.Write([]byte("+OK\r\n"))

// PEXPIRE k <ttl>
readRESPCommand(t, br)
srv.Write([]byte(":1\r\n"))
}()

allowed, err := rl.allowRedisLeakyBucket(cli, "k")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if allowed {
t.Fatal("expected request to be rate limited when bucket is over limit")
}
}

func TestRetryAfterRedisTLSMissingCA(t *testing.T) {
oldAddr := *redisAddr
oldCA := *redisCA
*redisAddr = "rediss://example.com:6379"
*redisCA = "does-not-exist"
t.Cleanup(func() {
*redisAddr = oldAddr
*redisCA = oldCA
})

rl := NewRateLimiter(1, time.Second, "fixed_window")
if _, err := rl.retryAfterRedis("key"); err == nil {
t.Fatal("expected error when CA file cannot be read")
}
}
Loading