ai-code-battle/ratelimit/ratelimit_test.go
jedarden 7e131d310f feat(api): add token-bucket rate limiting to public endpoints
Adds ratelimit package with per-IP and per-key HTTP middleware.
Applied to register (5/hr), feedback (20/hr), predict (60/hr),
and job submission (5/day) endpoints. Includes metrics counter
for rejected requests and periodic bucket cleanup goroutine.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-22 16:52:29 -04:00

252 lines
6.5 KiB
Go

package ratelimit
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestBucketAllowsUpToMax(t *testing.T) {
b := NewBucket(3, 1.0) // 3 tokens, 1/sec refill
for i := 0; i < 3; i++ {
if !b.Allow() {
t.Fatalf("request %d should be allowed", i+1)
}
}
if b.Allow() {
t.Fatal("4th request should be denied")
}
}
func TestBucketRefills(t *testing.T) {
b := NewBucket(1, 100.0) // 1 token, 100/sec refill (fast refill)
if !b.Allow() {
t.Fatal("first request should be allowed")
}
if b.Allow() {
t.Fatal("second request should be denied (bucket empty)")
}
// Wait for refill
time.Sleep(20 * time.Millisecond)
if !b.Allow() {
t.Fatal("request after refill should be allowed")
}
}
func TestBucketRetryAfter(t *testing.T) {
b := NewBucket(1, 1.0) // 1 token, 1/sec refill
b.Allow() // drain
retry := b.RetryAfter()
if retry < 1 {
t.Fatalf("RetryAfter = %d, want >= 1", retry)
}
}
func TestLimiterCreatesBucketsOnDemand(t *testing.T) {
l := NewLimiter(2, 1.0)
_, ok1 := l.Allow("key-a")
if !ok1 {
t.Fatal("first request for key-a should be allowed")
}
_, ok2 := l.Allow("key-b")
if !ok2 {
t.Fatal("first request for key-b should be allowed")
}
// key-a has 1 token left
_, ok3 := l.Allow("key-a")
if !ok3 {
t.Fatal("second request for key-a should be allowed")
}
_, ok4 := l.Allow("key-a")
if ok4 {
t.Fatal("third request for key-a should be denied")
}
// key-b still has 1 token
_, ok5 := l.Allow("key-b")
if !ok5 {
t.Fatal("second request for key-b should be allowed (independent bucket)")
}
}
func TestMiddlewareAllowsWhenUnderLimit(t *testing.T) {
l := NewLimiter(2, 1.0)
called := false
handler := l.Middleware(func(r *http.Request) string { return "ip1" }, nil)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}),
)
req := httptest.NewRequest("POST", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("handler should have been called")
}
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestMiddlewareRejectsOnLimit(t *testing.T) {
l := NewLimiter(1, 0.0001) // 1 token, extremely slow refill
onLimitCalled := false
mw := l.Middleware(func(r *http.Request) string { return "ip1" }, func() {
onLimitCalled = true
})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request passes
req := httptest.NewRequest("POST", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("first request: status = %d, want 200", w.Code)
}
// Second request is rate limited
req2 := httptest.NewRequest("POST", "/test", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w2.Code != http.StatusTooManyRequests {
t.Fatalf("second request: status = %d, want 429", w2.Code)
}
if !onLimitCalled {
t.Fatal("onLimit callback should have been called")
}
if h := w2.Header().Get("Retry-After"); h == "" {
t.Fatal("Retry-After header should be set")
}
var body map[string]string
json.NewDecoder(w2.Body).Decode(&body)
if body["error"] != "rate limit exceeded" {
t.Fatalf("error body = %q, want 'rate limit exceeded'", body["error"])
}
}
func TestMiddlewareKeysByIP(t *testing.T) {
l := NewLimiter(1, 0.0001) // 1 token per key
keyCount := 0
mw := l.Middleware(func(r *http.Request) string {
keyCount++
return r.RemoteAddr
}, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First IP gets one request
req1 := httptest.NewRequest("POST", "/test", nil)
req1.RemoteAddr = "1.2.3.4:1234"
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
// Second IP also gets one request (different bucket)
req2 := httptest.NewRequest("POST", "/test", nil)
req2.RemoteAddr = "5.6.7.8:5678"
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w1.Code != http.StatusOK {
t.Fatalf("first IP: status = %d, want 200", w1.Code)
}
if w2.Code != http.StatusOK {
t.Fatalf("second IP: status = %d, want 200", w2.Code)
}
}
func TestCleanupRemovesStaleBuckets(t *testing.T) {
l := NewLimiter(1, 1.0)
l.Allow("stale-key")
l.Allow("fresh-key")
if len(l.buckets) != 2 {
t.Fatalf("expected 2 buckets, got %d", len(l.buckets))
}
// Manually age the stale bucket
l.buckets["stale-key"].mu.Lock()
l.buckets["stale-key"].lastTime = time.Now().Add(-2 * time.Hour)
l.buckets["stale-key"].mu.Unlock()
l.Cleanup(time.Hour)
if len(l.buckets) != 1 {
t.Fatalf("expected 1 bucket after cleanup, got %d", len(l.buckets))
}
if _, ok := l.buckets["stale-key"]; ok {
t.Fatal("stale-key should have been cleaned up")
}
if _, ok := l.buckets["fresh-key"]; !ok {
t.Fatal("fresh-key should still be present")
}
}
func TestFloodRegisterEndpoint(t *testing.T) {
// Simulates the verification requirement: flood test against /register
// returns 429 after threshold.
l := NewLimiter(5, 5.0/3600) // 5/hour per IP
mw := l.Middleware(func(r *http.Request) string { return "1.2.3.4" }, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
}))
okCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("POST", "/api/register", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code == http.StatusCreated {
okCount++
}
}
if okCount != 5 {
t.Fatalf("expected 5 successful requests, got %d", okCount)
}
// 6th request should be 429
req := httptest.NewRequest("POST", "/api/register", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("6th request: status = %d, want 429", w.Code)
}
}
func TestLegitimateTrafficPasses(t *testing.T) {
// Verifies that normal traffic patterns don't get rate limited.
// 3 registrations from different IPs should all succeed.
l := NewLimiter(5, 5.0/3600) // 5/hour per IP
mw := l.Middleware(func(r *http.Request) string { return r.RemoteAddr }, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
}))
for i, ip := range []string{"1.1.1.1:1111", "2.2.2.2:2222", "3.3.3.3:3333"} {
req := httptest.NewRequest("POST", "/api/register", nil)
req.RemoteAddr = ip
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("request %d from %s: status = %d, want 201", i+1, ip, w.Code)
}
}
}