feat(api): add token-bucket rate limiting to public endpoints

Adds ratelimit package with per-IP and per-key HTTP middleware.
Applied to register (5/hr), feedback (20/hr), predict (60/hr),
and job submission (5/day) endpoints. Includes metrics counter
for rejected requests and periodic bucket cleanup goroutine.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-04-22 16:52:29 -04:00
parent 2df70c8ae0
commit 7e131d310f
6 changed files with 477 additions and 18 deletions

View file

@ -16,6 +16,7 @@ import (
"time"
"github.com/aicodebattle/acb/metrics"
"github.com/aicodebattle/acb/ratelimit"
_ "github.com/lib/pq"
"github.com/redis/go-redis/v9"
)
@ -76,12 +77,27 @@ func main() {
defer rdb.Close()
srv := &Server{
cfg: cfg,
db: db,
rdb: rdb,
// Note: alerter moved to acb-matchmaker deployment
cfg: cfg,
db: db,
rdb: rdb,
regLimiter: ratelimit.NewLimiter(5, 5.0/3600), // 5/hour per IP
feedbackLtr: ratelimit.NewLimiter(20, 20.0/3600), // 20/hour per IP
predictLtr: ratelimit.NewLimiter(60, 60.0/3600), // 60/hour per IP
submitLtr: ratelimit.NewLimiter(5, 5.0/86400), // 5/day per key
}
// Periodically purge stale rate-limit buckets (every 10 min)
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
srv.regLimiter.Cleanup(time.Hour)
srv.feedbackLtr.Cleanup(time.Hour)
srv.predictLtr.Cleanup(time.Hour)
srv.submitLtr.Cleanup(24 * time.Hour)
}
}()
mux := http.NewServeMux()
srv.RegisterRoutes(mux)

View file

@ -7,12 +7,15 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/aicodebattle/acb/metrics"
"github.com/aicodebattle/acb/ratelimit"
"github.com/redis/go-redis/v9"
)
@ -20,39 +23,80 @@ import (
// Provides bot registration, job coordination, replay serving,
// bot profiles, leaderboards, and UI feedback ingestion.
type Server struct {
cfg Config
db *sql.DB
rdb *redis.Client
cfg Config
db *sql.DB
rdb *redis.Client
regLimiter *ratelimit.Limiter // 5/hour per IP
feedbackLtr *ratelimit.Limiter // 20/hour per IP
predictLtr *ratelimit.Limiter // 60/hour per IP
submitLtr *ratelimit.Limiter // 5/day per bot_id
}
func (s *Server) RegisterRoutes(mux *http.ServeMux) {
// Health endpoints
// Health endpoints (no rate limit)
mux.HandleFunc("GET /health", s.handleHealth)
mux.HandleFunc("GET /ready", s.handleReady)
// Bot registration
mux.HandleFunc("POST /api/register", s.handleRegister)
// Bot registration — 5/hour per IP
regMW := s.regLimiter.Middleware(ipKey, func() {
metrics.RateLimitHits.WithLabelValues("register").Inc()
})
mux.HandleFunc("POST /api/register", regMW(http.HandlerFunc(s.handleRegister)).ServeHTTP)
// Job coordination (for workers)
// Job coordination (for workers — authenticated, no public rate limit)
mux.HandleFunc("GET /api/job", s.handleGetJob)
mux.HandleFunc("POST /api/job/", s.handleJobResult)
// Replay serving
// Job result submission — per-worker 5/day limit
submitMW := s.submitLtr.Middleware(botIDKey(), func() {
metrics.RateLimitHits.WithLabelValues("submit").Inc()
})
mux.HandleFunc("POST /api/job/", submitMW(http.HandlerFunc(s.handleJobResult)).ServeHTTP)
// Replay serving (read-only, no rate limit)
mux.HandleFunc("GET /api/replay/", s.handleGetReplay)
// Bot profiles and leaderboard
// Bot profiles and leaderboard (read-only, no rate limit)
mux.HandleFunc("GET /api/bot/", s.handleGetBot)
mux.HandleFunc("GET /api/bots", s.handleListBots)
// Community replay feedback per plan §13.6
mux.HandleFunc("POST /api/feedback", s.handleUIFeedback)
// Community replay feedback — 20/hour per IP
fbMW := s.feedbackLtr.Middleware(ipKey, func() {
metrics.RateLimitHits.WithLabelValues("feedback").Inc()
})
mux.HandleFunc("POST /api/feedback", fbMW(http.HandlerFunc(s.handleUIFeedback)).ServeHTTP)
// Predictions
mux.HandleFunc("POST /api/predict", s.handlePredict)
// Predictions — 60/hour per IP
predMW := s.predictLtr.Middleware(ipKey, func() {
metrics.RateLimitHits.WithLabelValues("predict").Inc()
})
mux.HandleFunc("POST /api/predict", predMW(http.HandlerFunc(s.handlePredict)).ServeHTTP)
mux.HandleFunc("GET /api/predictions/open", s.handleOpenPredictions)
mux.HandleFunc("GET /api/predictions/history", s.handlePredictionHistory)
}
// ipKey extracts the client IP from the request for rate limiting.
// Respects X-Forwarded-For when present (behind reverse proxy).
func ipKey(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
host, _, _ := net.SplitHostPort(r.RemoteAddr)
return host
}
// botIDKey extracts a rate-limit key for the job submission endpoint.
// Workers submit results authenticated by API key, so we key by worker IP
// to enforce the per-worker submission rate limit (max 5/day).
func botIDKey() func(*http.Request) string {
return func(r *http.Request) string {
host, _, _ := net.SplitHostPort(r.RemoteAddr)
return "worker:" + host
}
}
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)

