Extracted from ardenone-cluster/containers/zai-proxy and ardenone-cluster/containers/zai-proxy-dashboard. - proxy/: OpenAI-compatible ZAI reverse proxy (Go, v1.10.0) - Token counting, rate limiting, Prometheus metrics, canary support - dashboard/: Metrics dashboard backend + React frontend (Go, v1.0.0) - Prometheus collector, SQLite storage, SSE live updates - docs/: Operational notes, research, and plan subdirs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
609 lines
18 KiB
Go
609 lines
18 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/tiktoken-go/tokenizer"
|
|
)
|
|
|
|
// TokenCounter interface for counting tokens in text
|
|
type TokenCounter interface {
|
|
CountTokens(text string) (int, error)
|
|
}
|
|
|
|
// TikTokenCounter uses tiktoken-go with cl100k_base encoding (Claude 3 compatible)
|
|
type TikTokenCounter struct {
|
|
encoder tokenizer.Codec
|
|
mu sync.Mutex // Protect encoder access
|
|
}
|
|
|
|
// NewTikTokenCounter creates a new tiktoken-based token counter with cl100k_base encoding
|
|
func NewTikTokenCounter() (*TikTokenCounter, error) {
|
|
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &TikTokenCounter{
|
|
encoder: enc,
|
|
}, nil
|
|
}
|
|
|
|
// CountTokens counts tokens in text using tiktoken cl100k_base encoding
|
|
func (tc *TikTokenCounter) CountTokens(text string) (int, error) {
|
|
if text == "" {
|
|
return 0, nil
|
|
}
|
|
|
|
tc.mu.Lock()
|
|
defer tc.mu.Unlock()
|
|
|
|
// Encode text to token IDs
|
|
ids, _, err := tc.encoder.Encode(text)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(ids), nil
|
|
}
|
|
|
|
// SimpleTokenCounter is a fallback tokenizer that uses word count approximation
|
|
// Used only if TikToken initialization fails
|
|
type SimpleTokenCounter struct{}
|
|
|
|
func NewSimpleTokenCounter() *SimpleTokenCounter {
|
|
return &SimpleTokenCounter{}
|
|
}
|
|
|
|
// CountTokens approximates token count using word count * 1.3
|
|
// This is a rough approximation for fallback scenarios
|
|
func (tc *SimpleTokenCounter) CountTokens(text string) (int, error) {
|
|
if text == "" {
|
|
return 0, nil
|
|
}
|
|
|
|
// Rough approximation: ~1.3 tokens per word on average
|
|
words := len(text) / 4 // Average word length ~4 chars
|
|
if words == 0 {
|
|
words = 1
|
|
}
|
|
|
|
return words, nil
|
|
}
|
|
|
|
// UsageData holds token usage counts from an API response.
|
|
type UsageData struct {
|
|
InputTokens int
|
|
OutputTokens int
|
|
CacheReadTokens int
|
|
CacheWriteTokens int
|
|
FromAPI bool // true = upstream API counts, false = tiktoken estimate
|
|
}
|
|
|
|
// ExtractUsageFromJSON reads the usage block from a non-streaming Anthropic-format response.
|
|
// Returns (usage, true) when a usage block with non-zero token counts is present.
|
|
func ExtractUsageFromJSON(body []byte) (UsageData, bool) {
|
|
var resp struct {
|
|
Usage *struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
if err := json.Unmarshal(body, &resp); err != nil || resp.Usage == nil {
|
|
return UsageData{}, false
|
|
}
|
|
u := resp.Usage
|
|
if u.InputTokens == 0 && u.OutputTokens == 0 {
|
|
return UsageData{}, false
|
|
}
|
|
return UsageData{
|
|
InputTokens: u.InputTokens,
|
|
OutputTokens: u.OutputTokens,
|
|
CacheReadTokens: u.CacheReadInputTokens,
|
|
CacheWriteTokens: u.CacheCreationInputTokens,
|
|
FromAPI: true,
|
|
}, true
|
|
}
|
|
|
|
// jsonFloat safely converts a JSON-unmarshalled interface{} value to float64.
|
|
func jsonFloat(v interface{}) float64 {
|
|
if f, ok := v.(float64); ok {
|
|
return f
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// RequestBody represents Claude API request structure
|
|
type RequestBody struct {
|
|
Model string `json:"model"`
|
|
Messages []RequestMessage `json:"messages"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
}
|
|
|
|
// ContentBlock represents a single content block in multi-modal messages
|
|
type ContentBlock struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
}
|
|
|
|
// RequestMessage represents a message in Claude API request format
|
|
// Content can be either a string (simple text) or an array of ContentBlock (multi-modal)
|
|
type RequestMessage struct {
|
|
Role string `json:"role"`
|
|
Content json.RawMessage `json:"content"` // Can be string or array
|
|
}
|
|
|
|
// CountRequestTokens extracts messages from request body and counts tokens
|
|
// Supports both simple string content and multi-modal array content
|
|
func CountRequestTokens(body []byte, counter TokenCounter) (int, error) {
|
|
if len(body) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
var req RequestBody
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
log.Printf("Warning: failed to parse request body for token counting: %v", err)
|
|
return 0, nil // Graceful degradation
|
|
}
|
|
|
|
totalTokens := 0
|
|
for _, msg := range req.Messages {
|
|
// Try to parse content as string first (simple text message)
|
|
var contentStr string
|
|
if err := json.Unmarshal(msg.Content, &contentStr); err == nil {
|
|
tokens, err := counter.CountTokens(contentStr)
|
|
if err != nil {
|
|
log.Printf("Warning: failed to count tokens for message: %v", err)
|
|
continue
|
|
}
|
|
totalTokens += tokens
|
|
continue
|
|
}
|
|
|
|
// If not a string, try array of ContentBlock (multi-modal message)
|
|
var contentBlocks []ContentBlock
|
|
if err := json.Unmarshal(msg.Content, &contentBlocks); err == nil {
|
|
for _, block := range contentBlocks {
|
|
if block.Type == "text" && block.Text != "" {
|
|
tokens, err := counter.CountTokens(block.Text)
|
|
if err != nil {
|
|
log.Printf("Warning: failed to count tokens for content block: %v", err)
|
|
continue
|
|
}
|
|
totalTokens += tokens
|
|
}
|
|
// Other block types (image, etc.) are skipped for token counting
|
|
}
|
|
} else {
|
|
log.Printf("Warning: failed to parse message content (neither string nor array): %v", err)
|
|
}
|
|
}
|
|
|
|
return totalTokens, nil
|
|
}
|
|
|
|
// ResponseBodyCapture captures streaming response body for token counting
|
|
type ResponseBodyCapture struct {
|
|
originalBody io.ReadCloser
|
|
buffer *bytes.Buffer
|
|
teeReader io.Reader
|
|
counter TokenCounter
|
|
}
|
|
|
|
// NewResponseBodyCapture creates a new response body capture that uses io.TeeReader
|
|
func NewResponseBodyCapture(body io.ReadCloser, counter TokenCounter) *ResponseBodyCapture {
|
|
buffer := &bytes.Buffer{}
|
|
teeReader := io.TeeReader(body, buffer)
|
|
|
|
return &ResponseBodyCapture{
|
|
originalBody: body,
|
|
buffer: buffer,
|
|
teeReader: teeReader,
|
|
counter: counter,
|
|
}
|
|
}
|
|
|
|
// WrapResponseWithUsage wraps a non-streaming Z.AI JSON response with Claude-compatible usage field
|
|
// This enables ccdash to track GLM token consumption from session logs
|
|
func WrapResponseWithUsage(originalResp []byte, inputTokens, outputTokens int) ([]byte, error) {
|
|
// Parse the original Z.AI response
|
|
var zaiResp map[string]interface{}
|
|
if err := json.Unmarshal(originalResp, &zaiResp); err != nil {
|
|
log.Printf("Warning: failed to parse Z.AI response: %v", err)
|
|
return originalResp, err // Return original on parse error
|
|
}
|
|
|
|
// Extract the actual result from Z.AI response structure
|
|
var result interface{}
|
|
if res, ok := zaiResp["result"]; ok {
|
|
result = res
|
|
} else {
|
|
result = zaiResp
|
|
}
|
|
|
|
// Wrap in Claude-compatible format with usage field
|
|
wrapped := map[string]interface{}{
|
|
"result": result,
|
|
"usage": map[string]interface{}{
|
|
"input_tokens": inputTokens,
|
|
"output_tokens": outputTokens,
|
|
"cache_read_input_tokens": 0,
|
|
"cache_creation_input_tokens": 0,
|
|
},
|
|
}
|
|
|
|
wrappedJSON, err := json.Marshal(wrapped)
|
|
if err != nil {
|
|
log.Printf("Warning: failed to marshal wrapped response: %v", err)
|
|
return originalResp, err
|
|
}
|
|
|
|
log.Printf("Injected usage into Z.AI response: input=%d, output=%d", inputTokens, outputTokens)
|
|
return wrappedJSON, nil
|
|
}
|
|
|
|
// Read implements io.Reader, forwarding reads while capturing content
|
|
func (rbc *ResponseBodyCapture) Read(p []byte) (n int, err error) {
|
|
return rbc.teeReader.Read(p)
|
|
}
|
|
|
|
// Close implements io.Closer
|
|
func (rbc *ResponseBodyCapture) Close() error {
|
|
return rbc.originalBody.Close()
|
|
}
|
|
|
|
// GetCapturedContent returns the captured response body
|
|
func (rbc *ResponseBodyCapture) GetCapturedContent() []byte {
|
|
return rbc.buffer.Bytes()
|
|
}
|
|
|
|
// CountOutputTokens counts tokens in the captured response
|
|
func (rbc *ResponseBodyCapture) CountOutputTokens() (int, error) {
|
|
content := rbc.buffer.Bytes()
|
|
if len(content) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// Check if this is a streaming response (SSE format)
|
|
if bytes.Contains(content, []byte("data: ")) {
|
|
return rbc.countSSETokens(content)
|
|
}
|
|
|
|
// Non-streaming response
|
|
return rbc.countJSONTokens(content)
|
|
}
|
|
|
|
// countSSETokens counts tokens in SSE (Server-Sent Events) streaming response
|
|
func (rbc *ResponseBodyCapture) countSSETokens(content []byte) (int, error) {
|
|
lines := bytes.Split(content, []byte("\n"))
|
|
totalTokens := 0
|
|
|
|
for _, line := range lines {
|
|
// Parse SSE data lines
|
|
if !bytes.HasPrefix(line, []byte("data: ")) {
|
|
continue
|
|
}
|
|
|
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
|
if len(jsonData) == 0 {
|
|
continue
|
|
}
|
|
|
|
var event map[string]interface{}
|
|
if err := json.Unmarshal(jsonData, &event); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Extract text from content_block_delta events
|
|
if eventType, ok := event["type"].(string); ok && eventType == "content_block_delta" {
|
|
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
|
if text, ok := delta["text"].(string); ok {
|
|
tokens, err := rbc.counter.CountTokens(text)
|
|
if err == nil {
|
|
totalTokens += tokens
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return totalTokens, nil
|
|
}
|
|
|
|
// countJSONTokens counts tokens in non-streaming JSON response
|
|
func (rbc *ResponseBodyCapture) countJSONTokens(content []byte) (int, error) {
|
|
var resp map[string]interface{}
|
|
if err := json.Unmarshal(content, &resp); err != nil {
|
|
log.Printf("Warning: failed to parse response body for token counting: %v", err)
|
|
return 0, nil
|
|
}
|
|
|
|
totalTokens := 0
|
|
|
|
// Extract text from content blocks
|
|
if contentBlocks, ok := resp["content"].([]interface{}); ok {
|
|
for _, block := range contentBlocks {
|
|
if blockMap, ok := block.(map[string]interface{}); ok {
|
|
if text, ok := blockMap["text"].(string); ok {
|
|
tokens, err := rbc.counter.CountTokens(text)
|
|
if err == nil {
|
|
totalTokens += tokens
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return totalTokens, nil
|
|
}
|
|
|
|
// InjectTokenUsage injects token usage into response body
|
|
// Note: SSE streaming responses are handled by StreamingResponseBodyCapture in main.go
|
|
// This function only handles non-streaming JSON responses
|
|
func InjectTokenUsage(body []byte, inputTokens, outputTokens int) ([]byte, error) {
|
|
// For SSE format, return as-is - streaming is handled elsewhere
|
|
if bytes.Contains(body, []byte("data: ")) {
|
|
return body, nil
|
|
}
|
|
|
|
// Non-streaming JSON response
|
|
return injectJSONUsage(body, inputTokens, outputTokens)
|
|
}
|
|
|
|
// injectSSEUsage injects token usage into the message_delta event in an SSE response.
|
|
func injectSSEUsage(body []byte, inputTokens, outputTokens int) ([]byte, error) {
|
|
lines := strings.Split(string(body), "\n")
|
|
var out []string
|
|
for _, line := range lines {
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
out = append(out, line)
|
|
continue
|
|
}
|
|
jsonData := strings.TrimPrefix(line, "data: ")
|
|
var event map[string]interface{}
|
|
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
|
out = append(out, line)
|
|
continue
|
|
}
|
|
if eventType, ok := event["type"].(string); ok && eventType == "message_delta" {
|
|
event["usage"] = map[string]int{
|
|
"input_tokens": inputTokens,
|
|
"output_tokens": outputTokens,
|
|
}
|
|
modified, err := json.Marshal(event)
|
|
if err != nil {
|
|
out = append(out, line)
|
|
continue
|
|
}
|
|
out = append(out, "data: "+string(modified))
|
|
continue
|
|
}
|
|
out = append(out, line)
|
|
}
|
|
return []byte(strings.Join(out, "\n")), nil
|
|
}
|
|
|
|
// injectJSONUsage adds usage field to JSON response
|
|
func injectJSONUsage(body []byte, inputTokens, outputTokens int) ([]byte, error) {
|
|
var resp map[string]interface{}
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
log.Printf("Warning: failed to parse response for usage injection: %v", err)
|
|
return body, nil // Return original on error
|
|
}
|
|
|
|
resp["usage"] = map[string]int{
|
|
"input_tokens": inputTokens,
|
|
"output_tokens": outputTokens,
|
|
}
|
|
|
|
return json.Marshal(resp)
|
|
}
|
|
|
|
// StreamingResponseBodyCapture captures streaming response body for token counting
|
|
// and injects usage information into the message_delta event
|
|
type StreamingResponseBodyCapture struct {
|
|
originalBody io.ReadCloser
|
|
buffer *bytes.Buffer
|
|
teeReader io.Reader
|
|
counter TokenCounter
|
|
inputTokens int
|
|
outputTokens int
|
|
state string // "reading", "injecting", "done"
|
|
injectBuffer []byte
|
|
deltaSeen bool
|
|
usage UsageData // API-reported token counts accumulated from SSE events
|
|
}
|
|
|
|
// NewStreamingResponseBodyCapture creates a new streaming response body capture
|
|
// that injects token usage into the message_delta SSE event
|
|
func NewStreamingResponseBodyCapture(body io.ReadCloser, counter TokenCounter, inputTokens int) *StreamingResponseBodyCapture {
|
|
buffer := &bytes.Buffer{}
|
|
teeReader := io.TeeReader(body, buffer)
|
|
|
|
return &StreamingResponseBodyCapture{
|
|
originalBody: body,
|
|
buffer: buffer,
|
|
teeReader: teeReader,
|
|
counter: counter,
|
|
inputTokens: inputTokens,
|
|
state: "reading",
|
|
deltaSeen: false,
|
|
}
|
|
}
|
|
|
|
// Read implements io.Reader with on-the-fly SSE usage injection
|
|
func (srbc *StreamingResponseBodyCapture) Read(p []byte) (n int, err error) {
|
|
// If we have data in the inject buffer, return that first
|
|
if len(srbc.injectBuffer) > 0 {
|
|
n = copy(p, srbc.injectBuffer)
|
|
srbc.injectBuffer = srbc.injectBuffer[n:]
|
|
if len(srbc.injectBuffer) == 0 {
|
|
srbc.state = "done"
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Read from the underlying reader
|
|
n, err = srbc.teeReader.Read(p)
|
|
if n > 0 {
|
|
// Process the newly read data to find and inject usage
|
|
srbc.processChunk(p[:n], &n)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// processChunk processes a chunk of data to inject usage into message_delta
|
|
func (srbc *StreamingResponseBodyCapture) processChunk(chunk []byte, n *int) {
|
|
// IMPORTANT: Count tokens FIRST, before checking for message_delta.
|
|
// This ensures tokens from content_block_delta events in the same chunk
|
|
// as message_delta are counted before we inject the usage.
|
|
srbc.countTokensInChunk(chunk)
|
|
|
|
// Look for message_delta events in the chunk
|
|
data := string(chunk)
|
|
|
|
// Check if this chunk contains "message_delta"
|
|
if !srbc.deltaSeen && strings.Contains(data, "message_delta") {
|
|
srbc.deltaSeen = true
|
|
|
|
// Parse and inject usage
|
|
lines := strings.Split(data, "\n")
|
|
modifiedLines := make([]string, 0, len(lines))
|
|
|
|
for _, line := range lines {
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
modifiedLines = append(modifiedLines, line)
|
|
continue
|
|
}
|
|
|
|
jsonData := strings.TrimPrefix(line, "data: ")
|
|
if jsonData == "" {
|
|
modifiedLines = append(modifiedLines, line)
|
|
continue
|
|
}
|
|
|
|
var event map[string]interface{}
|
|
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
|
modifiedLines = append(modifiedLines, line)
|
|
continue
|
|
}
|
|
|
|
// Inject usage into message_delta event
|
|
if eventType, ok := event["type"].(string); ok && eventType == "message_delta" {
|
|
// Check if upstream API already provided usage - pass through if so
|
|
if existingUsage, ok := event["usage"].(map[string]interface{}); ok && len(existingUsage) > 0 {
|
|
log.Printf("Using upstream usage from message_delta: %+v", existingUsage)
|
|
modifiedLines = append(modifiedLines, line)
|
|
continue
|
|
}
|
|
|
|
// No upstream usage provided, inject proxy-counted values
|
|
event["usage"] = map[string]int{
|
|
"input_tokens": srbc.inputTokens,
|
|
"output_tokens": srbc.outputTokens,
|
|
"cache_creation_input_tokens": 0,
|
|
"cache_read_input_tokens": 0,
|
|
}
|
|
|
|
modifiedJSON, err := json.Marshal(event)
|
|
if err == nil {
|
|
modifiedLines = append(modifiedLines, "data: "+string(modifiedJSON))
|
|
log.Printf("Injected token usage into message_delta: input=%d, output=%d", srbc.inputTokens, srbc.outputTokens)
|
|
continue
|
|
}
|
|
}
|
|
|
|
modifiedLines = append(modifiedLines, line)
|
|
}
|
|
|
|
// Reconstruct the chunk with modifications
|
|
modifiedData := strings.Join(modifiedLines, "\n")
|
|
*n = copy(chunk, modifiedData)
|
|
if len(modifiedData) > len(chunk) {
|
|
// If modified data is larger, store the overflow in injectBuffer
|
|
srbc.injectBuffer = []byte(modifiedData[len(chunk):])
|
|
}
|
|
}
|
|
}
|
|
|
|
// countTokensInChunk extracts token usage from SSE events.
|
|
// message_start provides input + cache counts; message_delta provides output count.
|
|
// content_block_delta text is counted via tiktoken as a fallback for output.
|
|
func (srbc *StreamingResponseBodyCapture) countTokensInChunk(chunk []byte) {
|
|
lines := strings.Split(string(chunk), "\n")
|
|
for _, line := range lines {
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
jsonData := strings.TrimPrefix(line, "data: ")
|
|
if len(jsonData) == 0 {
|
|
continue
|
|
}
|
|
var event map[string]interface{}
|
|
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
|
|
continue
|
|
}
|
|
eventType, _ := event["type"].(string)
|
|
switch eventType {
|
|
case "message_start":
|
|
if msg, ok := event["message"].(map[string]interface{}); ok {
|
|
if u, ok := msg["usage"].(map[string]interface{}); ok {
|
|
srbc.usage.InputTokens = int(jsonFloat(u["input_tokens"]))
|
|
srbc.usage.CacheReadTokens = int(jsonFloat(u["cache_read_input_tokens"]))
|
|
srbc.usage.CacheWriteTokens = int(jsonFloat(u["cache_creation_input_tokens"]))
|
|
srbc.usage.FromAPI = true
|
|
}
|
|
}
|
|
case "message_delta":
|
|
if u, ok := event["usage"].(map[string]interface{}); ok {
|
|
if out := int(jsonFloat(u["output_tokens"])); out > 0 {
|
|
srbc.usage.OutputTokens = out
|
|
srbc.usage.FromAPI = true
|
|
}
|
|
}
|
|
case "content_block_delta":
|
|
if srbc.counter != nil {
|
|
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
|
if text, ok := delta["text"].(string); ok {
|
|
if tokens, err := srbc.counter.CountTokens(text); err == nil {
|
|
srbc.outputTokens += tokens
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetUsage returns API-reported token counts, falling back to tiktoken estimates
|
|
// for any values the API did not provide.
|
|
func (srbc *StreamingResponseBodyCapture) GetUsage() UsageData {
|
|
if srbc.usage.FromAPI {
|
|
result := srbc.usage
|
|
if result.OutputTokens == 0 && srbc.outputTokens > 0 {
|
|
result.OutputTokens = srbc.outputTokens
|
|
}
|
|
return result
|
|
}
|
|
return UsageData{
|
|
InputTokens: srbc.inputTokens,
|
|
OutputTokens: srbc.outputTokens,
|
|
}
|
|
}
|
|
|
|
// Close implements io.Closer
|
|
func (srbc *StreamingResponseBodyCapture) Close() error {
|
|
return srbc.originalBody.Close()
|
|
}
|
|
|
|
// GetOutputTokenCount returns the counted output tokens
|
|
func (srbc *StreamingResponseBodyCapture) GetOutputTokenCount() int {
|
|
return srbc.outputTokens
|
|
}
|