package main import ( "bytes" "context" "encoding/json" "io" "log" "net/http" "os" "strconv" "sync" "sync/atomic" "time" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/time/rate" ) var ( currentRequests int64 maxWorkersValue int64 = 10 // Default tokenCounter TokenCounter tokenizerModel string = "glm-4" // Default tokenizer model name for metrics deploymentVariant string = "production" // Default deployment variant // Build info — set via -ldflags at build time buildVersion string buildCommit string buildTimeStr string ) // AdaptiveRateLimiter manages rate limiting by tracking the upstream 429 ceiling // and holding just below it. Periodically probes above to detect ceiling shifts. type AdaptiveRateLimiter struct { limiter *rate.Limiter mu sync.RWMutex currentRate float64 minRate float64 maxRate float64 estimatedCeiling float64 // EWMA of the rate at which 429s occur ceilingSmoothAlpha float64 // EWMA smoothing factor (0-1, higher = more reactive) holdMargin float64 // Hold this fraction below ceiling (e.g., 0.02 = 2%) probeInterval int // Probe above ceiling every N clean windows cleanWindows int // Consecutive clean windows since last 429 lastAdjustment time.Time adjustmentWindow time.Duration recent429Count int64 recentSuccessCount int64 } func NewAdaptiveRateLimiter(initialRate, minRate, maxRate float64) *AdaptiveRateLimiter { return &AdaptiveRateLimiter{ limiter: rate.NewLimiter(rate.Limit(initialRate), int(initialRate*2)), currentRate: initialRate, minRate: minRate, maxRate: maxRate, estimatedCeiling: maxRate, // Assume max until we learn otherwise ceilingSmoothAlpha: 0.3, // 30% new observation, 70% history holdMargin: 0.02, // Hold 2% below estimated ceiling probeInterval: 10, // Probe every 10 clean windows (5 min at 30s windows) cleanWindows: 0, lastAdjustment: time.Now(), adjustmentWindow: 30 * time.Second, } } func (arl *AdaptiveRateLimiter) Wait(variant string) time.Duration { start := time.Now() // Protect access to limiter with read lock to prevent race with tryAdjustRate() arl.mu.RLock() limiter := arl.limiter arl.mu.RUnlock() limiter.Wait(context.Background()) waitTime := time.Since(start) rateLimitWaitTime.WithLabelValues(variant).Observe(waitTime.Seconds()) return waitTime } func (arl *AdaptiveRateLimiter) Record429() { atomic.AddInt64(&arl.recent429Count, 1) arl.tryAdjustRate() } func (arl *AdaptiveRateLimiter) RecordSuccess() { atomic.AddInt64(&arl.recentSuccessCount, 1) arl.tryAdjustRate() } func (arl *AdaptiveRateLimiter) tryAdjustRate() { arl.mu.Lock() defer arl.mu.Unlock() if time.Since(arl.lastAdjustment) < arl.adjustmentWindow { return } count429 := atomic.SwapInt64(&arl.recent429Count, 0) countSuccess := atomic.SwapInt64(&arl.recentSuccessCount, 0) total := count429 + countSuccess if total == 0 { return } error429Rate := float64(count429) / float64(total) newRate := arl.currentRate if error429Rate > 0.05 { // 429s detected — update ceiling estimate via EWMA oldCeiling := arl.estimatedCeiling arl.estimatedCeiling = arl.ceilingSmoothAlpha*arl.currentRate + (1-arl.ceilingSmoothAlpha)*arl.estimatedCeiling arl.cleanWindows = 0 // Drop to hold position: just below the updated ceiling newRate = arl.estimatedCeiling * (1 - arl.holdMargin) if newRate < arl.minRate { newRate = arl.minRate } log.Printf("Rate limit: Ceiling updated %.2f → %.2f req/s, holding at %.2f req/s (429 rate: %.2f%%)", oldCeiling, arl.estimatedCeiling, newRate, error429Rate*100) rateLimitAdjustments.WithLabelValues("decrease", deploymentVariant).Inc() } else if error429Rate < 0.01 { arl.cleanWindows++ targetRate := arl.estimatedCeiling * (1 - arl.holdMargin) if arl.cleanWindows >= arl.probeInterval && arl.currentRate < arl.maxRate { // Probe: the ceiling may have shifted up. Step above our hold point // to test for higher throughput. probeRate := arl.estimatedCeiling * (1 + arl.holdMargin) if probeRate > arl.maxRate { probeRate = arl.maxRate } newRate = probeRate arl.cleanWindows = 0 log.Printf("Rate limit: Probing ceiling at %.2f req/s (estimated ceiling: %.2f, clean windows: %d)", newRate, arl.estimatedCeiling, arl.probeInterval) rateLimitAdjustments.WithLabelValues("probe", deploymentVariant).Inc() } else if arl.currentRate < targetRate { // Below hold point — move toward it quickly gap := targetRate - arl.currentRate step := gap * 0.5 // Close half the gap each window if step < 0.25 { step = 0.25 } newRate = arl.currentRate + step if newRate > targetRate { newRate = targetRate } log.Printf("Rate limit: Converging to %.2f req/s (target: %.2f, ceiling: %.2f)", newRate, targetRate, arl.estimatedCeiling) rateLimitAdjustments.WithLabelValues("increase", deploymentVariant).Inc() } // At or above target with no 429s — hold steady, no log spam } if newRate != arl.currentRate { arl.currentRate = newRate arl.limiter.SetLimit(rate.Limit(newRate)) arl.limiter.SetBurst(int(newRate * 2)) rateLimitCurrentRate.WithLabelValues(deploymentVariant).Set(newRate) } arl.lastAdjustment = time.Now() } func (arl *AdaptiveRateLimiter) GetCurrentRate() float64 { arl.mu.RLock() defer arl.mu.RUnlock() return arl.currentRate } func (arl *AdaptiveRateLimiter) Reset(initialRate float64) { arl.mu.Lock() defer arl.mu.Unlock() arl.currentRate = initialRate arl.estimatedCeiling = initialRate arl.cleanWindows = 0 arl.limiter = rate.NewLimiter(rate.Limit(initialRate), int(initialRate*2)) arl.lastAdjustment = time.Now() atomic.StoreInt64(&arl.recent429Count, 0) atomic.StoreInt64(&arl.recentSuccessCount, 0) log.Printf("Rate limiter reset: rate=%.1f, ceiling=%.1f", arl.currentRate, arl.estimatedCeiling) } func updateUtilization() { current := atomic.LoadInt64(¤tRequests) max := atomic.LoadInt64(&maxWorkersValue) if max > 0 { utilization := float64(current) / float64(max) workerUtilization.WithLabelValues(deploymentVariant).Set(utilization) } } func main() { apiKey := os.Getenv("ZAI_API_KEY") if apiKey == "" { log.Fatal("ZAI_API_KEY environment variable required") } // Read deployment variant from environment deploymentVariant = os.Getenv("DEPLOYMENT_VARIANT") if deploymentVariant == "" { deploymentVariant = "production" log.Printf("DEPLOYMENT_VARIANT not set, defaulting to: %s", deploymentVariant) } else { log.Printf("Deployment variant: %s", deploymentVariant) } // Read build info — prefer ldflags, fall back to env vars, then "unknown" version := buildVersion if version == "" { version = os.Getenv("ZAI_PROXY_VERSION") } if version == "" { version = "unknown" } commit := buildCommit if commit == "" { commit = os.Getenv("ZAI_PROXY_COMMIT") } if commit == "" { commit = "unknown" } buildTime := buildTimeStr if buildTime == "" { buildTime = os.Getenv("ZAI_PROXY_BUILD_TIME") } if buildTime == "" { buildTime = "unknown" } // Set build info metric buildInfo.WithLabelValues(version, deploymentVariant, commit, buildTime).Set(1) log.Printf("Build info: version=%s, variant=%s, commit=%s, build_time=%s", version, deploymentVariant, commit, buildTime) // Read tokenizer configuration from environment // // TOKEN_COUNTING_ENABLED: Enable/disable token counting (default: true) // Set to "false" or "0" to disable token counting entirely. // When disabled, no token metrics are collected and tokenizer is not initialized. // // TOKENIZER_MODEL: Model name for Prometheus metrics labels (default: glm-4) // Used to tag token count metrics in Prometheus (e.g., glm-4, claude-3, etc.) // This is purely for metrics labeling and does not affect tokenization algorithm. tokenCountingEnabled := true if val := os.Getenv("TOKEN_COUNTING_ENABLED"); val != "" { if val == "false" || val == "0" { tokenCountingEnabled = false } } if val := os.Getenv("TOKENIZER_MODEL"); val != "" { tokenizerModel = val } // Initialize tokenizer with tiktoken cl100k_base encoding if tokenCountingEnabled { tikTokenCounter, err := NewTikTokenCounter() if err != nil { log.Printf("Warning: Failed to initialize TikToken counter: %v", err) log.Println("Falling back to SimpleTokenCounter") tokenCounter = NewSimpleTokenCounter() log.Printf("Token counting enabled (fallback mode, model: %s)", tokenizerModel) } else { tokenCounter = tikTokenCounter log.Printf("Token counting enabled (tiktoken cl100k_base encoding, model: %s)", tokenizerModel) } } else { log.Println("Token counting disabled (TOKEN_COUNTING_ENABLED=false)") tokenCounter = nil } // Read max workers from environment if maxWorkersEnv := os.Getenv("MAX_WORKERS"); maxWorkersEnv != "" { if val, err := strconv.ParseInt(maxWorkersEnv, 10, 64); err == nil && val > 0 { atomic.StoreInt64(&maxWorkersValue, val) maxWorkers.WithLabelValues(deploymentVariant).Set(float64(val)) log.Printf("Max workers set to: %d", val) } } else { maxWorkers.WithLabelValues(deploymentVariant).Set(float64(maxWorkersValue)) log.Printf("Max workers defaulting to: %d", maxWorkersValue) } // Initialize adaptive rate limiter initialRate := 10.0 // 10 req/s minRate := 1.0 // 1 req/s minimum maxRate := 50.0 // 50 req/s maximum // Read rate limit config from environment if val := os.Getenv("RATE_LIMIT_INITIAL"); val != "" { if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 { initialRate = parsed } } if val := os.Getenv("RATE_LIMIT_MIN"); val != "" { if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 { minRate = parsed } } if val := os.Getenv("RATE_LIMIT_MAX"); val != "" { if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 { maxRate = parsed } } rateLimiter := NewAdaptiveRateLimiter(initialRate, minRate, maxRate) if val := os.Getenv("RATE_LIMIT_CEILING_ALPHA"); val != "" { if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 && parsed <= 1 { rateLimiter.ceilingSmoothAlpha = parsed } } if val := os.Getenv("RATE_LIMIT_HOLD_MARGIN"); val != "" { if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 && parsed < 1 { rateLimiter.holdMargin = parsed } } if val := os.Getenv("RATE_LIMIT_PROBE_INTERVAL"); val != "" { if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 { rateLimiter.probeInterval = parsed } } rateLimitCurrentRate.WithLabelValues(deploymentVariant).Set(initialRate) log.Printf("Adaptive rate limiting: initial=%.1f, min=%.1f, max=%.1f req/s (ceiling alpha=%.2f, margin=%.1f%%, probe every %d windows)", initialRate, minRate, maxRate, rateLimiter.ceilingSmoothAlpha, rateLimiter.holdMargin*100, rateLimiter.probeInterval) // Retry configuration maxRetries := 3 if val := os.Getenv("MAX_RETRIES"); val != "" { if parsed, err := strconv.Atoi(val); err == nil && parsed >= 0 { maxRetries = parsed } } target := "https://api.z.ai/api/anthropic" if val := os.Getenv("ZAI_TARGET_URL"); val != "" { target = val } client := &http.Client{ Timeout: 5 * time.Minute, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } // Metrics endpoint http.Handle("/metrics", promhttp.Handler()) // Health endpoint http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) }) // Admin: reset rate limiter http.HandleFunc("/admin/reset-rate-limit", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } rateLimiter.Reset(initialRate) json.NewEncoder(w).Encode(map[string]interface{}{ "status": "reset", "current_rate": rateLimiter.GetCurrentRate(), }) }) // Proxy handler with adaptive rate limiting http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Increment concurrent requests current := atomic.AddInt64(¤tRequests, 1) concurrentRequests.WithLabelValues(deploymentVariant).Set(float64(current)) updateUtilization() defer func() { current := atomic.AddInt64(¤tRequests, -1) concurrentRequests.WithLabelValues(deploymentVariant).Set(float64(current)) updateUtilization() }() // Check if we're at max capacity max := atomic.LoadInt64(&maxWorkersValue) if current > max { log.Printf("Max workers exceeded: %d/%d", current, max) requestsTotal.WithLabelValues(r.Method, r.URL.Path, "503", deploymentVariant).Inc() http.Error(w, "Service at capacity", http.StatusServiceUnavailable) return } // Apply rate limiting rateLimiter.Wait(deploymentVariant) // Track request size if r.ContentLength > 0 { requestSize.WithLabelValues(r.Method, r.URL.Path, deploymentVariant).Observe(float64(r.ContentLength)) } // Capture request body (always, for translation + optional token counting). var requestBody []byte if r.Body != nil { var buf bytes.Buffer if _, err := io.Copy(&buf, r.Body); err != nil { log.Printf("Error reading request body: %v", err) } else { requestBody = buf.Bytes() } r.Body.Close() } // Extract model name from request body for metrics labels. reqModel := tokenizerModel if len(requestBody) > 0 { var rb RequestBody if err := json.Unmarshal(requestBody, &rb); err == nil && rb.Model != "" { reqModel = rb.Model } } // Count input tokens if enabled. var inputTokens int if tokenCounter != nil && len(requestBody) > 0 { countStart := time.Now() inputTokens, _ = CountRequestTokens(requestBody, tokenCounter) countDuration := time.Since(countStart) tokenCountDuration.WithLabelValues(deploymentVariant).Observe(countDuration.Seconds()) // Input tokens recorded after response completes via RecordUsage. } // Translate request body: strip Anthropic API fields ZhipuAI doesn't support. translatedBody := requestBody if len(requestBody) > 0 { if translated, changed, err := TranslateRequest(requestBody); err != nil { log.Printf("Warning: failed to translate request body: %v", err) } else if changed { translatedBody = translated } } // Retry logic for transient errors var lastErr error var resp *http.Response var validatedBody []byte // Pre-validated non-streaming body (read inside retry loop) var streamingPeek []byte // First bytes peeked from streaming response (read inside retry loop) for attempt := 0; attempt <= maxRetries; attempt++ { if attempt > 0 { // Exponential backoff: 1s, 2s, 4s, etc. backoffDuration := time.Duration(1< 0 { reqBodyReader = bytes.NewReader(translatedBody) } upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, reqBodyReader) if err != nil { log.Printf("Error creating request: %v", err) upstreamErrors.WithLabelValues("request_creation", deploymentVariant).Inc() requestsTotal.WithLabelValues(r.Method, r.URL.Path, "400", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "400", deploymentVariant).Observe(time.Since(start).Seconds()) http.Error(w, "Bad request", http.StatusBadRequest) return } for key, values := range r.Header { for _, value := range values { upstreamReq.Header.Add(key, value) } } upstreamReq.Header.Set("Host", "api.z.ai") upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) // Disable gzip so the proxy can parse/modify the response body upstreamReq.Header.Set("Accept-Encoding", "identity") resp, err = client.Do(upstreamReq) if err != nil { lastErr = err log.Printf("Error forwarding request (attempt %d/%d): %v", attempt+1, maxRetries+1, err) upstreamErrors.WithLabelValues("upstream_connection", deploymentVariant).Inc() // Retry on network errors if attempt < maxRetries { retryAttempts.WithLabelValues("network_error", deploymentVariant).Inc() continue } requestsTotal.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Observe(time.Since(start).Seconds()) http.Error(w, "Upstream error", http.StatusBadGateway) return } // Handle 429 Rate Limit if resp.StatusCode == 429 { resp.Body.Close() rateLimiter.Record429() // Check Retry-After header retryAfter := resp.Header.Get("Retry-After") if retryAfter != "" { if seconds, err := strconv.Atoi(retryAfter); err == nil { log.Printf("429 Rate Limited, retry after %d seconds", seconds) time.Sleep(time.Duration(seconds) * time.Second) } } if attempt < maxRetries { log.Printf("429 Rate Limited, retrying (attempt %d/%d)", attempt+1, maxRetries+1) retryAttempts.WithLabelValues("429", deploymentVariant).Inc() continue } // Exceeded max retries, return 429 to client log.Printf("429 Rate Limited, max retries exceeded") requestsTotal.WithLabelValues(r.Method, r.URL.Path, "429", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "429", deploymentVariant).Observe(time.Since(start).Seconds()) http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } // Handle 422 Unprocessable Entity — log full bodies for diagnosis, // then return a clear error to the client so callers can fail fast. // 422s are not retried: they indicate a structural problem with the // request body that retrying won't fix. if resp.StatusCode == 422 { respBody, _ := io.ReadAll(resp.Body) resp.Body.Close() log.Printf("422 from upstream — request body: %s", string(requestBody)) log.Printf("422 from upstream — response body: %s", string(respBody)) upstreamErrors.WithLabelValues("422", deploymentVariant).Inc() requestsTotal.WithLabelValues(r.Method, r.URL.Path, "422", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "422", deploymentVariant).Observe(time.Since(start).Seconds()) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnprocessableEntity) w.Write(respBody) return } // Validate response body before committing to the client. // Z.AI occasionally returns HTTP 200 with empty or truncated JSON. if resp.StatusCode >= 200 && resp.StatusCode < 300 { if !IsStreamingRequest(requestBody) { // Non-streaming: validate entire body bodyBytes, _ := io.ReadAll(resp.Body) resp.Body.Close() if len(bodyBytes) == 0 || !json.Valid(bodyBytes) { log.Printf("Malformed response from upstream (empty=%v, size=%d), retrying (attempt %d/%d)", len(bodyBytes) == 0, len(bodyBytes), attempt+1, maxRetries+1) upstreamErrors.WithLabelValues("truncated_response", deploymentVariant).Inc() if attempt < maxRetries { retryAttempts.WithLabelValues("truncated_response", deploymentVariant).Inc() continue } log.Printf("Malformed response from upstream, max retries exceeded - returning 502") requestsTotal.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Observe(time.Since(start).Seconds()) http.Error(w, "Upstream returned empty or malformed response after retries", http.StatusBadGateway) return } validatedBody = bodyBytes } else { // Streaming: peek at first chunk to confirm the response has data peekBuf := make([]byte, 4096) n, _ := resp.Body.Read(peekBuf) if n == 0 { resp.Body.Close() log.Printf("Empty streaming response from upstream, retrying (attempt %d/%d)", attempt+1, maxRetries+1) upstreamErrors.WithLabelValues("empty_streaming", deploymentVariant).Inc() if attempt < maxRetries { retryAttempts.WithLabelValues("empty_streaming", deploymentVariant).Inc() continue } log.Printf("Empty streaming response, max retries exceeded - returning 502") requestsTotal.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Observe(time.Since(start).Seconds()) http.Error(w, "Upstream returned empty streaming response after retries", http.StatusBadGateway) return } streamingPeek = peekBuf[:n] } } // Success or non-retryable error break } if resp == nil { log.Printf("All retry attempts failed: %v", lastErr) requestsTotal.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, "502", deploymentVariant).Observe(time.Since(start).Seconds()) http.Error(w, "Upstream error after retries", http.StatusBadGateway) return } statusCode := strconv.Itoa(resp.StatusCode) // Record success for rate limiter adaptation if resp.StatusCode >= 200 && resp.StatusCode < 300 { rateLimiter.RecordSuccess() } for key, values := range resp.Header { for _, value := range values { w.Header().Add(key, value) } } // Remove Content-Length: proxy may modify body (usage injection), causing overrun w.Header().Del("Content-Length") // Declare trailer headers for token usage (will be sent after response body) if tokenCounter != nil && inputTokens > 0 { w.Header().Add("Trailer", "X-Token-Output") w.Header().Add("Trailer", "X-Token-Total") // Set input token count in initial headers (we know this upfront) w.Header().Set("X-Token-Input", strconv.Itoa(inputTokens)) } w.WriteHeader(resp.StatusCode) var bytesWritten int64 // Use token counting and injection if enabled, otherwise direct copy if tokenCounter != nil { // For streaming responses, we need to inject usage into the message_delta event // Check if this is a streaming request by checking the request body isStreaming := false if len(requestBody) > 0 { var req RequestBody if err := json.Unmarshal(requestBody, &req); err == nil { isStreaming = req.Stream } } if isStreaming { // Streaming response: capture, count, and inject usage into message_delta // If we peeked at the first chunk during validation, prepend it var bodyReader io.Reader = resp.Body if len(streamingPeek) > 0 { bodyReader = io.MultiReader(bytes.NewReader(streamingPeek), resp.Body) } bodyCapture := NewStreamingResponseBodyCapture(io.NopCloser(bodyReader), tokenCounter, inputTokens) defer bodyCapture.Close() buf := make([]byte, 1024) flusher, canFlush := w.(http.Flusher) for { n, readErr := bodyCapture.Read(buf) if n > 0 { written, writeErr := w.Write(buf[:n]) bytesWritten += int64(written) if writeErr != nil { log.Printf("Error writing response: %v", writeErr) upstreamErrors.WithLabelValues("write_error", deploymentVariant).Inc() return } if canFlush { flusher.Flush() } } if readErr == io.EOF { break } if readErr != nil { log.Printf("Error reading response: %v", readErr) upstreamErrors.WithLabelValues("read_error", deploymentVariant).Inc() return } } // Record token counts from API response (or tiktoken fallback). usage := bodyCapture.GetUsage() RecordUsage(reqModel, deploymentVariant, usage) log.Printf("Token usage (stream, fromAPI=%v): input=%d output=%d cache_read=%d cache_write=%d", usage.FromAPI, usage.InputTokens, usage.OutputTokens, usage.CacheReadTokens, usage.CacheWriteTokens) } else { // Non-streaming: capture body, count tokens, wrap with usage, then send. // If body was pre-read during truncation validation, reuse it. var bodySource io.ReadCloser if len(validatedBody) > 0 { bodySource = io.NopCloser(bytes.NewReader(validatedBody)) } else { bodySource = resp.Body } bodyCapture := NewResponseBodyCapture(bodySource, tokenCounter) defer bodyCapture.Close() bodyBytes, readErr := io.ReadAll(bodyCapture) if readErr != nil { log.Printf("Error reading response: %v", readErr) upstreamErrors.WithLabelValues("read_error", deploymentVariant).Inc() return } // Prefer API-reported usage from the response body. if usage, ok := ExtractUsageFromJSON(bodyBytes); ok { RecordUsage(reqModel, deploymentVariant, usage) log.Printf("Token usage (API): input=%d output=%d cache_read=%d cache_write=%d", usage.InputTokens, usage.OutputTokens, usage.CacheReadTokens, usage.CacheWriteTokens) written, writeErr := w.Write(bodyBytes) bytesWritten += int64(written) if writeErr != nil { log.Printf("Error writing response: %v", writeErr) upstreamErrors.WithLabelValues("write_error", deploymentVariant).Inc() return } if flusher, canFlush := w.(http.Flusher); canFlush { flusher.Flush() } } else { // Tiktoken fallback: estimate and wrap. countStart := time.Now() outputTokens, err := bodyCapture.CountOutputTokens() countDuration := time.Since(countStart) tokenCountDuration.WithLabelValues(deploymentVariant).Observe(countDuration.Seconds()) if err != nil { log.Printf("Warning: failed to count output tokens: %v", err) } RecordUsage(reqModel, deploymentVariant, UsageData{InputTokens: inputTokens, OutputTokens: outputTokens}) RecordOutputTokenRate(tokenizerModel, deploymentVariant, countDuration, outputTokens) log.Printf("Token usage (estimated): input=%d output=%d", inputTokens, outputTokens) wrappedResp, wrapErr := WrapResponseWithUsage(bodyBytes, inputTokens, outputTokens) if wrapErr != nil { log.Printf("Warning: failed to wrap response with usage, sending original: %v", wrapErr) wrappedResp = bodyBytes } written, writeErr := w.Write(wrappedResp) bytesWritten += int64(written) if writeErr != nil { log.Printf("Error writing response: %v", writeErr) upstreamErrors.WithLabelValues("write_error", deploymentVariant).Inc() return } if flusher, canFlush := w.(http.Flusher); canFlush { flusher.Flush() } } } } else { // Token counting disabled, direct streaming defer resp.Body.Close() buf := make([]byte, 1024) flusher, canFlush := w.(http.Flusher) for { n, err := resp.Body.Read(buf) if n > 0 { written, writeErr := w.Write(buf[:n]) bytesWritten += int64(written) if writeErr != nil { log.Printf("Error writing response: %v", writeErr) upstreamErrors.WithLabelValues("write_error", deploymentVariant).Inc() return } if canFlush { flusher.Flush() } } if err == io.EOF { break } if err != nil { log.Printf("Error reading response: %v", err) upstreamErrors.WithLabelValues("read_error", deploymentVariant).Inc() return } } } // Record metrics duration := time.Since(start).Seconds() requestsTotal.WithLabelValues(r.Method, r.URL.Path, statusCode, deploymentVariant).Inc() requestDuration.WithLabelValues(r.Method, r.URL.Path, statusCode, deploymentVariant).Observe(duration) responseSize.WithLabelValues(r.Method, r.URL.Path, statusCode, deploymentVariant).Observe(float64(bytesWritten)) }) log.Println("Z.AI proxy listening on :8080") log.Println("Metrics available at :8080/metrics") log.Fatal(http.ListenAndServe(":8080", nil)) }