View file

@ -5,6 +5,8 @@ import (
"net/http"
"net/http/httptest"
"testing"
"github.com/aicodebattle/acb/ratelimit"
)
// newTestServer creates a Server with no database or redis (for unit tests
@ -16,6 +18,10 @@ func newTestServer() *Server {
BotTimeoutSecs: 5,
MaxConsecFails: 3,
},
regLimiter: ratelimit.NewLimiter(5, 5.0/3600),
feedbackLtr: ratelimit.NewLimiter(20, 20.0/3600),
predictLtr: ratelimit.NewLimiter(60, 60.0/3600),
submitLtr: ratelimit.NewLimiter(5, 5.0/86400),
}
}

View file

@ -109,6 +109,12 @@ var (
Help: "Match execution duration in seconds.",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600},
})
// RateLimitHits counts requests rejected by rate limiting.
RateLimitHits = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "acb_rate_limit_hits_total",
Help: "Total number of requests rejected by rate limiting.",
}, []string{"endpoint"})
)
func init() {
@ -128,6 +134,7 @@ func init() {
WorkerMatchErrorsTotal,
WorkerJobsClaimedTotal,
WorkerMatchDuration,
RateLimitHits,
)
}

134
ratelimit/ratelimit.go Normal file
View file

@ -0,0 +1,134 @@
// Package ratelimit provides token-bucket rate limiting for HTTP handlers.
package ratelimit
import (
"fmt"
"net/http"
"sync"
"time"
)
// Bucket is a token-bucket rate limiter for a single key.
type Bucket struct {
mu sync.Mutex
tokens float64
max float64
refill float64 // tokens added per second
lastTime time.Time
}
// NewBucket creates a bucket that holds max tokens and refills at the given
// rate (tokens per second). The bucket starts full.
func NewBucket(max, refillPerSec float64) *Bucket {
return &Bucket{
tokens: max,
max: max,
refill: refillPerSec,
lastTime: time.Now(),
}
}
// Allow consumes one token. Returns true if a token was available.
func (b *Bucket) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
now := time.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.lastTime = now
b.tokens += elapsed * b.refill
if b.tokens > b.max {
b.tokens = b.max
}
if b.tokens < 1 {
return false
}
b.tokens--
return true
}
// RetryAfter returns the number of seconds until the next token is available.
// Call after Allow() returns false.
func (b *Bucket) RetryAfter() int {
b.mu.Lock()
defer b.mu.Unlock()
deficit := 1.0 - b.tokens
if deficit <= 0 {
return 1
}
secs := deficit / b.refill
if secs < 1 {
return 1
}
return int(secs)
}
// Limiter holds a collection of buckets keyed by string (e.g. "ip:endpoint").
type Limiter struct {
mu sync.Mutex
buckets map[string]*Bucket
max float64
refill float64
}
// NewLimiter creates a Limiter where each key gets max tokens, refilling at
// refillPerSec tokens per second.
func NewLimiter(max, refillPerSec float64) *Limiter {
return &Limiter{
buckets: make(map[string]*Bucket),
max: max,
refill: refillPerSec,
}
}
// Allow checks the bucket for the given key. Creates one if needed.
func (l *Limiter) Allow(key string) (*Bucket, bool) {
l.mu.Lock()
b, ok := l.buckets[key]
if !ok {
b = NewBucket(l.max, l.refill)
l.buckets[key] = b
}
l.mu.Unlock()
return b, b.Allow()
}
// Cleanup removes buckets that haven't been used in the given duration.
// Call periodically to prevent unbounded memory growth.
func (l *Limiter) Cleanup(maxAge time.Duration) {
l.mu.Lock()
defer l.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
for k, b := range l.buckets {
b.mu.Lock()
if b.lastTime.Before(cutoff) {
delete(l.buckets, k)
}
b.mu.Unlock()
}
}
// Middleware returns an http.Handler that applies per-key rate limiting.
// On limit breach it responds with HTTP 429 and a Retry-After header.
// onLimit is called (if non-nil) when a request is rate-limited, for metrics.
func (l *Limiter) Middleware(keyFunc func(*http.Request) string, onLimit func()) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := keyFunc(r)
bucket, ok := l.Allow(key)
if !ok {
if onLimit != nil {
onLimit()
}
retry := bucket.RetryAfter()
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", time.Duration(retry).Seconds()))
http.Error(w, `{"error":"rate limit exceeded"}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}

252
ratelimit/ratelimit_test.go Normal file
View file

@ -0,0 +1,252 @@
package ratelimit
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestBucketAllowsUpToMax(t *testing.T) {
b := NewBucket(3, 1.0) // 3 tokens, 1/sec refill
for i := 0; i < 3; i++ {
if !b.Allow() {
t.Fatalf("request %d should be allowed", i+1)
}
}
if b.Allow() {
t.Fatal("4th request should be denied")
}
}
func TestBucketRefills(t *testing.T) {
b := NewBucket(1, 100.0) // 1 token, 100/sec refill (fast refill)
if !b.Allow() {
t.Fatal("first request should be allowed")
}
if b.Allow() {
t.Fatal("second request should be denied (bucket empty)")
}
// Wait for refill
time.Sleep(20 * time.Millisecond)
if !b.Allow() {
t.Fatal("request after refill should be allowed")
}
}
func TestBucketRetryAfter(t *testing.T) {
b := NewBucket(1, 1.0) // 1 token, 1/sec refill
b.Allow() // drain
retry := b.RetryAfter()
if retry < 1 {
t.Fatalf("RetryAfter = %d, want >= 1", retry)
}
}
func TestLimiterCreatesBucketsOnDemand(t *testing.T) {
l := NewLimiter(2, 1.0)
_, ok1 := l.Allow("key-a")
if !ok1 {
t.Fatal("first request for key-a should be allowed")
}
_, ok2 := l.Allow("key-b")
if !ok2 {
t.Fatal("first request for key-b should be allowed")
}
// key-a has 1 token left
_, ok3 := l.Allow("key-a")
if !ok3 {
t.Fatal("second request for key-a should be allowed")
}
_, ok4 := l.Allow("key-a")
if ok4 {
t.Fatal("third request for key-a should be denied")
}
// key-b still has 1 token
_, ok5 := l.Allow("key-b")
if !ok5 {
t.Fatal("second request for key-b should be allowed (independent bucket)")
}
}
func TestMiddlewareAllowsWhenUnderLimit(t *testing.T) {
l := NewLimiter(2, 1.0)
called := false
handler := l.Middleware(func(r *http.Request) string { return "ip1" }, nil)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}),
)
req := httptest.NewRequest("POST", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("handler should have been called")
}
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestMiddlewareRejectsOnLimit(t *testing.T) {
l := NewLimiter(1, 0.0001) // 1 token, extremely slow refill
onLimitCalled := false
mw := l.Middleware(func(r *http.Request) string { return "ip1" }, func() {
onLimitCalled = true
})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request passes
req := httptest.NewRequest("POST", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("first request: status = %d, want 200", w.Code)
}
// Second request is rate limited
req2 := httptest.NewRequest("POST", "/test", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w2.Code != http.StatusTooManyRequests {
t.Fatalf("second request: status = %d, want 429", w2.Code)
}
if !onLimitCalled {
t.Fatal("onLimit callback should have been called")
}
if h := w2.Header().Get("Retry-After"); h == "" {
t.Fatal("Retry-After header should be set")
}
var body map[string]string
json.NewDecoder(w2.Body).Decode(&body)
if body["error"] != "rate limit exceeded" {
t.Fatalf("error body = %q, want 'rate limit exceeded'", body["error"])
}
}
func TestMiddlewareKeysByIP(t *testing.T) {
l := NewLimiter(1, 0.0001) // 1 token per key
keyCount := 0
mw := l.Middleware(func(r *http.Request) string {
keyCount++
return r.RemoteAddr
}, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First IP gets one request
req1 := httptest.NewRequest("POST", "/test", nil)
req1.RemoteAddr = "1.2.3.4:1234"
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
// Second IP also gets one request (different bucket)
req2 := httptest.NewRequest("POST", "/test", nil)
req2.RemoteAddr = "5.6.7.8:5678"
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w1.Code != http.StatusOK {
t.Fatalf("first IP: status = %d, want 200", w1.Code)
}
if w2.Code != http.StatusOK {
t.Fatalf("second IP: status = %d, want 200", w2.Code)
}
}
func TestCleanupRemovesStaleBuckets(t *testing.T) {
l := NewLimiter(1, 1.0)
l.Allow("stale-key")
l.Allow("fresh-key")
if len(l.buckets) != 2 {
t.Fatalf("expected 2 buckets, got %d", len(l.buckets))
}
// Manually age the stale bucket
l.buckets["stale-key"].mu.Lock()
l.buckets["stale-key"].lastTime = time.Now().Add(-2 * time.Hour)
l.buckets["stale-key"].mu.Unlock()
l.Cleanup(time.Hour)
if len(l.buckets) != 1 {
t.Fatalf("expected 1 bucket after cleanup, got %d", len(l.buckets))
}
if _, ok := l.buckets["stale-key"]; ok {
t.Fatal("stale-key should have been cleaned up")
}
if _, ok := l.buckets["fresh-key"]; !ok {
t.Fatal("fresh-key should still be present")
}
}
func TestFloodRegisterEndpoint(t *testing.T) {
// Simulates the verification requirement: flood test against /register
// returns 429 after threshold.
l := NewLimiter(5, 5.0/3600) // 5/hour per IP
mw := l.Middleware(func(r *http.Request) string { return "1.2.3.4" }, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
}))
okCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("POST", "/api/register", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code == http.StatusCreated {
okCount++
}
}
if okCount != 5 {
t.Fatalf("expected 5 successful requests, got %d", okCount)
}
// 6th request should be 429
req := httptest.NewRequest("POST", "/api/register", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("6th request: status = %d, want 429", w.Code)
}
}
func TestLegitimateTrafficPasses(t *testing.T) {
// Verifies that normal traffic patterns don't get rate limited.
// 3 registrations from different IPs should all succeed.
l := NewLimiter(5, 5.0/3600) // 5/hour per IP
mw := l.Middleware(func(r *http.Request) string { return r.RemoteAddr }, nil)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
}))
for i, ip := range []string{"1.1.1.1:1111", "2.2.2.2:2222", "3.3.3.3:3333"} {
req := httptest.NewRequest("POST", "/api/register", nil)
req.RemoteAddr = ip
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("request %d from %s: status = %d, want 201", i+1, ip, w.Code)
}
}
}