Extracted from ardenone-cluster/containers/zai-proxy and ardenone-cluster/containers/zai-proxy-dashboard. - proxy/: OpenAI-compatible ZAI reverse proxy (Go, v1.10.0) - Token counting, rate limiting, Prometheus metrics, canary support - dashboard/: Metrics dashboard backend + React frontend (Go, v1.0.0) - Prometheus collector, SQLite storage, SSE live updates - docs/: Operational notes, research, and plan subdirs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
812 lines
21 KiB
Go
812 lines
21 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// TestTikTokenCounter tests the tiktoken-based token counter
|
|
func TestTikTokenCounter(t *testing.T) {
|
|
counter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
t.Skipf("Skipping TikToken tests: failed to initialize counter: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
text string
|
|
wantMin int
|
|
wantMax int
|
|
}{
|
|
{
|
|
name: "Empty string",
|
|
text: "",
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
{
|
|
name: "Short phrase",
|
|
text: "Hello, world!",
|
|
wantMin: 3,
|
|
wantMax: 5,
|
|
},
|
|
{
|
|
name: "Longer sentence",
|
|
text: "The quick brown fox jumps over the lazy dog.",
|
|
wantMin: 9,
|
|
wantMax: 12,
|
|
},
|
|
{
|
|
name: "Code snippet",
|
|
text: "def hello_world():\n print('Hello, world!')",
|
|
wantMin: 10,
|
|
wantMax: 18,
|
|
},
|
|
{
|
|
name: "Unicode characters",
|
|
text: "Hello 世界! 🌍",
|
|
wantMin: 5,
|
|
wantMax: 12,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := counter.CountTokens(tt.text)
|
|
if err != nil {
|
|
t.Errorf("CountTokens() error = %v", err)
|
|
return
|
|
}
|
|
t.Logf("Text: %q, Tokens: %d (expected %d-%d)", tt.text, got, tt.wantMin, tt.wantMax)
|
|
if got < tt.wantMin || got > tt.wantMax {
|
|
t.Errorf("CountTokens() = %v, want between %v and %v", got, tt.wantMin, tt.wantMax)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSimpleTokenCounter tests the fallback token counter
|
|
func TestSimpleTokenCounter(t *testing.T) {
|
|
counter := NewSimpleTokenCounter()
|
|
|
|
tests := []struct {
|
|
name string
|
|
text string
|
|
wantMin int
|
|
wantMax int
|
|
}{
|
|
{
|
|
name: "Empty string",
|
|
text: "",
|
|
wantMin: 0,
|
|
wantMax: 1,
|
|
},
|
|
{
|
|
name: "Short phrase",
|
|
text: "Hello, world!",
|
|
wantMin: 1,
|
|
wantMax: 5,
|
|
},
|
|
{
|
|
name: "Longer sentence",
|
|
text: "The quick brown fox jumps over the lazy dog.",
|
|
wantMin: 8,
|
|
wantMax: 15,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := counter.CountTokens(tt.text)
|
|
if err != nil {
|
|
t.Errorf("CountTokens() error = %v", err)
|
|
return
|
|
}
|
|
if got < tt.wantMin || got > tt.wantMax {
|
|
t.Errorf("CountTokens() = %v, want between %v and %v", got, tt.wantMin, tt.wantMax)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTikTokenAccuracy tests token counting accuracy against known values
|
|
func TestTikTokenAccuracy(t *testing.T) {
|
|
counter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
t.Skipf("Skipping accuracy tests: failed to initialize counter: %v", err)
|
|
}
|
|
|
|
// Test cases with approximate expected token counts
|
|
// These are based on cl100k_base encoding
|
|
tests := []struct {
|
|
text string
|
|
tolerance float64 // Acceptable variance (0.05 = 5%)
|
|
expected int
|
|
}{
|
|
{
|
|
text: "Hello",
|
|
expected: 1,
|
|
tolerance: 0.1,
|
|
},
|
|
{
|
|
text: "Hello, world!",
|
|
expected: 4,
|
|
tolerance: 0.1,
|
|
},
|
|
{
|
|
text: "The quick brown fox jumps over the lazy dog",
|
|
expected: 10,
|
|
tolerance: 0.1,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.text[:min(20, len(tt.text))], func(t *testing.T) {
|
|
got, err := counter.CountTokens(tt.text)
|
|
if err != nil {
|
|
t.Errorf("CountTokens() error = %v", err)
|
|
return
|
|
}
|
|
|
|
variance := float64(abs(got-tt.expected)) / float64(tt.expected)
|
|
t.Logf("Text: %q\n Expected: %d tokens\n Got: %d tokens\n Variance: %.1f%%",
|
|
tt.text, tt.expected, got, variance*100)
|
|
|
|
if variance > tt.tolerance {
|
|
t.Errorf("Token count variance %.1f%% exceeds tolerance %.1f%%",
|
|
variance*100, tt.tolerance*100)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func abs(x int) int {
|
|
if x < 0 {
|
|
return -x
|
|
}
|
|
return x
|
|
}
|
|
|
|
// TestCountRequestTokens tests parsing request bodies and counting tokens
|
|
func TestCountRequestTokens(t *testing.T) {
|
|
var counter TokenCounter
|
|
tikCounter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
t.Logf("TikToken not available, using SimpleTokenCounter: %v", err)
|
|
counter = NewSimpleTokenCounter()
|
|
} else {
|
|
counter = tikCounter
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
wantMin int
|
|
wantMax int
|
|
}{
|
|
{
|
|
name: "Valid single message",
|
|
body: `{"model":"glm-4","messages":[{"role":"user","content":"Hello"}]}`,
|
|
wantMin: 1,
|
|
wantMax: 3,
|
|
},
|
|
{
|
|
name: "Multiple messages",
|
|
body: `{"model":"glm-4","messages":[{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello there"}]}`,
|
|
wantMin: 2,
|
|
wantMax: 6,
|
|
},
|
|
{
|
|
name: "Empty messages",
|
|
body: `{"model":"glm-4","messages":[]}`,
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
{
|
|
name: "Invalid JSON",
|
|
body: `{invalid json}`,
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
{
|
|
name: "Empty body",
|
|
body: ``,
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := CountRequestTokens([]byte(tt.body), counter)
|
|
if err != nil {
|
|
t.Errorf("CountRequestTokens() error = %v", err)
|
|
return
|
|
}
|
|
if got < tt.wantMin || got > tt.wantMax {
|
|
t.Errorf("CountRequestTokens() = %v, want between %v and %v", got, tt.wantMin, tt.wantMax)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestResponseBodyCapture tests capturing response body while streaming
|
|
func TestResponseBodyCapture(t *testing.T) {
|
|
var counter TokenCounter
|
|
tikCounter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
counter = NewSimpleTokenCounter()
|
|
} else {
|
|
counter = tikCounter
|
|
}
|
|
originalContent := "This is test content that should be captured."
|
|
|
|
// Create a ReadCloser from string
|
|
body := io.NopCloser(strings.NewReader(originalContent))
|
|
|
|
capture := NewResponseBodyCapture(body, counter)
|
|
defer capture.Close()
|
|
|
|
// Read all content
|
|
read, err := io.ReadAll(capture)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read from capture: %v", err)
|
|
}
|
|
|
|
// Verify forwarded content matches original
|
|
if string(read) != originalContent {
|
|
t.Errorf("Forwarded content = %q, want %q", string(read), originalContent)
|
|
}
|
|
|
|
// Verify captured content matches original
|
|
captured := capture.GetCapturedContent()
|
|
if string(captured) != originalContent {
|
|
t.Errorf("Captured content = %q, want %q", string(captured), originalContent)
|
|
}
|
|
}
|
|
|
|
// TestCountJSONResponseTokens tests counting tokens in non-streaming responses
|
|
func TestCountJSONResponseTokens(t *testing.T) {
|
|
var counter TokenCounter
|
|
tikCounter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
counter = NewSimpleTokenCounter()
|
|
} else {
|
|
counter = tikCounter
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
response string
|
|
wantMin int
|
|
wantMax int
|
|
}{
|
|
{
|
|
name: "Valid response with content",
|
|
response: `{"id":"msg_123","type":"message","role":"assistant","content":[{"type":"text","text":"Hello there"}]}`,
|
|
wantMin: 2,
|
|
wantMax: 4,
|
|
},
|
|
{
|
|
name: "Response with multiple content blocks",
|
|
response: `{"content":[{"type":"text","text":"First block"},{"type":"text","text":"Second block"}]}`,
|
|
wantMin: 3,
|
|
wantMax: 6,
|
|
},
|
|
{
|
|
name: "Empty content",
|
|
response: `{"content":[]}`,
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
{
|
|
name: "Invalid JSON",
|
|
response: `{invalid}`,
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
body := io.NopCloser(strings.NewReader(tt.response))
|
|
capture := NewResponseBodyCapture(body, counter)
|
|
|
|
// Read all to populate buffer
|
|
io.ReadAll(capture)
|
|
|
|
got, err := capture.countJSONTokens(capture.GetCapturedContent())
|
|
if err != nil {
|
|
t.Errorf("countJSONTokens() error = %v", err)
|
|
return
|
|
}
|
|
|
|
if got < tt.wantMin || got > tt.wantMax {
|
|
t.Errorf("countJSONTokens() = %v, want between %v and %v", got, tt.wantMin, tt.wantMax)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCountSSEResponseTokens tests counting tokens in streaming (SSE) responses
|
|
func TestCountSSEResponseTokens(t *testing.T) {
|
|
var counter TokenCounter
|
|
tikCounter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
counter = NewSimpleTokenCounter()
|
|
} else {
|
|
counter = tikCounter
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
response string
|
|
wantMin int
|
|
wantMax int
|
|
}{
|
|
{
|
|
name: "Valid SSE stream",
|
|
response: `data: {"type":"message_start","message":{"id":"msg_123"}}
|
|
|
|
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
|
|
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
|
|
|
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" there"}}
|
|
|
|
data: {"type":"content_block_stop","index":0}
|
|
|
|
data: {"type":"message_stop"}
|
|
`,
|
|
wantMin: 2,
|
|
wantMax: 4,
|
|
},
|
|
{
|
|
name: "SSE with multiple deltas",
|
|
response: `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"The"}}
|
|
|
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" quick"}}
|
|
|
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" brown"}}
|
|
`,
|
|
wantMin: 3,
|
|
wantMax: 6,
|
|
},
|
|
{
|
|
name: "Empty SSE",
|
|
response: ``,
|
|
wantMin: 0,
|
|
wantMax: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
body := io.NopCloser(strings.NewReader(tt.response))
|
|
capture := NewResponseBodyCapture(body, counter)
|
|
|
|
// Read all to populate buffer
|
|
io.ReadAll(capture)
|
|
|
|
got, err := capture.countSSETokens(capture.GetCapturedContent())
|
|
if err != nil {
|
|
t.Errorf("countSSETokens() error = %v", err)
|
|
return
|
|
}
|
|
|
|
if got < tt.wantMin || got > tt.wantMax {
|
|
t.Errorf("countSSETokens() = %v, want between %v and %v", got, tt.wantMin, tt.wantMax)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestInjectJSONUsage tests injecting usage into non-streaming responses
|
|
func TestInjectJSONUsage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
inputTokens int
|
|
outputTokens int
|
|
wantUsage bool
|
|
}{
|
|
{
|
|
name: "Valid JSON response",
|
|
body: `{"id":"msg_123","type":"message","role":"assistant"}`,
|
|
inputTokens: 10,
|
|
outputTokens: 20,
|
|
wantUsage: true,
|
|
},
|
|
{
|
|
name: "Invalid JSON",
|
|
body: `{invalid}`,
|
|
inputTokens: 5,
|
|
outputTokens: 10,
|
|
wantUsage: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := injectJSONUsage([]byte(tt.body), tt.inputTokens, tt.outputTokens)
|
|
if err != nil {
|
|
t.Errorf("injectJSONUsage() error = %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.wantUsage {
|
|
if !bytes.Contains(got, []byte("input_tokens")) {
|
|
t.Error("Response missing input_tokens")
|
|
}
|
|
if !bytes.Contains(got, []byte("output_tokens")) {
|
|
t.Error("Response missing output_tokens")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestInjectSSEUsage tests injecting usage into streaming responses
|
|
func TestInjectSSEUsage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
inputTokens int
|
|
outputTokens int
|
|
wantUsage bool
|
|
}{
|
|
{
|
|
name: "Valid SSE with message_delta",
|
|
body: `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
|
|
|
|
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"}}
|
|
|
|
data: {"type":"message_stop"}
|
|
`,
|
|
inputTokens: 5,
|
|
outputTokens: 10,
|
|
wantUsage: true,
|
|
},
|
|
{
|
|
name: "SSE without message_delta",
|
|
body: `data: {"type":"content_block_start"}
|
|
|
|
data: {"type":"content_block_stop"}
|
|
`,
|
|
inputTokens: 5,
|
|
outputTokens: 10,
|
|
wantUsage: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := injectSSEUsage([]byte(tt.body), tt.inputTokens, tt.outputTokens)
|
|
if err != nil {
|
|
t.Errorf("injectSSEUsage() error = %v", err)
|
|
return
|
|
}
|
|
|
|
if tt.wantUsage {
|
|
if !bytes.Contains(got, []byte("input_tokens")) {
|
|
t.Error("SSE response missing input_tokens in message_delta")
|
|
}
|
|
if !bytes.Contains(got, []byte("output_tokens")) {
|
|
t.Error("SSE response missing output_tokens in message_delta")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestStreamingUsagePassthrough tests that upstream usage is passed through when present
|
|
func TestStreamingUsagePassthrough(t *testing.T) {
|
|
var counter TokenCounter
|
|
tikCounter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
counter = NewSimpleTokenCounter()
|
|
} else {
|
|
counter = tikCounter
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
body string
|
|
inputTokens int
|
|
wantUpstreamUsage bool
|
|
wantUsageFields []string // fields expected in usage
|
|
}{
|
|
{
|
|
name: "message_delta with upstream usage should be preserved",
|
|
body: `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
|
|
|
|
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":42,"output_tokens":99,"cache_creation_input_tokens":5,"cache_read_input_tokens":10}}
|
|
|
|
data: {"type":"message_stop"}
|
|
`,
|
|
inputTokens: 5,
|
|
wantUpstreamUsage: true,
|
|
wantUsageFields: []string{`"input_tokens":42`, `"output_tokens":99`, `"cache_creation_input_tokens":5`, `"cache_read_input_tokens":10`},
|
|
},
|
|
{
|
|
name: "message_delta without upstream usage should get proxy usage",
|
|
body: `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
|
|
|
|
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"}}
|
|
|
|
data: {"type":"message_stop"}
|
|
`,
|
|
inputTokens: 10,
|
|
wantUpstreamUsage: false,
|
|
wantUsageFields: []string{`"input_tokens":10`, `"output_tokens"`},
|
|
},
|
|
{
|
|
name: "message_delta with empty usage object should get proxy usage",
|
|
body: `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
|
|
|
|
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{}}
|
|
|
|
data: {"type":"message_stop"}
|
|
`,
|
|
inputTokens: 15,
|
|
wantUpstreamUsage: false,
|
|
wantUsageFields: []string{`"input_tokens":15`},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
body := io.NopCloser(strings.NewReader(tt.body))
|
|
capture := NewStreamingResponseBodyCapture(body, counter, tt.inputTokens)
|
|
|
|
// Read all content
|
|
read, err := io.ReadAll(capture)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read from capture: %v", err)
|
|
}
|
|
capture.Close()
|
|
|
|
// Check for expected usage fields
|
|
for _, field := range tt.wantUsageFields {
|
|
if !bytes.Contains(read, []byte(field)) {
|
|
t.Errorf("Response missing expected usage field: %s\nGot: %s", field, string(read))
|
|
}
|
|
}
|
|
|
|
// If upstream usage should be preserved, verify it's not overwritten with proxy values
|
|
if tt.wantUpstreamUsage {
|
|
// Should NOT contain the proxy input tokens (5) when upstream has 42
|
|
if bytes.Contains(read, []byte(`"input_tokens":5`)) {
|
|
t.Error("Response should preserve upstream usage (42), not use proxy value (5)")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestWrapResponseWithUsage tests wrapping Z.AI responses with Claude-compatible usage field
|
|
func TestWrapResponseWithUsage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
originalResp string
|
|
inputTokens int
|
|
outputTokens int
|
|
wantUsage bool
|
|
wantResult bool // whether to expect a "result" field
|
|
}{
|
|
{
|
|
name: "Z.AI response with result field",
|
|
originalResp: `{"result":{"id":"msg_123","type":"message","role":"assistant","content":[{"type":"text","text":"Hello there"}]}}`,
|
|
inputTokens: 10,
|
|
outputTokens: 20,
|
|
wantUsage: true,
|
|
wantResult: true,
|
|
},
|
|
{
|
|
name: "Z.AI response without result field (direct content)",
|
|
originalResp: `{"id":"msg_456","type":"message","role":"assistant","content":[{"type":"text","text":"Hi"}]}`,
|
|
inputTokens: 5,
|
|
outputTokens: 10,
|
|
wantUsage: true,
|
|
wantResult: true,
|
|
},
|
|
{
|
|
name: "Empty response",
|
|
originalResp: `{}`,
|
|
inputTokens: 0,
|
|
outputTokens: 0,
|
|
wantUsage: true,
|
|
wantResult: true,
|
|
},
|
|
{
|
|
name: "Invalid JSON should return original",
|
|
originalResp: `{invalid json}`,
|
|
inputTokens: 5,
|
|
outputTokens: 10,
|
|
wantUsage: false, // Original is returned on error
|
|
wantResult: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := WrapResponseWithUsage([]byte(tt.originalResp), tt.inputTokens, tt.outputTokens)
|
|
|
|
// For invalid JSON, expect error
|
|
if tt.name == "Invalid JSON should return original" {
|
|
if err == nil {
|
|
t.Error("Expected error for invalid JSON, got nil")
|
|
}
|
|
// Should still return original response on error
|
|
if string(got) != tt.originalResp {
|
|
t.Errorf("Expected original response on error, got: %s", string(got))
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("WrapResponseWithUsage() error = %v", err)
|
|
return
|
|
}
|
|
|
|
// Check that response contains usage field with correct values
|
|
if tt.wantUsage {
|
|
if !bytes.Contains(got, []byte(`"input_tokens":`)) {
|
|
t.Error("Wrapped response missing input_tokens field")
|
|
}
|
|
if !bytes.Contains(got, []byte(`"output_tokens":`)) {
|
|
t.Error("Wrapped response missing output_tokens field")
|
|
}
|
|
|
|
// Verify the specific token values are present
|
|
if !bytes.Contains(got, []byte(`"input_tokens":`+string(rune('0'+tt.inputTokens/10))+string(rune('0'+tt.inputTokens%10)))) {
|
|
// Just check that input_tokens key exists (exact value check is complex)
|
|
t.Logf("Response contains input_tokens: %s", string(got))
|
|
}
|
|
}
|
|
|
|
// Check that result field is present
|
|
if tt.wantResult {
|
|
if !bytes.Contains(got, []byte(`"result"`)) {
|
|
t.Error("Wrapped response missing result field")
|
|
}
|
|
}
|
|
|
|
// Parse and validate structure
|
|
var wrapped map[string]interface{}
|
|
if err := json.Unmarshal(got, &wrapped); err != nil {
|
|
t.Errorf("Failed to parse wrapped response: %v", err)
|
|
return
|
|
}
|
|
|
|
// Validate usage structure
|
|
usage, ok := wrapped["usage"].(map[string]interface{})
|
|
if !ok {
|
|
t.Error("Wrapped response missing or invalid usage object")
|
|
return
|
|
}
|
|
|
|
// Check required Claude-compatible fields
|
|
requiredFields := []string{"input_tokens", "output_tokens", "cache_read_input_tokens", "cache_creation_input_tokens"}
|
|
for _, field := range requiredFields {
|
|
if _, exists := usage[field]; !exists {
|
|
t.Errorf("Usage missing required field: %s", field)
|
|
}
|
|
}
|
|
|
|
t.Logf("Wrapped response: %s", string(got))
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestWrapResponseWithUsagePreservesContent tests that content is preserved during wrapping
|
|
func TestWrapResponseWithUsagePreservesContent(t *testing.T) {
|
|
originalContent := "This is the original response content that should be preserved."
|
|
originalResp := `{"result":{"id":"msg_test","content":[{"type":"text","text":"` + originalContent + `"}]}}`
|
|
|
|
wrapped, err := WrapResponseWithUsage([]byte(originalResp), 10, 20)
|
|
if err != nil {
|
|
t.Fatalf("WrapResponseWithUsage() error = %v", err)
|
|
}
|
|
|
|
// Verify original content is preserved in the wrapped response
|
|
if !bytes.Contains(wrapped, []byte(originalContent)) {
|
|
t.Errorf("Wrapped response lost original content.\nOriginal: %s\nWrapped: %s", originalResp, string(wrapped))
|
|
}
|
|
|
|
// Parse and verify structure
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal(wrapped, &result); err != nil {
|
|
t.Fatalf("Failed to parse wrapped response: %v", err)
|
|
}
|
|
|
|
resultObj, ok := result["result"].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("Wrapped response missing result object")
|
|
}
|
|
|
|
content, ok := resultObj["content"].([]interface{})
|
|
if !ok {
|
|
t.Fatal("Wrapped response result missing content array")
|
|
}
|
|
|
|
if len(content) == 0 {
|
|
t.Fatal("Content array is empty")
|
|
}
|
|
|
|
textBlock, ok := content[0].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("Content block is not an object")
|
|
}
|
|
|
|
text, ok := textBlock["text"].(string)
|
|
if !ok || text != originalContent {
|
|
t.Errorf("Content text mismatch: got %q, want %q", text, originalContent)
|
|
}
|
|
}
|
|
|
|
// TestStreamingPreservation tests that TeeReader doesn't affect streaming
|
|
func TestStreamingPreservation(t *testing.T) {
|
|
var counter TokenCounter
|
|
tikCounter, err := NewTikTokenCounter()
|
|
if err != nil {
|
|
counter = NewSimpleTokenCounter()
|
|
} else {
|
|
counter = tikCounter
|
|
}
|
|
|
|
// Simulate streaming data in chunks
|
|
chunks := []string{
|
|
"data: {\"type\":\"message_start\"}\n",
|
|
"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n",
|
|
"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\" world\"}}\n",
|
|
"data: {\"type\":\"message_stop\"}\n",
|
|
}
|
|
|
|
// Create a reader that emits chunks
|
|
pipeReader, pipeWriter := io.Pipe()
|
|
|
|
go func() {
|
|
defer pipeWriter.Close()
|
|
for _, chunk := range chunks {
|
|
pipeWriter.Write([]byte(chunk))
|
|
}
|
|
}()
|
|
|
|
capture := NewResponseBodyCapture(io.NopCloser(pipeReader), counter)
|
|
defer capture.Close()
|
|
|
|
// Read in chunks to simulate streaming
|
|
buf := make([]byte, 64)
|
|
var received []string
|
|
|
|
for {
|
|
n, err := capture.Read(buf)
|
|
if n > 0 {
|
|
received = append(received, string(buf[:n]))
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Read error: %v", err)
|
|
}
|
|
}
|
|
|
|
// Verify all chunks were received
|
|
fullReceived := strings.Join(received, "")
|
|
expected := strings.Join(chunks, "")
|
|
|
|
if fullReceived != expected {
|
|
t.Errorf("Streaming content mismatch:\ngot: %q\nwant: %q", fullReceived, expected)
|
|
}
|
|
|
|
// Verify captured content is complete
|
|
captured := string(capture.GetCapturedContent())
|
|
if captured != expected {
|
|
t.Errorf("Captured content incomplete:\ngot: %q\nwant: %q", captured, expected)
|
|
}
|
|
}
|