zai-proxy/proxy/metrics_test.go
jedarden e7c24a0c08 feat: initial zai-proxy ecosystem repo
Extracted from ardenone-cluster/containers/zai-proxy and
ardenone-cluster/containers/zai-proxy-dashboard.

- proxy/: OpenAI-compatible ZAI reverse proxy (Go, v1.10.0)
  - Token counting, rate limiting, Prometheus metrics, canary support
- dashboard/: Metrics dashboard backend + React frontend (Go, v1.0.0)
  - Prometheus collector, SQLite storage, SSE live updates
- docs/: Operational notes, research, and plan subdirs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 15:53:52 -04:00

371 lines
10 KiB
Go

package main
import (
"strings"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestRecordInputTokens(t *testing.T) {
// Reset metrics before test
tokensTotal.Reset()
tests := []struct {
name string
model string
version string
count int
wantZero bool
}{
{
name: "record positive input tokens",
model: "glm-4",
version: "stable",
count: 100,
wantZero: false,
},
{
name: "record zero tokens - should not increment",
model: "glm-4",
version: "stable",
count: 0,
wantZero: true,
},
{
name: "record negative tokens - should not increment",
model: "glm-4",
version: "stable",
count: -10,
wantZero: true,
},
{
name: "canary deployment tokens",
model: "glm-4",
version: "canary",
count: 250,
wantZero: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Get initial count
initialCount := testutil.ToFloat64(tokensTotal.WithLabelValues("input", tt.model, tt.version, "off_peak"))
// Record tokens
RecordInputTokens(tt.model, tt.version, tt.count)
// Get final count
finalCount := testutil.ToFloat64(tokensTotal.WithLabelValues("input", tt.model, tt.version, "off_peak"))
if tt.wantZero {
// Should not have changed
if finalCount != initialCount {
t.Errorf("RecordInputTokens() changed count for zero/negative input: initial=%v, final=%v", initialCount, finalCount)
}
} else {
// Should have increased by count
expected := initialCount + float64(tt.count)
if finalCount != expected {
t.Errorf("RecordInputTokens() count mismatch: got=%v, want=%v", finalCount, expected)
}
}
})
}
}
func TestRecordOutputTokens(t *testing.T) {
// Reset metrics before test
tokensTotal.Reset()
tests := []struct {
name string
model string
version string
count int
wantZero bool
}{
{
name: "record positive output tokens",
model: "glm-4",
version: "stable",
count: 500,
wantZero: false,
},
{
name: "record zero tokens - should not increment",
model: "glm-4",
version: "stable",
count: 0,
wantZero: true,
},
{
name: "canary deployment output tokens",
model: "claude-3",
version: "canary",
count: 1000,
wantZero: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Get initial count
initialCount := testutil.ToFloat64(tokensTotal.WithLabelValues("output", tt.model, tt.version, "off_peak"))
// Record tokens
RecordOutputTokens(tt.model, tt.version, tt.count)
// Get final count
finalCount := testutil.ToFloat64(tokensTotal.WithLabelValues("output", tt.model, tt.version, "off_peak"))
if tt.wantZero {
// Should not have changed
if finalCount != initialCount {
t.Errorf("RecordOutputTokens() changed count for zero input: initial=%v, final=%v", initialCount, finalCount)
}
} else {
// Should have increased by count
expected := initialCount + float64(tt.count)
if finalCount != expected {
t.Errorf("RecordOutputTokens() count mismatch: got=%v, want=%v", finalCount, expected)
}
}
})
}
}
func TestRecordTokenRate(t *testing.T) {
// Reset metrics before test
tokenRateSeconds.Reset()
tokenRate.Reset()
tests := []struct {
name string
direction string
model string
version string
duration time.Duration
tokenCount int
wantRecord bool
}{
{
name: "record input token rate",
direction: "input",
model: "glm-4",
version: "stable",
duration: 10 * time.Millisecond,
tokenCount: 100,
wantRecord: true,
},
{
name: "record output token rate",
direction: "output",
model: "glm-4",
version: "canary",
duration: 5 * time.Millisecond,
tokenCount: 250,
wantRecord: true,
},
{
name: "zero token count - should not record",
direction: "input",
model: "glm-4",
version: "stable",
duration: 10 * time.Millisecond,
tokenCount: 0,
wantRecord: false,
},
{
name: "zero duration - should not record",
direction: "input",
model: "glm-4",
version: "stable",
duration: 0,
tokenCount: 100,
wantRecord: false,
},
{
name: "negative token count - should not record",
direction: "input",
model: "glm-4",
version: "stable",
duration: 10 * time.Millisecond,
tokenCount: -10,
wantRecord: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Record token rate
RecordTokenRate(tt.direction, tt.model, tt.version, tt.duration, tt.tokenCount)
// For histograms, we can't easily get the count before/after without collecting metrics
// Instead, we verify that the function doesn't panic and completes successfully
// The detailed verification is done in TestMetricsExportFormat
if tt.wantRecord {
// Test passed if we got here without panic
}
})
}
}
func TestRecordInputTokenRate(t *testing.T) {
// Reset metrics before test
tokenRateSeconds.Reset()
tokenRate.Reset()
// Record input token rate
model := "glm-4"
version := "stable"
duration := 10 * time.Millisecond
tokenCount := 100
// Should not panic
RecordInputTokenRate(model, version, duration, tokenCount)
// Test passes if no panic occurs
}
func TestRecordOutputTokenRate(t *testing.T) {
// Reset metrics before test
tokenRateSeconds.Reset()
tokenRate.Reset()
// Record output token rate
model := "glm-4"
version := "canary"
duration := 5 * time.Millisecond
tokenCount := 250
// Should not panic
RecordOutputTokenRate(model, version, duration, tokenCount)
// Test passes if no panic occurs
}
func TestMetricLabels(t *testing.T) {
// Reset all metrics
tokensTotal.Reset()
tokenRateSeconds.Reset()
tokenRate.Reset()
// Record metrics with different label combinations
RecordInputTokens("glm-4", "stable", 100)
RecordInputTokens("glm-4", "canary", 150)
RecordOutputTokens("claude-3", "stable", 200)
RecordOutputTokens("claude-3", "canary", 250)
RecordInputTokenRate("glm-4", "stable", 10*time.Millisecond, 100)
RecordOutputTokenRate("claude-3", "canary", 5*time.Millisecond, 200)
// Verify each label combination is tracked separately
tests := []struct {
name string
direction string
model string
version string
wantCount float64
}{
{"input glm-4 stable", "input", "glm-4", "stable", 100},
{"input glm-4 canary", "input", "glm-4", "canary", 150},
{"output claude-3 stable", "output", "claude-3", "stable", 200},
{"output claude-3 canary", "output", "claude-3", "canary", 250},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count := testutil.ToFloat64(tokensTotal.WithLabelValues(tt.direction, tt.model, tt.version, "off_peak"))
if count != tt.wantCount {
t.Errorf("Token count mismatch for %s: got=%v, want=%v", tt.name, count, tt.wantCount)
}
})
}
}
func TestMetricNoLeaks(t *testing.T) {
// This test ensures that recording metrics doesn't create memory leaks
// by checking that we don't create unbounded label combinations
// Reset metrics
tokensTotal.Reset()
tokenRateSeconds.Reset()
tokenRate.Reset()
// Record metrics multiple times with same labels
for i := 0; i < 1000; i++ {
RecordInputTokens("glm-4", "stable", 1)
RecordOutputTokens("glm-4", "stable", 1)
RecordInputTokenRate("glm-4", "stable", 1*time.Millisecond, 10)
RecordOutputTokenRate("glm-4", "stable", 1*time.Millisecond, 10)
}
// Verify counts accumulated correctly (should be 1000)
inputCount := testutil.ToFloat64(tokensTotal.WithLabelValues("input", "glm-4", "stable", "off_peak"))
if inputCount != 1000 {
t.Errorf("Input token count incorrect after 1000 iterations: got=%v, want=1000", inputCount)
}
outputCount := testutil.ToFloat64(tokensTotal.WithLabelValues("output", "glm-4", "stable", "off_peak"))
if outputCount != 1000 {
t.Errorf("Output token count incorrect after 1000 iterations: got=%v, want=1000", outputCount)
}
// For histograms, we just verify they don't panic
// The histogram metrics are tested in TestMetricsExportFormat
}
func TestMetricNoConflicts(t *testing.T) {
// This test ensures that different label combinations don't interfere with each other
// Reset metrics
tokensTotal.Reset()
// Record different combinations
RecordInputTokens("glm-4", "stable", 100)
RecordInputTokens("glm-4", "canary", 200)
RecordOutputTokens("glm-4", "stable", 300)
RecordOutputTokens("glm-4", "canary", 400)
// Verify each is independent
if got := testutil.ToFloat64(tokensTotal.WithLabelValues("input", "glm-4", "stable", "off_peak")); got != 100 {
t.Errorf("Input stable tokens incorrect: got=%v, want=100", got)
}
if got := testutil.ToFloat64(tokensTotal.WithLabelValues("input", "glm-4", "canary", "off_peak")); got != 200 {
t.Errorf("Input canary tokens incorrect: got=%v, want=200", got)
}
if got := testutil.ToFloat64(tokensTotal.WithLabelValues("output", "glm-4", "stable", "off_peak")); got != 300 {
t.Errorf("Output stable tokens incorrect: got=%v, want=300", got)
}
if got := testutil.ToFloat64(tokensTotal.WithLabelValues("output", "glm-4", "canary", "off_peak")); got != 400 {
t.Errorf("Output canary tokens incorrect: got=%v, want=400", got)
}
}
func TestMetricsExportFormat(t *testing.T) {
// Reset metrics
tokensTotal.Reset()
tokenRateSeconds.Reset()
tokenRate.Reset()
// Record some metrics
RecordInputTokens("glm-4", "stable", 100)
RecordOutputTokens("glm-4", "stable", 200)
RecordInputTokenRate("glm-4", "stable", 10*time.Millisecond, 100)
// Collect metrics in Prometheus text format
metadata := `
# HELP zai_proxy_tokens_total Total number of tokens processed by direction (input/output), model, and deployment variant (stable/canary)
# TYPE zai_proxy_tokens_total counter
`
expectedInputLine := `zai_proxy_tokens_total{direction="input",model="glm-4",variant="stable"} 100`
expectedOutputLine := `zai_proxy_tokens_total{direction="output",model="glm-4",variant="stable"} 200`
// Verify metric can be collected
if err := testutil.CollectAndCompare(tokensTotal, strings.NewReader(metadata+expectedInputLine+"\n"+expectedOutputLine+"\n")); err != nil {
t.Errorf("Metrics export format incorrect: %v", err)
}
}