feat(api): add token-bucket rate limiting to public endpoints
Adds ratelimit package with per-IP and per-key HTTP middleware. Applied to register (5/hr), feedback (20/hr), predict (60/hr), and job submission (5/day) endpoints. Includes metrics counter for rejected requests and periodic bucket cleanup goroutine. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
2df70c8ae0
commit
7e131d310f
6 changed files with 477 additions and 18 deletions
|
|
@ -16,6 +16,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/aicodebattle/acb/metrics"
|
||||
"github.com/aicodebattle/acb/ratelimit"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
|
@ -76,12 +77,27 @@ func main() {
|
|||
defer rdb.Close()
|
||||
|
||||
srv := &Server{
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
// Note: alerter moved to acb-matchmaker deployment
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
regLimiter: ratelimit.NewLimiter(5, 5.0/3600), // 5/hour per IP
|
||||
feedbackLtr: ratelimit.NewLimiter(20, 20.0/3600), // 20/hour per IP
|
||||
predictLtr: ratelimit.NewLimiter(60, 60.0/3600), // 60/hour per IP
|
||||
submitLtr: ratelimit.NewLimiter(5, 5.0/86400), // 5/day per key
|
||||
}
|
||||
|
||||
// Periodically purge stale rate-limit buckets (every 10 min)
|
||||
go func() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
srv.regLimiter.Cleanup(time.Hour)
|
||||
srv.feedbackLtr.Cleanup(time.Hour)
|
||||
srv.predictLtr.Cleanup(time.Hour)
|
||||
srv.submitLtr.Cleanup(24 * time.Hour)
|
||||
}
|
||||
}()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
srv.RegisterRoutes(mux)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,15 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aicodebattle/acb/metrics"
|
||||
"github.com/aicodebattle/acb/ratelimit"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
|
|
@ -20,39 +23,80 @@ import (
|
|||
// Provides bot registration, job coordination, replay serving,
|
||||
// bot profiles, leaderboards, and UI feedback ingestion.
|
||||
type Server struct {
|
||||
cfg Config
|
||||
db *sql.DB
|
||||
rdb *redis.Client
|
||||
cfg Config
|
||||
db *sql.DB
|
||||
rdb *redis.Client
|
||||
regLimiter *ratelimit.Limiter // 5/hour per IP
|
||||
feedbackLtr *ratelimit.Limiter // 20/hour per IP
|
||||
predictLtr *ratelimit.Limiter // 60/hour per IP
|
||||
submitLtr *ratelimit.Limiter // 5/day per bot_id
|
||||
}
|
||||
|
||||
func (s *Server) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Health endpoints
|
||||
// Health endpoints (no rate limit)
|
||||
mux.HandleFunc("GET /health", s.handleHealth)
|
||||
mux.HandleFunc("GET /ready", s.handleReady)
|
||||
|
||||
// Bot registration
|
||||
mux.HandleFunc("POST /api/register", s.handleRegister)
|
||||
// Bot registration — 5/hour per IP
|
||||
regMW := s.regLimiter.Middleware(ipKey, func() {
|
||||
metrics.RateLimitHits.WithLabelValues("register").Inc()
|
||||
})
|
||||
mux.HandleFunc("POST /api/register", regMW(http.HandlerFunc(s.handleRegister)).ServeHTTP)
|
||||
|
||||
// Job coordination (for workers)
|
||||
// Job coordination (for workers — authenticated, no public rate limit)
|
||||
mux.HandleFunc("GET /api/job", s.handleGetJob)
|
||||
mux.HandleFunc("POST /api/job/", s.handleJobResult)
|
||||
|
||||
// Replay serving
|
||||
// Job result submission — per-worker 5/day limit
|
||||
submitMW := s.submitLtr.Middleware(botIDKey(), func() {
|
||||
metrics.RateLimitHits.WithLabelValues("submit").Inc()
|
||||
})
|
||||
mux.HandleFunc("POST /api/job/", submitMW(http.HandlerFunc(s.handleJobResult)).ServeHTTP)
|
||||
|
||||
// Replay serving (read-only, no rate limit)
|
||||
mux.HandleFunc("GET /api/replay/", s.handleGetReplay)
|
||||
|
||||
// Bot profiles and leaderboard
|
||||
// Bot profiles and leaderboard (read-only, no rate limit)
|
||||
mux.HandleFunc("GET /api/bot/", s.handleGetBot)
|
||||
mux.HandleFunc("GET /api/bots", s.handleListBots)
|
||||
|
||||
// Community replay feedback per plan §13.6
|
||||
mux.HandleFunc("POST /api/feedback", s.handleUIFeedback)
|
||||
// Community replay feedback — 20/hour per IP
|
||||
fbMW := s.feedbackLtr.Middleware(ipKey, func() {
|
||||
metrics.RateLimitHits.WithLabelValues("feedback").Inc()
|
||||
})
|
||||
mux.HandleFunc("POST /api/feedback", fbMW(http.HandlerFunc(s.handleUIFeedback)).ServeHTTP)
|
||||
|
||||
// Predictions
|
||||
mux.HandleFunc("POST /api/predict", s.handlePredict)
|
||||
// Predictions — 60/hour per IP
|
||||
predMW := s.predictLtr.Middleware(ipKey, func() {
|
||||
metrics.RateLimitHits.WithLabelValues("predict").Inc()
|
||||
})
|
||||
mux.HandleFunc("POST /api/predict", predMW(http.HandlerFunc(s.handlePredict)).ServeHTTP)
|
||||
mux.HandleFunc("GET /api/predictions/open", s.handleOpenPredictions)
|
||||
mux.HandleFunc("GET /api/predictions/history", s.handlePredictionHistory)
|
||||
}
|
||||
|
||||
// ipKey extracts the client IP from the request for rate limiting.
|
||||
// Respects X-Forwarded-For when present (behind reverse proxy).
|
||||
func ipKey(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
return host
|
||||
}
|
||||
|
||||
// botIDKey extracts a rate-limit key for the job submission endpoint.
|
||||
// Workers submit results authenticated by API key, so we key by worker IP
|
||||
// to enforce the per-worker submission rate limit (max 5/day).
|
||||
func botIDKey() func(*http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
return "worker:" + host
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/aicodebattle/acb/ratelimit"
|
||||
)
|
||||
|
||||
// newTestServer creates a Server with no database or redis (for unit tests
|
||||
|
|
@ -16,6 +18,10 @@ func newTestServer() *Server {
|
|||
BotTimeoutSecs: 5,
|
||||
MaxConsecFails: 3,
|
||||
},
|
||||
regLimiter: ratelimit.NewLimiter(5, 5.0/3600),
|
||||
feedbackLtr: ratelimit.NewLimiter(20, 20.0/3600),
|
||||
predictLtr: ratelimit.NewLimiter(60, 60.0/3600),
|
||||
submitLtr: ratelimit.NewLimiter(5, 5.0/86400),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -109,6 +109,12 @@ var (
|
|||
Help: "Match execution duration in seconds.",
|
||||
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600},
|
||||
})
|
||||
|
||||
// RateLimitHits counts requests rejected by rate limiting.
|
||||
RateLimitHits = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "acb_rate_limit_hits_total",
|
||||
Help: "Total number of requests rejected by rate limiting.",
|
||||
}, []string{"endpoint"})
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -128,6 +134,7 @@ func init() {
|
|||
WorkerMatchErrorsTotal,
|
||||
WorkerJobsClaimedTotal,
|
||||
WorkerMatchDuration,
|
||||
RateLimitHits,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
134
ratelimit/ratelimit.go
Normal file
134
ratelimit/ratelimit.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
// Package ratelimit provides token-bucket rate limiting for HTTP handlers.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Bucket is a token-bucket rate limiter for a single key.
|
||||
type Bucket struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
max float64
|
||||
refill float64 // tokens added per second
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// NewBucket creates a bucket that holds max tokens and refills at the given
|
||||
// rate (tokens per second). The bucket starts full.
|
||||
func NewBucket(max, refillPerSec float64) *Bucket {
|
||||
return &Bucket{
|
||||
tokens: max,
|
||||
max: max,
|
||||
refill: refillPerSec,
|
||||
lastTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow consumes one token. Returns true if a token was available.
|
||||
func (b *Bucket) Allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(b.lastTime).Seconds()
|
||||
b.lastTime = now
|
||||
b.tokens += elapsed * b.refill
|
||||
if b.tokens > b.max {
|
||||
b.tokens = b.max
|
||||
}
|
||||
if b.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// RetryAfter returns the number of seconds until the next token is available.
|
||||
// Call after Allow() returns false.
|
||||
func (b *Bucket) RetryAfter() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
deficit := 1.0 - b.tokens
|
||||
if deficit <= 0 {
|
||||
return 1
|
||||
}
|
||||
secs := deficit / b.refill
|
||||
if secs < 1 {
|
||||
return 1
|
||||
}
|
||||
return int(secs)
|
||||
}
|
||||
|
||||
// Limiter holds a collection of buckets keyed by string (e.g. "ip:endpoint").
|
||||
type Limiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*Bucket
|
||||
max float64
|
||||
refill float64
|
||||
}
|
||||
|
||||
// NewLimiter creates a Limiter where each key gets max tokens, refilling at
|
||||
// refillPerSec tokens per second.
|
||||
func NewLimiter(max, refillPerSec float64) *Limiter {
|
||||
return &Limiter{
|
||||
buckets: make(map[string]*Bucket),
|
||||
max: max,
|
||||
refill: refillPerSec,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks the bucket for the given key. Creates one if needed.
|
||||
func (l *Limiter) Allow(key string) (*Bucket, bool) {
|
||||
l.mu.Lock()
|
||||
b, ok := l.buckets[key]
|
||||
if !ok {
|
||||
b = NewBucket(l.max, l.refill)
|
||||
l.buckets[key] = b
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
return b, b.Allow()
|
||||
}
|
||||
|
||||
// Cleanup removes buckets that haven't been used in the given duration.
|
||||
// Call periodically to prevent unbounded memory growth.
|
||||
func (l *Limiter) Cleanup(maxAge time.Duration) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
for k, b := range l.buckets {
|
||||
b.mu.Lock()
|
||||
if b.lastTime.Before(cutoff) {
|
||||
delete(l.buckets, k)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns an http.Handler that applies per-key rate limiting.
|
||||
// On limit breach it responds with HTTP 429 and a Retry-After header.
|
||||
// onLimit is called (if non-nil) when a request is rate-limited, for metrics.
|
||||
func (l *Limiter) Middleware(keyFunc func(*http.Request) string, onLimit func()) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := keyFunc(r)
|
||||
bucket, ok := l.Allow(key)
|
||||
if !ok {
|
||||
if onLimit != nil {
|
||||
onLimit()
|
||||
}
|
||||
retry := bucket.RetryAfter()
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", time.Duration(retry).Seconds()))
|
||||
http.Error(w, `{"error":"rate limit exceeded"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
252
ratelimit/ratelimit_test.go
Normal file
252
ratelimit/ratelimit_test.go
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBucketAllowsUpToMax(t *testing.T) {
|
||||
b := NewBucket(3, 1.0) // 3 tokens, 1/sec refill
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
if !b.Allow() {
|
||||
t.Fatalf("request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
if b.Allow() {
|
||||
t.Fatal("4th request should be denied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBucketRefills(t *testing.T) {
|
||||
b := NewBucket(1, 100.0) // 1 token, 100/sec refill (fast refill)
|
||||
|
||||
if !b.Allow() {
|
||||
t.Fatal("first request should be allowed")
|
||||
}
|
||||
if b.Allow() {
|
||||
t.Fatal("second request should be denied (bucket empty)")
|
||||
}
|
||||
|
||||
// Wait for refill
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
if !b.Allow() {
|
||||
t.Fatal("request after refill should be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBucketRetryAfter(t *testing.T) {
|
||||
b := NewBucket(1, 1.0) // 1 token, 1/sec refill
|
||||
b.Allow() // drain
|
||||
|
||||
retry := b.RetryAfter()
|
||||
if retry < 1 {
|
||||
t.Fatalf("RetryAfter = %d, want >= 1", retry)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterCreatesBucketsOnDemand(t *testing.T) {
|
||||
l := NewLimiter(2, 1.0)
|
||||
|
||||
_, ok1 := l.Allow("key-a")
|
||||
if !ok1 {
|
||||
t.Fatal("first request for key-a should be allowed")
|
||||
}
|
||||
_, ok2 := l.Allow("key-b")
|
||||
if !ok2 {
|
||||
t.Fatal("first request for key-b should be allowed")
|
||||
}
|
||||
|
||||
// key-a has 1 token left
|
||||
_, ok3 := l.Allow("key-a")
|
||||
if !ok3 {
|
||||
t.Fatal("second request for key-a should be allowed")
|
||||
}
|
||||
_, ok4 := l.Allow("key-a")
|
||||
if ok4 {
|
||||
t.Fatal("third request for key-a should be denied")
|
||||
}
|
||||
|
||||
// key-b still has 1 token
|
||||
_, ok5 := l.Allow("key-b")
|
||||
if !ok5 {
|
||||
t.Fatal("second request for key-b should be allowed (independent bucket)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareAllowsWhenUnderLimit(t *testing.T) {
|
||||
l := NewLimiter(2, 1.0)
|
||||
called := false
|
||||
handler := l.Middleware(func(r *http.Request) string { return "ip1" }, nil)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("handler should have been called")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareRejectsOnLimit(t *testing.T) {
|
||||
l := NewLimiter(1, 0.0001) // 1 token, extremely slow refill
|
||||
onLimitCalled := false
|
||||
mw := l.Middleware(func(r *http.Request) string { return "ip1" }, func() {
|
||||
onLimitCalled = true
|
||||
})
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request passes
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("first request: status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
// Second request is rate limited
|
||||
req2 := httptest.NewRequest("POST", "/test", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second request: status = %d, want 429", w2.Code)
|
||||
}
|
||||
if !onLimitCalled {
|
||||
t.Fatal("onLimit callback should have been called")
|
||||
}
|
||||
if h := w2.Header().Get("Retry-After"); h == "" {
|
||||
t.Fatal("Retry-After header should be set")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
json.NewDecoder(w2.Body).Decode(&body)
|
||||
if body["error"] != "rate limit exceeded" {
|
||||
t.Fatalf("error body = %q, want 'rate limit exceeded'", body["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareKeysByIP(t *testing.T) {
|
||||
l := NewLimiter(1, 0.0001) // 1 token per key
|
||||
keyCount := 0
|
||||
mw := l.Middleware(func(r *http.Request) string {
|
||||
keyCount++
|
||||
return r.RemoteAddr
|
||||
}, nil)
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First IP gets one request
|
||||
req1 := httptest.NewRequest("POST", "/test", nil)
|
||||
req1.RemoteAddr = "1.2.3.4:1234"
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
// Second IP also gets one request (different bucket)
|
||||
req2 := httptest.NewRequest("POST", "/test", nil)
|
||||
req2.RemoteAddr = "5.6.7.8:5678"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first IP: status = %d, want 200", w1.Code)
|
||||
}
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("second IP: status = %d, want 200", w2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupRemovesStaleBuckets(t *testing.T) {
|
||||
l := NewLimiter(1, 1.0)
|
||||
|
||||
l.Allow("stale-key")
|
||||
l.Allow("fresh-key")
|
||||
|
||||
if len(l.buckets) != 2 {
|
||||
t.Fatalf("expected 2 buckets, got %d", len(l.buckets))
|
||||
}
|
||||
|
||||
// Manually age the stale bucket
|
||||
l.buckets["stale-key"].mu.Lock()
|
||||
l.buckets["stale-key"].lastTime = time.Now().Add(-2 * time.Hour)
|
||||
l.buckets["stale-key"].mu.Unlock()
|
||||
|
||||
l.Cleanup(time.Hour)
|
||||
|
||||
if len(l.buckets) != 1 {
|
||||
t.Fatalf("expected 1 bucket after cleanup, got %d", len(l.buckets))
|
||||
}
|
||||
if _, ok := l.buckets["stale-key"]; ok {
|
||||
t.Fatal("stale-key should have been cleaned up")
|
||||
}
|
||||
if _, ok := l.buckets["fresh-key"]; !ok {
|
||||
t.Fatal("fresh-key should still be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloodRegisterEndpoint(t *testing.T) {
|
||||
// Simulates the verification requirement: flood test against /register
|
||||
// returns 429 after threshold.
|
||||
l := NewLimiter(5, 5.0/3600) // 5/hour per IP
|
||||
mw := l.Middleware(func(r *http.Request) string { return "1.2.3.4" }, nil)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}))
|
||||
|
||||
okCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("POST", "/api/register", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code == http.StatusCreated {
|
||||
okCount++
|
||||
}
|
||||
}
|
||||
|
||||
if okCount != 5 {
|
||||
t.Fatalf("expected 5 successful requests, got %d", okCount)
|
||||
}
|
||||
|
||||
// 6th request should be 429
|
||||
req := httptest.NewRequest("POST", "/api/register", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("6th request: status = %d, want 429", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimateTrafficPasses(t *testing.T) {
|
||||
// Verifies that normal traffic patterns don't get rate limited.
|
||||
// 3 registrations from different IPs should all succeed.
|
||||
l := NewLimiter(5, 5.0/3600) // 5/hour per IP
|
||||
mw := l.Middleware(func(r *http.Request) string { return r.RemoteAddr }, nil)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}))
|
||||
|
||||
for i, ip := range []string{"1.1.1.1:1111", "2.2.2.2:2222", "3.3.3.3:3333"} {
|
||||
req := httptest.NewRequest("POST", "/api/register", nil)
|
||||
req.RemoteAddr = ip
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("request %d from %s: status = %d, want 201", i+1, ip, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue