diff --git a/cmd/acb-api/main.go b/cmd/acb-api/main.go index 9f41bd4..02143cc 100644 --- a/cmd/acb-api/main.go +++ b/cmd/acb-api/main.go @@ -16,6 +16,7 @@ import ( "time" "github.com/aicodebattle/acb/metrics" + "github.com/aicodebattle/acb/ratelimit" _ "github.com/lib/pq" "github.com/redis/go-redis/v9" ) @@ -76,12 +77,27 @@ func main() { defer rdb.Close() srv := &Server{ - cfg: cfg, - db: db, - rdb: rdb, - // Note: alerter moved to acb-matchmaker deployment + cfg: cfg, + db: db, + rdb: rdb, + regLimiter: ratelimit.NewLimiter(5, 5.0/3600), // 5/hour per IP + feedbackLtr: ratelimit.NewLimiter(20, 20.0/3600), // 20/hour per IP + predictLtr: ratelimit.NewLimiter(60, 60.0/3600), // 60/hour per IP + submitLtr: ratelimit.NewLimiter(5, 5.0/86400), // 5/day per key } + // Periodically purge stale rate-limit buckets (every 10 min) + go func() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + for range ticker.C { + srv.regLimiter.Cleanup(time.Hour) + srv.feedbackLtr.Cleanup(time.Hour) + srv.predictLtr.Cleanup(time.Hour) + srv.submitLtr.Cleanup(24 * time.Hour) + } + }() + mux := http.NewServeMux() srv.RegisterRoutes(mux) diff --git a/cmd/acb-api/server.go b/cmd/acb-api/server.go index f0cef15..1a79b65 100644 --- a/cmd/acb-api/server.go +++ b/cmd/acb-api/server.go @@ -7,12 +7,15 @@ import ( "fmt" "io" "log" + "net" "net/http" "os" "strconv" "strings" "time" + "github.com/aicodebattle/acb/metrics" + "github.com/aicodebattle/acb/ratelimit" "github.com/redis/go-redis/v9" ) @@ -20,39 +23,80 @@ import ( // Provides bot registration, job coordination, replay serving, // bot profiles, leaderboards, and UI feedback ingestion. type Server struct { - cfg Config - db *sql.DB - rdb *redis.Client + cfg Config + db *sql.DB + rdb *redis.Client + regLimiter *ratelimit.Limiter // 5/hour per IP + feedbackLtr *ratelimit.Limiter // 20/hour per IP + predictLtr *ratelimit.Limiter // 60/hour per IP + submitLtr *ratelimit.Limiter // 5/day per bot_id } func (s *Server) RegisterRoutes(mux *http.ServeMux) { - // Health endpoints + // Health endpoints (no rate limit) mux.HandleFunc("GET /health", s.handleHealth) mux.HandleFunc("GET /ready", s.handleReady) - // Bot registration - mux.HandleFunc("POST /api/register", s.handleRegister) + // Bot registration — 5/hour per IP + regMW := s.regLimiter.Middleware(ipKey, func() { + metrics.RateLimitHits.WithLabelValues("register").Inc() + }) + mux.HandleFunc("POST /api/register", regMW(http.HandlerFunc(s.handleRegister)).ServeHTTP) - // Job coordination (for workers) + // Job coordination (for workers — authenticated, no public rate limit) mux.HandleFunc("GET /api/job", s.handleGetJob) - mux.HandleFunc("POST /api/job/", s.handleJobResult) - // Replay serving + // Job result submission — per-worker 5/day limit + submitMW := s.submitLtr.Middleware(botIDKey(), func() { + metrics.RateLimitHits.WithLabelValues("submit").Inc() + }) + mux.HandleFunc("POST /api/job/", submitMW(http.HandlerFunc(s.handleJobResult)).ServeHTTP) + + // Replay serving (read-only, no rate limit) mux.HandleFunc("GET /api/replay/", s.handleGetReplay) - // Bot profiles and leaderboard + // Bot profiles and leaderboard (read-only, no rate limit) mux.HandleFunc("GET /api/bot/", s.handleGetBot) mux.HandleFunc("GET /api/bots", s.handleListBots) - // Community replay feedback per plan §13.6 - mux.HandleFunc("POST /api/feedback", s.handleUIFeedback) + // Community replay feedback — 20/hour per IP + fbMW := s.feedbackLtr.Middleware(ipKey, func() { + metrics.RateLimitHits.WithLabelValues("feedback").Inc() + }) + mux.HandleFunc("POST /api/feedback", fbMW(http.HandlerFunc(s.handleUIFeedback)).ServeHTTP) - // Predictions - mux.HandleFunc("POST /api/predict", s.handlePredict) + // Predictions — 60/hour per IP + predMW := s.predictLtr.Middleware(ipKey, func() { + metrics.RateLimitHits.WithLabelValues("predict").Inc() + }) + mux.HandleFunc("POST /api/predict", predMW(http.HandlerFunc(s.handlePredict)).ServeHTTP) mux.HandleFunc("GET /api/predictions/open", s.handleOpenPredictions) mux.HandleFunc("GET /api/predictions/history", s.handlePredictionHistory) } +// ipKey extracts the client IP from the request for rate limiting. +// Respects X-Forwarded-For when present (behind reverse proxy). +func ipKey(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + host, _, _ := net.SplitHostPort(r.RemoteAddr) + return host +} + +// botIDKey extracts a rate-limit key for the job submission endpoint. +// Workers submit results authenticated by API key, so we key by worker IP +// to enforce the per-worker submission rate limit (max 5/day). +func botIDKey() func(*http.Request) string { + return func(r *http.Request) string { + host, _, _ := net.SplitHostPort(r.RemoteAddr) + return "worker:" + host + } +} + func writeJSON(w http.ResponseWriter, status int, v any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) diff --git a/cmd/acb-api/server_test.go b/cmd/acb-api/server_test.go index 170fed8..c54de61 100644 --- a/cmd/acb-api/server_test.go +++ b/cmd/acb-api/server_test.go @@ -5,6 +5,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/aicodebattle/acb/ratelimit" ) // newTestServer creates a Server with no database or redis (for unit tests @@ -16,6 +18,10 @@ func newTestServer() *Server { BotTimeoutSecs: 5, MaxConsecFails: 3, }, + regLimiter: ratelimit.NewLimiter(5, 5.0/3600), + feedbackLtr: ratelimit.NewLimiter(20, 20.0/3600), + predictLtr: ratelimit.NewLimiter(60, 60.0/3600), + submitLtr: ratelimit.NewLimiter(5, 5.0/86400), } } diff --git a/metrics/metrics.go b/metrics/metrics.go index fca2312..9590846 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -109,6 +109,12 @@ var ( Help: "Match execution duration in seconds.", Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600}, }) + + // RateLimitHits counts requests rejected by rate limiting. + RateLimitHits = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "acb_rate_limit_hits_total", + Help: "Total number of requests rejected by rate limiting.", + }, []string{"endpoint"}) ) func init() { @@ -128,6 +134,7 @@ func init() { WorkerMatchErrorsTotal, WorkerJobsClaimedTotal, WorkerMatchDuration, + RateLimitHits, ) } diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 0000000..27df3f4 --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,134 @@ +// Package ratelimit provides token-bucket rate limiting for HTTP handlers. +package ratelimit + +import ( + "fmt" + "net/http" + "sync" + "time" +) + +// Bucket is a token-bucket rate limiter for a single key. +type Bucket struct { + mu sync.Mutex + tokens float64 + max float64 + refill float64 // tokens added per second + lastTime time.Time +} + +// NewBucket creates a bucket that holds max tokens and refills at the given +// rate (tokens per second). The bucket starts full. +func NewBucket(max, refillPerSec float64) *Bucket { + return &Bucket{ + tokens: max, + max: max, + refill: refillPerSec, + lastTime: time.Now(), + } +} + +// Allow consumes one token. Returns true if a token was available. +func (b *Bucket) Allow() bool { + b.mu.Lock() + defer b.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(b.lastTime).Seconds() + b.lastTime = now + b.tokens += elapsed * b.refill + if b.tokens > b.max { + b.tokens = b.max + } + if b.tokens < 1 { + return false + } + b.tokens-- + return true +} + +// RetryAfter returns the number of seconds until the next token is available. +// Call after Allow() returns false. +func (b *Bucket) RetryAfter() int { + b.mu.Lock() + defer b.mu.Unlock() + + deficit := 1.0 - b.tokens + if deficit <= 0 { + return 1 + } + secs := deficit / b.refill + if secs < 1 { + return 1 + } + return int(secs) +} + +// Limiter holds a collection of buckets keyed by string (e.g. "ip:endpoint"). +type Limiter struct { + mu sync.Mutex + buckets map[string]*Bucket + max float64 + refill float64 +} + +// NewLimiter creates a Limiter where each key gets max tokens, refilling at +// refillPerSec tokens per second. +func NewLimiter(max, refillPerSec float64) *Limiter { + return &Limiter{ + buckets: make(map[string]*Bucket), + max: max, + refill: refillPerSec, + } +} + +// Allow checks the bucket for the given key. Creates one if needed. +func (l *Limiter) Allow(key string) (*Bucket, bool) { + l.mu.Lock() + b, ok := l.buckets[key] + if !ok { + b = NewBucket(l.max, l.refill) + l.buckets[key] = b + } + l.mu.Unlock() + + return b, b.Allow() +} + +// Cleanup removes buckets that haven't been used in the given duration. +// Call periodically to prevent unbounded memory growth. +func (l *Limiter) Cleanup(maxAge time.Duration) { + l.mu.Lock() + defer l.mu.Unlock() + + cutoff := time.Now().Add(-maxAge) + for k, b := range l.buckets { + b.mu.Lock() + if b.lastTime.Before(cutoff) { + delete(l.buckets, k) + } + b.mu.Unlock() + } +} + +// Middleware returns an http.Handler that applies per-key rate limiting. +// On limit breach it responds with HTTP 429 and a Retry-After header. +// onLimit is called (if non-nil) when a request is rate-limited, for metrics. +func (l *Limiter) Middleware(keyFunc func(*http.Request) string, onLimit func()) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := keyFunc(r) + bucket, ok := l.Allow(key) + if !ok { + if onLimit != nil { + onLimit() + } + retry := bucket.RetryAfter() + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", time.Duration(retry).Seconds())) + http.Error(w, `{"error":"rate limit exceeded"}`, http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..a74f324 --- /dev/null +++ b/ratelimit/ratelimit_test.go @@ -0,0 +1,252 @@ +package ratelimit + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestBucketAllowsUpToMax(t *testing.T) { + b := NewBucket(3, 1.0) // 3 tokens, 1/sec refill + + for i := 0; i < 3; i++ { + if !b.Allow() { + t.Fatalf("request %d should be allowed", i+1) + } + } + if b.Allow() { + t.Fatal("4th request should be denied") + } +} + +func TestBucketRefills(t *testing.T) { + b := NewBucket(1, 100.0) // 1 token, 100/sec refill (fast refill) + + if !b.Allow() { + t.Fatal("first request should be allowed") + } + if b.Allow() { + t.Fatal("second request should be denied (bucket empty)") + } + + // Wait for refill + time.Sleep(20 * time.Millisecond) + + if !b.Allow() { + t.Fatal("request after refill should be allowed") + } +} + +func TestBucketRetryAfter(t *testing.T) { + b := NewBucket(1, 1.0) // 1 token, 1/sec refill + b.Allow() // drain + + retry := b.RetryAfter() + if retry < 1 { + t.Fatalf("RetryAfter = %d, want >= 1", retry) + } +} + +func TestLimiterCreatesBucketsOnDemand(t *testing.T) { + l := NewLimiter(2, 1.0) + + _, ok1 := l.Allow("key-a") + if !ok1 { + t.Fatal("first request for key-a should be allowed") + } + _, ok2 := l.Allow("key-b") + if !ok2 { + t.Fatal("first request for key-b should be allowed") + } + + // key-a has 1 token left + _, ok3 := l.Allow("key-a") + if !ok3 { + t.Fatal("second request for key-a should be allowed") + } + _, ok4 := l.Allow("key-a") + if ok4 { + t.Fatal("third request for key-a should be denied") + } + + // key-b still has 1 token + _, ok5 := l.Allow("key-b") + if !ok5 { + t.Fatal("second request for key-b should be allowed (independent bucket)") + } +} + +func TestMiddlewareAllowsWhenUnderLimit(t *testing.T) { + l := NewLimiter(2, 1.0) + called := false + handler := l.Middleware(func(r *http.Request) string { return "ip1" }, nil)( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }), + ) + + req := httptest.NewRequest("POST", "/test", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if !called { + t.Fatal("handler should have been called") + } + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } +} + +func TestMiddlewareRejectsOnLimit(t *testing.T) { + l := NewLimiter(1, 0.0001) // 1 token, extremely slow refill + onLimitCalled := false + mw := l.Middleware(func(r *http.Request) string { return "ip1" }, func() { + onLimitCalled = true + }) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request passes + req := httptest.NewRequest("POST", "/test", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("first request: status = %d, want 200", w.Code) + } + + // Second request is rate limited + req2 := httptest.NewRequest("POST", "/test", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + if w2.Code != http.StatusTooManyRequests { + t.Fatalf("second request: status = %d, want 429", w2.Code) + } + if !onLimitCalled { + t.Fatal("onLimit callback should have been called") + } + if h := w2.Header().Get("Retry-After"); h == "" { + t.Fatal("Retry-After header should be set") + } + + var body map[string]string + json.NewDecoder(w2.Body).Decode(&body) + if body["error"] != "rate limit exceeded" { + t.Fatalf("error body = %q, want 'rate limit exceeded'", body["error"]) + } +} + +func TestMiddlewareKeysByIP(t *testing.T) { + l := NewLimiter(1, 0.0001) // 1 token per key + keyCount := 0 + mw := l.Middleware(func(r *http.Request) string { + keyCount++ + return r.RemoteAddr + }, nil) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First IP gets one request + req1 := httptest.NewRequest("POST", "/test", nil) + req1.RemoteAddr = "1.2.3.4:1234" + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + // Second IP also gets one request (different bucket) + req2 := httptest.NewRequest("POST", "/test", nil) + req2.RemoteAddr = "5.6.7.8:5678" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + if w1.Code != http.StatusOK { + t.Fatalf("first IP: status = %d, want 200", w1.Code) + } + if w2.Code != http.StatusOK { + t.Fatalf("second IP: status = %d, want 200", w2.Code) + } +} + +func TestCleanupRemovesStaleBuckets(t *testing.T) { + l := NewLimiter(1, 1.0) + + l.Allow("stale-key") + l.Allow("fresh-key") + + if len(l.buckets) != 2 { + t.Fatalf("expected 2 buckets, got %d", len(l.buckets)) + } + + // Manually age the stale bucket + l.buckets["stale-key"].mu.Lock() + l.buckets["stale-key"].lastTime = time.Now().Add(-2 * time.Hour) + l.buckets["stale-key"].mu.Unlock() + + l.Cleanup(time.Hour) + + if len(l.buckets) != 1 { + t.Fatalf("expected 1 bucket after cleanup, got %d", len(l.buckets)) + } + if _, ok := l.buckets["stale-key"]; ok { + t.Fatal("stale-key should have been cleaned up") + } + if _, ok := l.buckets["fresh-key"]; !ok { + t.Fatal("fresh-key should still be present") + } +} + +func TestFloodRegisterEndpoint(t *testing.T) { + // Simulates the verification requirement: flood test against /register + // returns 429 after threshold. + l := NewLimiter(5, 5.0/3600) // 5/hour per IP + mw := l.Middleware(func(r *http.Request) string { return "1.2.3.4" }, nil) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + })) + + okCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("POST", "/api/register", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code == http.StatusCreated { + okCount++ + } + } + + if okCount != 5 { + t.Fatalf("expected 5 successful requests, got %d", okCount) + } + + // 6th request should be 429 + req := httptest.NewRequest("POST", "/api/register", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("6th request: status = %d, want 429", w.Code) + } +} + +func TestLegitimateTrafficPasses(t *testing.T) { + // Verifies that normal traffic patterns don't get rate limited. + // 3 registrations from different IPs should all succeed. + l := NewLimiter(5, 5.0/3600) // 5/hour per IP + mw := l.Middleware(func(r *http.Request) string { return r.RemoteAddr }, nil) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + })) + + for i, ip := range []string{"1.1.1.1:1111", "2.2.2.2:2222", "3.3.3.3:3333"} { + req := httptest.NewRequest("POST", "/api/register", nil) + req.RemoteAddr = ip + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("request %d from %s: status = %d, want 201", i+1, ip, w.Code) + } + } +}