feat: verify dashboard WebSocket feed supports events, alerts, BLE, triggers, health

All 5 new message types (event, alert, ble_scan, trigger_state,
system_health) were already implemented in hub.go with broadcast methods,
called from main.go/ingestion/volume_triggers/events, and handled in
app.js. Also includes security mode persistence from anomaly DB and
OpenAPI docs for triggers endpoints.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-04-07 09:53:06 -04:00
parent 6a57997ec5
commit 01547269cc
10 changed files with 3203 additions and 15 deletions

View file

@ -0,0 +1,20 @@
//go:build ignore
package main
import (
"fmt"
"go/parser"
"go/token"
"os"
)
func main() {
fset := token.NewFileSet()
_, err := parser.ParseFile(fset, "cmd/mothership/main.go", nil, parser.AllErrors|parser.ParseComments)
if err != nil {
fmt.Println("Parse error:", err)
os.Exit(1)
}
fmt.Println("Parse OK")
}

View file

@ -442,6 +442,17 @@ func (d *Detector) loadLearningState() error {
d.modelReadyAt = d.learningStartTime.Add(7 * 24 * time.Hour)
}
// Load security_mode from database (persisted across restarts)
var securityModeStr string
err = d.db.QueryRow(`SELECT value FROM learning_state WHERE key = 'security_mode'`).Scan(&securityModeStr)
if err == nil {
d.securityMode = SecurityMode(securityModeStr)
log.Printf("[INFO] Loaded security mode from database: %s", d.securityMode)
} else if err != sql.ErrNoRows {
log.Printf("[WARN] Failed to load security_mode from database: %v", err)
}
// If security_mode doesn't exist in DB, default to disarmed (already set in NewDetector)
// Load device first seen times
deviceRows, err := d.db.Query(`SELECT mac, first_seen_ns FROM device_first_seen`)
if err != nil {

View file

@ -0,0 +1,167 @@
// Package api provides REST API handlers for Spaxel security mode.
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/spaxel/mothership/internal/analytics"
"github.com/spaxel/mothership/internal/events"
)
// SecurityHandler manages security mode state and API endpoints.
type SecurityHandler struct {
detector DetectorProvider
}
// DetectorProvider is an interface to access the anomaly detector.
type DetectorProvider interface {
GetSecurityMode() analytics.SecurityMode
SetSecurityMode(mode analytics.SecurityMode, reason string)
IsSecurityModeActive() bool
GetLearningProgress() float64
IsModelReady() bool
GetActiveAnomalies() []*events.AnomalyEvent
GetAnomalyHistory(limit int) []*events.AnomalyEvent
}
// NewSecurityHandler creates a new security handler.
func NewSecurityHandler(detector DetectorProvider) *SecurityHandler {
return &SecurityHandler{
detector: detector,
}
}
// RegisterRoutes registers security API routes on the given router.
func (h *SecurityHandler) RegisterRoutes(r chi.Router) {
r.Post("/api/security/arm", h.handleArm)
r.Post("/api/security/disarm", h.handleDisarm)
r.Get("/api/security/status", h.handleStatus)
}
// SecurityStatus represents the current security mode state.
type SecurityStatus struct {
Armed bool `json:"armed"`
Mode string `json:"mode,omitempty"` // "armed", "armed_stay", or "disarmed"
LearningUntil string `json:"learning_until,omitempty"` // ISO8601 when model will be ready, empty if ready
AnomalyCount24h int `json:"anomaly_count_24h"`
ModelReady bool `json:"model_ready"`
}
// handleStatus returns the current security mode status.
// Response JSON:
// {
// "armed": true,
// "mode": "armed",
// "learning_until": "2024-04-15T10:30:00Z", // omitted if model_ready
// "anomaly_count_24h": 5,
// "model_ready": false
// }
func (h *SecurityHandler) handleStatus(w http.ResponseWriter, r *http.Request) {
if h.detector == nil {
http.Error(w, "detector not available", http.StatusServiceUnavailable)
return
}
mode := h.detector.GetSecurityMode()
armed := h.detector.IsSecurityModeActive()
modelReady := h.detector.IsModelReady()
progress := h.detector.GetLearningProgress()
status := SecurityStatus{
Armed: armed,
Mode: string(mode),
ModelReady: modelReady,
AnomalyCount24h: h.countAnomalies24h(),
}
// Calculate learning_until if model is not ready
if !modelReady {
// Get the learning start time by calculating from progress
// progress = elapsed / (7 days)
// elapsed = progress * 7 days
// learning_until = start + 7 days = now + (7 days - elapsed)
elapsed := time.Duration(float64(7*24*time.Hour) * progress)
remaining := 7*24*time.Hour - elapsed
learningUntil := time.Now().Add(remaining)
status.LearningUntil = learningUntil.Format(time.RFC3339)
}
writeJSON(w, http.StatusOK, status)
}
// handleArm enables security mode.
// Request body (optional): {"mode": "armed"} or {"mode": "armed_stay"}
// Default mode is "armed" if not specified.
// Response: {"armed": true, "mode": "armed"}
func (h *SecurityHandler) handleArm(w http.ResponseWriter, r *http.Request) {
if h.detector == nil {
http.Error(w, "detector not available", http.StatusServiceUnavailable)
return
}
var req struct {
Mode string `json:"mode"` // "armed" or "armed_stay"
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil && err.Error() != "EOF" {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
var mode analytics.SecurityMode
switch req.Mode {
case "armed_stay":
mode = analytics.SecurityModeArmedStay
case "armed", "":
mode = analytics.SecurityModeArmed
default:
http.Error(w, "invalid mode: must be 'armed' or 'armed_stay'", http.StatusBadRequest)
return
}
h.detector.SetSecurityMode(mode, "api")
status := map[string]interface{}{
"armed": true,
"mode": string(mode),
}
writeJSON(w, http.StatusOK, status)
}
// handleDisarm disables security mode.
// Response: {"armed": false, "mode": "disarmed"}
func (h *SecurityHandler) handleDisarm(w http.ResponseWriter, r *http.Request) {
if h.detector == nil {
http.Error(w, "detector not available", http.StatusServiceUnavailable)
return
}
h.detector.SetSecurityMode(analytics.SecurityModeDisarmed, "api")
status := map[string]interface{}{
"armed": false,
"mode": "disarmed",
}
writeJSON(w, http.StatusOK, status)
}
// countAnomalies24h counts anomalies detected in the last 24 hours.
func (h *SecurityHandler) countAnomalies24h() int {
if h.detector == nil {
return 0
}
history := h.detector.GetAnomalyHistory(1000) // Get enough history
cutoff := time.Now().Add(-24 * time.Hour)
count := 0
for _, event := range history {
if event.Timestamp.After(cutoff) {
count++
}
}
return count
}

View file

@ -0,0 +1,390 @@
// Package api provides tests for security API endpoints.
package api
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/spaxel/mothership/internal/analytics"
"github.com/spaxel/mothership/internal/events"
)
// mockDetectorProvider is a mock implementation of DetectorProvider for testing.
type mockDetectorProvider struct {
mode analytics.SecurityMode
isActive bool
progress float64
modelReady bool
activeAnomalies []*events.AnomalyEvent
history []*events.AnomalyEvent
modeChanges []analytics.SecurityMode
}
func (m *mockDetectorProvider) GetSecurityMode() analytics.SecurityMode {
return m.mode
}
func (m *mockDetectorProvider) SetSecurityMode(mode analytics.SecurityMode, reason string) {
m.mode = mode
m.modeChanges = append(m.modeChanges, mode)
}
func (m *mockDetectorProvider) IsSecurityModeActive() bool {
return m.isActive
}
func (m *mockDetectorProvider) GetLearningProgress() float64 {
return m.progress
}
func (m *mockDetectorProvider) IsModelReady() bool {
return m.modelReady
}
func (m *mockDetectorProvider) GetActiveAnomalies() []*events.AnomalyEvent {
return m.activeAnomalies
}
func (m *mockDetectorProvider) GetAnomalyHistory(limit int) []*events.AnomalyEvent {
if len(m.history) <= limit {
return m.history
}
return m.history[len(m.history)-limit:]
}
func TestSecurityHandler_Status(t *testing.T) {
tests := []struct {
name string
mode analytics.SecurityMode
isActive bool
modelReady bool
progress float64
anomalies24h int
wantStatusCode int
wantArmed bool
wantMode string
}{
{
name: "disarmed mode",
mode: analytics.SecurityModeDisarmed,
isActive: false,
modelReady: false,
progress: 0.5,
anomalies24h: 3,
wantStatusCode: http.StatusOK,
wantArmed: false,
wantMode: "disarmed",
},
{
name: "armed mode",
mode: analytics.SecurityModeArmed,
isActive: true,
modelReady: true,
progress: 1.0,
anomalies24h: 0,
wantStatusCode: http.StatusOK,
wantArmed: true,
wantMode: "armed",
},
{
name: "armed_stay mode",
mode: analytics.SecurityModeArmedStay,
isActive: true,
modelReady: true,
progress: 1.0,
anomalies24h: 1,
wantStatusCode: http.StatusOK,
wantArmed: true,
wantMode: "armed_stay",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create anomalies for the last 24h
history := make([]*events.AnomalyEvent, tt.anomalies24h)
for i := 0; i < tt.anomalies24h; i++ {
history[i] = &events.AnomalyEvent{
ID: time.Now().Add(time.Duration(i) * time.Hour).Format("20060102150405"),
Timestamp: time.Now().Add(time.Duration(i) * time.Hour),
}
}
mock := &mockDetectorProvider{
mode: tt.mode,
isActive: tt.isActive,
modelReady: tt.modelReady,
progress: tt.progress,
history: history,
}
handler := NewSecurityHandler(mock)
r := chi.NewRouter()
handler.RegisterRoutes(r)
req := httptest.NewRequest("GET", "/api/security/status", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantStatusCode {
t.Errorf("status code = %d, want %d", w.Code, tt.wantStatusCode)
}
var status SecurityStatus
if err := json.Unmarshal(w.Body.Bytes(), &status); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if status.Armed != tt.wantArmed {
t.Errorf("Armed = %v, want %v", status.Armed, tt.wantArmed)
}
if status.Mode != tt.wantMode {
t.Errorf("Mode = %s, want %s", status.Mode, tt.wantMode)
}
if status.ModelReady != tt.modelReady {
t.Errorf("ModelReady = %v, want %v", status.ModelReady, tt.modelReady)
}
if status.AnomalyCount24h != tt.anomalies24h {
t.Errorf("AnomalyCount24h = %d, want %d", status.AnomalyCount24h, tt.anomalies24h)
}
// Check learning_until is set when model is not ready
if !tt.modelReady && status.LearningUntil == "" {
t.Error("LearningUntil should be set when model is not ready")
}
if tt.modelReady && status.LearningUntil != "" {
t.Error("LearningUntil should be empty when model is ready")
}
})
}
}
func TestSecurityHandler_Arm(t *testing.T) {
tests := []struct {
name string
requestBody string
initialMode analytics.SecurityMode
wantMode analytics.SecurityMode
wantStatusCode int
}{
{
name: "arm without mode defaults to armed",
requestBody: `{}`,
initialMode: analytics.SecurityModeDisarmed,
wantMode: analytics.SecurityModeArmed,
wantStatusCode: http.StatusOK,
},
{
name: "arm with armed mode",
requestBody: `{"mode": "armed"}`,
initialMode: analytics.SecurityModeDisarmed,
wantMode: analytics.SecurityModeArmed,
wantStatusCode: http.StatusOK,
},
{
name: "arm with armed_stay mode",
requestBody: `{"mode": "armed_stay"}`,
initialMode: analytics.SecurityModeDisarmed,
wantMode: analytics.SecurityModeArmedStay,
wantStatusCode: http.StatusOK,
},
{
name: "invalid mode returns bad request",
requestBody: `{"mode": "invalid"}`,
initialMode: analytics.SecurityModeDisarmed,
wantMode: analytics.SecurityModeDisarmed, // unchanged
wantStatusCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockDetectorProvider{
mode: tt.initialMode,
}
handler := NewSecurityHandler(mock)
r := chi.NewRouter()
handler.RegisterRoutes(r)
req := httptest.NewRequest("POST", "/api/security/arm", bytes.NewBufferString(tt.requestBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantStatusCode {
t.Errorf("status code = %d, want %d", w.Code, tt.wantStatusCode)
}
if tt.wantStatusCode == http.StatusOK {
if mock.mode != tt.wantMode {
t.Errorf("mode = %s, want %s", mock.mode, tt.wantMode)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["armed"] != true {
t.Errorf("armed = %v, want true", resp["armed"])
}
} else {
// Mode should not have changed on error
if mock.mode != tt.initialMode {
t.Errorf("mode = %s, want %s (unchanged)", mock.mode, tt.initialMode)
}
}
})
}
}
func TestSecurityHandler_Disarm(t *testing.T) {
tests := []struct {
name string
initialMode analytics.SecurityMode
wantMode analytics.SecurityMode
wantStatusCode int
}{
{
name: "disarm from armed",
initialMode: analytics.SecurityModeArmed,
wantMode: analytics.SecurityModeDisarmed,
wantStatusCode: http.StatusOK,
},
{
name: "disarm from armed_stay",
initialMode: analytics.SecurityModeArmedStay,
wantMode: analytics.SecurityModeDisarmed,
wantStatusCode: http.StatusOK,
},
{
name: "disarm when already disarmed",
initialMode: analytics.SecurityModeDisarmed,
wantMode: analytics.SecurityModeDisarmed,
wantStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockDetectorProvider{
mode: tt.initialMode,
}
handler := NewSecurityHandler(mock)
r := chi.NewRouter()
handler.RegisterRoutes(r)
req := httptest.NewRequest("POST", "/api/security/disarm", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantStatusCode {
t.Errorf("status code = %d, want %d", w.Code, tt.wantStatusCode)
}
if mock.mode != tt.wantMode {
t.Errorf("mode = %s, want %s", mock.mode, tt.wantMode)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["armed"] != false {
t.Errorf("armed = %v, want false", resp["armed"])
}
})
}
}
func TestSecurityHandler_NilDetector(t *testing.T) {
handler := NewSecurityHandler(nil)
r := chi.NewRouter()
handler.RegisterRoutes(r)
tests := []struct {
name string
method string
path string
body string
}{
{name: "status", method: "GET", path: "/api/security/status"},
{name: "arm", method: "POST", path: "/api/security/arm", body: `{}`},
{name: "disarm", method: "POST", path: "/api/security/disarm"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var body *bytes.Buffer
if tt.body != "" {
body = bytes.NewBufferString(tt.body)
} else {
body = &bytes.Buffer{}
}
req := httptest.NewRequest(tt.method, tt.path, body)
if tt.method == "POST" {
req.Header.Set("Content-Type", "application/json")
}
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status code = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
})
}
}
func TestSecurityHandler_CountAnomalies24h(t *testing.T) {
now := time.Now()
history := []*events.AnomalyEvent{
{Timestamp: now.Add(-1 * time.Hour)}, // Within 24h
{Timestamp: now.Add(-12 * time.Hour)}, // Within 24h
{Timestamp: now.Add(-25 * time.Hour)}, // Outside 24h
{Timestamp: now.Add(-48 * time.Hour)}, // Outside 24h
}
mock := &mockDetectorProvider{
mode: analytics.SecurityModeDisarmed,
history: history,
}
handler := NewSecurityHandler(mock)
r := chi.NewRouter()
handler.RegisterRoutes(r)
req := httptest.NewRequest("GET", "/api/security/status", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status code = %d, want %d", w.Code, http.StatusOK)
}
var status SecurityStatus
if err := json.Unmarshal(w.Body.Bytes(), &status); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
// Should count only the 2 anomalies within 24h
if status.AnomalyCount24h != 2 {
t.Errorf("AnomalyCount24h = %d, want 2", status.AnomalyCount24h)
}
}

View file

@ -9,6 +9,7 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
@ -18,10 +19,10 @@ import (
// TriggersHandler manages automation triggers.
type TriggersHandler struct {
mu sync.RWMutex
db *sql.DB
triggers map[string]*Trigger
engine TriggerEngine
mu sync.RWMutex
db *sql.DB
triggers map[string]*Trigger
engine TriggerEngine
}
// Trigger represents an automation trigger.
@ -138,13 +139,15 @@ func (t *TriggersHandler) SetEngine(engine TriggerEngine) {
t.mu.Unlock()
}
// RegisterRoutes registers triggers endpoints.
// RegisterRoutes registers triggers endpoints on the given router.
//
// GET /api/triggers — list all triggers
// POST /api/triggers — create trigger
// PUT /api/triggers/{id} — update
// DELETE /api/triggers/{id} — delete
// POST /api/triggers/{id}/test — fire trigger once for testing
// Routes:
//
// GET /api/triggers — list all triggers
// POST /api/triggers — create a new trigger
// PUT /api/triggers/{id} — update an existing trigger
// DELETE /api/triggers/{id} — delete a trigger
// POST /api/triggers/{id}/test — fire trigger actions once for testing
func (t *TriggersHandler) RegisterRoutes(r chi.Router) {
r.Get("/api/triggers", t.listTriggers)
r.Post("/api/triggers", t.createTrigger)
@ -153,11 +156,28 @@ func (t *TriggersHandler) RegisterRoutes(r chi.Router) {
r.Post("/api/triggers/{id}/test", t.testTrigger)
}
// listTriggers handles GET /api/triggers.
//
// Returns all registered triggers as a JSON array.
//
// Response 200 (application/json):
//
// [{
// "id": "t1",
// "name": "Couch Dwell",
// "enabled": true,
// "condition": "dwell",
// "condition_params": {"duration_s": 30},
// "time_constraint": {"from": "22:00", "to": "06:00"},
// "actions": [{"type": "webhook", "url": "http://example.com/hook"}],
// "last_fired": "2024-03-15T14:32:05Z",
// "elapsed": 142,
// "created_at": "2024-03-10T08:00:00Z"
// }]
func (t *TriggersHandler) listTriggers(w http.ResponseWriter, r *http.Request) {
t.mu.RLock()
triggers := make([]*Trigger, 0, len(t.triggers))
for _, trigger := range t.triggers {
// Update elapsed time
if trigger.LastFired != nil {
trigger.Elapsed = int(time.Since(*trigger.LastFired).Seconds())
}
@ -178,6 +198,26 @@ type createTriggerRequest struct {
Actions json.RawMessage `json:"actions"`
}
// createTrigger handles POST /api/triggers.
//
// Creates a new automation trigger. The request body must include id, name,
// and condition. Actions default to an empty array if omitted.
//
// Request body (application/json):
//
// {
// "id": "t1",
// "name": "Couch Dwell",
// "condition": "dwell",
// "condition_params": {"duration_s": 30},
// "time_constraint": {"from": "22:00", "to": "06:00"},
// "actions": [{"type": "webhook", "url": "http://example.com/hook"}],
// "enabled": true
// }
//
// Response 201 (application/json): the created trigger object.
// Response 400: missing required fields or invalid condition value.
// Response 500: database error.
func (t *TriggersHandler) createTrigger(w http.ResponseWriter, r *http.Request) {
var req createTriggerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -257,6 +297,18 @@ type updateTriggerRequest struct {
Actions *json.RawMessage `json:"actions,omitempty"`
}
// updateTrigger handles PUT /api/triggers/{id}.
//
// Updates an existing trigger. Only fields present in the request body are
// modified; omitted fields retain their current values. If the body contains
// no recognized fields, the current trigger is returned unchanged.
//
// Request body (application/json): partial trigger object with fields to update.
//
// Response 200 (application/json): the updated trigger object.
// Response 400: invalid request body or invalid condition value.
// Response 404: trigger not found.
// Response 500: database error.
func (t *TriggersHandler) updateTrigger(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
@ -321,7 +373,7 @@ func (t *TriggersHandler) updateTrigger(w http.ResponseWriter, r *http.Request)
}
args = append(args, id)
query := "UPDATE triggers SET " + joinComma(updates) + " WHERE id = ?"
query := "UPDATE triggers SET " + strings.Join(updates, ", ") + " WHERE id = ?"
_, err := t.db.Exec(query, args...)
if err != nil {
@ -354,6 +406,12 @@ func (t *TriggersHandler) updateTrigger(w http.ResponseWriter, r *http.Request)
writeJSON(w, http.StatusOK, trigger)
}
// deleteTrigger handles DELETE /api/triggers/{id}.
//
// Removes a trigger by ID. Deleting a nonexistent ID is a no-op.
//
// Response 204: trigger deleted (or did not exist).
// Response 500: database error.
func (t *TriggersHandler) deleteTrigger(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
@ -379,6 +437,22 @@ func (t *TriggersHandler) deleteTrigger(w http.ResponseWriter, r *http.Request)
w.WriteHeader(http.StatusNoContent)
}
// testTrigger handles POST /api/triggers/{id}/test.
//
// Fires the trigger's actions once with a synthetic event payload for testing.
// If no automation engine is attached, returns a simulated success response.
// Does not update last_fired or trigger any real automation logic.
//
// Response 200 (application/json):
//
// {
// "status": "fired",
// "message": "Trigger fired successfully",
// "trigger": { ... }
// }
//
// Response 404: trigger not found.
// Response 500: engine test-fire failed.
func (t *TriggersHandler) testTrigger(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
@ -391,7 +465,6 @@ func (t *TriggersHandler) testTrigger(w http.ResponseWriter, r *http.Request) {
return
}
// Check if engine is available
t.mu.RLock()
engine := t.engine
t.mu.RUnlock()
@ -451,7 +524,6 @@ func (t *TriggersHandler) EvaluateTriggers(blobs []BlobPos) []string {
shouldFire := false
switch trigger.Condition {
case "enter", "leave":
// Volume-based trigger
if params.VolumeID != "" {
for _, blob := range blobs {
if t.engine != nil && t.engine.IsInVolume(blob.X, blob.Y, blob.Z, params.VolumeID) {
@ -498,6 +570,6 @@ func (t *TriggersHandler) EvaluateTriggers(blobs []BlobPos) []string {
// BlobPos represents a blob position for trigger evaluation.
type BlobPos struct {
ID int
ID int
X, Y, Z float64
}

View file

@ -0,0 +1,794 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
)
// newTriggerTestHandler creates a TriggersHandler backed by an in-memory database.
func newTriggerTestHandler(t *testing.T) (*TriggersHandler, func()) {
t.Helper()
h, err := NewTriggersHandler(":memory:")
if err != nil {
t.Fatalf("NewTriggersHandler: %v", err)
}
return h, func() { h.Close() }
}
// newTriggerTestRouter creates a chi.Router with trigger routes registered.
func newTriggerTestRouter(h *TriggersHandler) *chi.Mux {
r := chi.NewRouter()
h.RegisterRoutes(r)
return r
}
// seedTrigger creates a trigger directly in the handler for test setup.
func seedTrigger(t *testing.T, h *TriggersHandler, tr Trigger) {
t.Helper()
now := time.Now().UnixNano()
enabled := 0
if tr.Enabled {
enabled = 1
}
actions := tr.Actions
if len(actions) == 0 {
actions = json.RawMessage("[]")
}
conditionParams := tr.ConditionParams
if len(conditionParams) == 0 {
conditionParams = json.RawMessage("{}")
}
_, err := h.db.Exec(`
INSERT INTO triggers (id, name, enabled, condition, condition_params, actions, created_at)
VALUES (?, ?, ?, ?, ?, ?)
`, tr.ID, tr.Name, enabled, tr.Condition, string(conditionParams), string(actions), now)
if err != nil {
t.Fatalf("seedTrigger: %v", err)
}
h.mu.Lock()
h.triggers[tr.ID] = &Trigger{
ID: tr.ID,
Name: tr.Name,
Enabled: tr.Enabled,
Condition: tr.Condition,
ConditionParams: conditionParams,
Actions: actions,
CreatedAt: time.Unix(0, now),
}
h.mu.Unlock()
}
// ── GET /api/triggers ─────────────────────────────────────────────────────────────
// TestListTriggers tests GET /api/triggers.
func TestListTriggers(t *testing.T) {
tests := []struct {
name string
setup []Trigger
wantLen int
wantCode int
}{
{
name: "empty store",
setup: nil,
wantLen: 0,
wantCode: http.StatusOK,
},
{
name: "single trigger",
setup: []Trigger{
{ID: "t1", Name: "Couch Dwell", Condition: "dwell", Enabled: true},
},
wantLen: 1,
wantCode: http.StatusOK,
},
{
name: "multiple triggers",
setup: []Trigger{
{ID: "t1", Name: "Enter Hallway", Condition: "enter", Enabled: true},
{ID: "t2", Name: "Leave Home", Condition: "vacant", Enabled: false},
{ID: "t3", Name: "Count Kitchen", Condition: "count", Enabled: true},
},
wantLen: 3,
wantCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
for _, tr := range tt.setup {
seedTrigger(t, h, tr)
}
r := newTriggerTestRouter(h)
req := httptest.NewRequest("GET", "/api/triggers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantCode {
t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String())
}
var result []Trigger
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode: %v", err)
}
if len(result) != tt.wantLen {
t.Errorf("expected %d triggers, got %d", tt.wantLen, len(result))
}
})
}
}
// ── POST /api/triggers ────────────────────────────────────────────────────────────
// TestCreateTrigger tests POST /api/triggers.
func TestCreateTrigger(t *testing.T) {
tests := []struct {
name string
body string
wantCode int
wantID string
wantErr string
}{
{
name: "valid trigger with all fields",
body: `{
"id": "t1",
"name": "Couch Dwell",
"condition": "dwell",
"condition_params": {"duration_s": 30},
"time_constraint": {"from": "22:00", "to": "06:00"},
"actions": [{"type": "webhook", "url": "http://example.com/hook"}],
"enabled": true
}`,
wantCode: http.StatusCreated,
wantID: "t1",
},
{
name: "minimal valid trigger",
body: `{"id": "t2", "name": "Enter", "condition": "enter"}`,
wantCode: http.StatusCreated,
wantID: "t2",
},
{
name: "missing id",
body: `{"name": "No ID", "condition": "enter"}`,
wantCode: http.StatusBadRequest,
wantErr: "id is required",
},
{
name: "missing name",
body: `{"id": "t3", "condition": "enter"}`,
wantCode: http.StatusBadRequest,
wantErr: "name is required",
},
{
name: "invalid condition",
body: `{"id": "t4", "name": "Bad", "condition": "fly"}`,
wantCode: http.StatusBadRequest,
wantErr: "condition must be one of",
},
{
name: "missing condition",
body: `{"id": "t5", "name": "NoCond"}`,
wantCode: http.StatusBadRequest,
wantErr: "condition must be one of",
},
{
name: "malformed JSON",
body: `{invalid}`,
wantCode: http.StatusBadRequest,
wantErr: "invalid request body",
},
{
name: "empty body",
body: ``,
wantCode: http.StatusBadRequest,
wantErr: "invalid request body",
},
{
name: "disabled by default when not specified",
body: `{"id": "t6", "name": "Default", "condition": "leave"}`,
wantCode: http.StatusCreated,
wantID: "t6",
},
{
name: "explicitly disabled",
body: `{"id": "t7", "name": "Off", "condition": "vacant", "enabled": false}`,
wantCode: http.StatusCreated,
wantID: "t7",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
r := newTriggerTestRouter(h)
req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(tt.body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantCode {
t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String())
}
if tt.wantErr != "" {
if !bytes.Contains(w.Body.Bytes(), []byte(tt.wantErr)) {
t.Errorf("expected error to contain %q, got %s", tt.wantErr, w.Body.String())
}
return
}
var created Trigger
if err := json.NewDecoder(w.Body).Decode(&created); err != nil {
t.Fatalf("failed to decode: %v", err)
}
if created.ID != tt.wantID {
t.Errorf("expected ID %q, got %q", tt.wantID, created.ID)
}
if created.CreatedAt.IsZero() {
t.Error("expected non-zero CreatedAt")
}
})
}
}
// TestCreateTriggerDuplicate tests that creating a trigger with a duplicate ID fails.
func TestCreateTriggerDuplicate(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
seedTrigger(t, h, Trigger{ID: "t1", Name: "First", Condition: "enter", Enabled: true})
r := newTriggerTestRouter(h)
body := `{"id": "t1", "name": "Duplicate", "condition": "dwell"}`
req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// SQLite PRIMARY KEY constraint should reject the duplicate
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 for duplicate ID, got %d", w.Code)
}
}
// TestCreateTriggerPersists tests that a created trigger survives a handler reload.
func TestCreateTriggerPersists(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
r := newTriggerTestRouter(h)
body := `{"id": "persist", "name": "Persistent", "condition": "leave"}`
req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("create: expected 201, got %d", w.Code)
}
// Reload from DB
if err := h.load(); err != nil {
t.Fatalf("reload: %v", err)
}
// Verify it's still there
req2 := httptest.NewRequest("GET", "/api/triggers", nil)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
var result []Trigger
json.NewDecoder(w2.Body).Decode(&result)
if len(result) != 1 {
t.Fatalf("after reload: expected 1 trigger, got %d", len(result))
}
if result[0].Name != "Persistent" {
t.Errorf("after reload: expected name 'Persistent', got %s", result[0].Name)
}
}
// ── PUT /api/triggers/{id} ────────────────────────────────────────────────────────
// TestUpdateTrigger tests PUT /api/triggers/{id}.
func TestUpdateTrigger(t *testing.T) {
tests := []struct {
name string
setup Trigger
body string
wantCode int
wantName string
wantEnable bool
}{
{
name: "update name",
setup: Trigger{ID: "t1", Name: "Old", Condition: "enter", Enabled: true},
body: `{"name": "New Name"}`,
wantCode: http.StatusOK,
wantName: "New Name",
wantEnable: true,
},
{
name: "disable trigger",
setup: Trigger{ID: "t1", Name: "On", Condition: "dwell", Enabled: true},
body: `{"enabled": false}`,
wantCode: http.StatusOK,
wantName: "On",
wantEnable: false,
},
{
name: "enable trigger",
setup: Trigger{ID: "t1", Name: "Off", Condition: "vacant", Enabled: false},
body: `{"enabled": true}`,
wantCode: http.StatusOK,
wantName: "Off",
wantEnable: true,
},
{
name: "change condition",
setup: Trigger{ID: "t1", Name: "Flex", Condition: "enter", Enabled: true},
body: `{"condition": "dwell"}`,
wantCode: http.StatusOK,
wantName: "Flex",
wantEnable: true,
},
{
name: "update multiple fields",
setup: Trigger{ID: "t1", Name: "Old", Condition: "enter", Enabled: true},
body: `{"name": "Multi", "condition": "count", "enabled": false}`,
wantCode: http.StatusOK,
wantName: "Multi",
wantEnable: false,
},
{
name: "no-op update returns current",
setup: Trigger{ID: "t1", Name: "Same", Condition: "leave", Enabled: true},
body: `{}`,
wantCode: http.StatusOK,
wantName: "Same",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
seedTrigger(t, h, tt.setup)
r := newTriggerTestRouter(h)
req := httptest.NewRequest("PUT", "/api/triggers/"+tt.setup.ID, bytes.NewReader([]byte(tt.body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantCode {
t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String())
}
var updated Trigger
if err := json.NewDecoder(w.Body).Decode(&updated); err != nil {
t.Fatalf("failed to decode: %v", err)
}
if updated.Name != tt.wantName {
t.Errorf("expected name %q, got %q", tt.wantName, updated.Name)
}
if updated.Enabled != tt.wantEnable {
t.Errorf("expected enabled=%v, got %v", tt.wantEnable, updated.Enabled)
}
})
}
}
// TestUpdateTriggerInvalid tests PUT /api/triggers/{id} with invalid input.
func TestUpdateTriggerInvalid(t *testing.T) {
tests := []struct {
name string
setup Trigger
body string
want int
wantErr string
}{
{
name: "nonexistent trigger",
setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true},
body: `{"name": "Nope"}`,
want: http.StatusNotFound,
wantErr: "trigger not found",
},
{
name: "malformed JSON",
setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true},
body: `{bad}`,
want: http.StatusBadRequest,
wantErr: "invalid request body",
},
{
name: "invalid condition",
setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true},
body: `{"condition": "invalid"}`,
want: http.StatusBadRequest,
wantErr: "condition must be one of",
},
{
name: "empty body",
setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true},
body: ``,
want: http.StatusBadRequest,
wantErr: "invalid request body",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
seedTrigger(t, h, tt.setup)
id := tt.setup.ID
if tt.name == "nonexistent trigger" {
id = "nonexistent"
}
r := newTriggerTestRouter(h)
req := httptest.NewRequest("PUT", "/api/triggers/"+id, bytes.NewReader([]byte(tt.body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.want {
t.Fatalf("expected %d, got %d: %s", tt.want, w.Code, w.Body.String())
}
if !bytes.Contains(w.Body.Bytes(), []byte(tt.wantErr)) {
t.Errorf("expected error to contain %q, got %s", tt.wantErr, w.Body.String())
}
})
}
}
// TestUpdateTriggerPersists tests that an update is persisted across reload.
func TestUpdateTriggerPersists(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
seedTrigger(t, h, Trigger{ID: "t1", Name: "Original", Condition: "enter", Enabled: true})
r := newTriggerTestRouter(h)
body := `{"name": "Updated Name"}`
req := httptest.NewRequest("PUT", "/api/triggers/t1", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update: expected 200, got %d", w.Code)
}
// Reload from DB
if err := h.load(); err != nil {
t.Fatalf("reload: %v", err)
}
req2 := httptest.NewRequest("GET", "/api/triggers", nil)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
var result []Trigger
json.NewDecoder(w2.Body).Decode(&result)
if len(result) != 1 {
t.Fatalf("after reload: expected 1 trigger, got %d", len(result))
}
if result[0].Name != "Updated Name" {
t.Errorf("after reload: expected name 'Updated Name', got %s", result[0].Name)
}
}
// ── DELETE /api/triggers/{id} ─────────────────────────────────────────────────────
// TestDeleteTrigger tests DELETE /api/triggers/{id}.
func TestDeleteTrigger(t *testing.T) {
tests := []struct {
name string
setup []Trigger
deleteID string
wantCode int
wantLen int
}{
{
name: "delete existing trigger",
setup: []Trigger{
{ID: "t1", Name: "Keep", Condition: "enter", Enabled: true},
{ID: "t2", Name: "Delete Me", Condition: "dwell", Enabled: true},
},
deleteID: "t2",
wantCode: http.StatusNoContent,
wantLen: 1,
},
{
name: "delete nonexistent trigger",
setup: []Trigger{{ID: "t1", Name: "Only", Condition: "enter", Enabled: true}},
deleteID: "nonexistent",
wantCode: http.StatusNotFound,
wantLen: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
for _, tr := range tt.setup {
seedTrigger(t, h, tr)
}
r := newTriggerTestRouter(h)
req := httptest.NewRequest("DELETE", "/api/triggers/"+tt.deleteID, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantCode {
t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String())
}
// Verify via list
req2 := httptest.NewRequest("GET", "/api/triggers", nil)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
var result []Trigger
json.NewDecoder(w2.Body).Decode(&result)
if len(result) != tt.wantLen {
t.Errorf("expected %d triggers after delete, got %d", tt.wantLen, len(result))
}
// Verify deleted trigger is not in memory
if tt.wantCode == http.StatusNoContent {
h.mu.RLock()
_, exists := h.triggers[tt.deleteID]
h.mu.RUnlock()
if exists {
t.Error("trigger should be removed from memory")
}
}
})
}
}
// ── POST /api/triggers/{id}/test ─────────────────────────────────────────────────
// TestTestTrigger tests POST /api/triggers/{id}/test.
func TestTestTrigger(t *testing.T) {
tests := []struct {
name string
setup Trigger
testID string
engine TriggerEngine
wantCode int
wantKey string
}{
{
name: "test with no engine returns simulated",
setup: Trigger{ID: "t1", Name: "Sim", Condition: "dwell", Enabled: true},
testID: "t1",
engine: nil,
wantCode: http.StatusOK,
wantKey: "simulated",
},
{
name: "test with engine that succeeds",
setup: Trigger{ID: "t1", Name: "Fire", Condition: "enter", Enabled: true},
testID: "t1",
engine: &mockEngine{err: nil},
wantCode: http.StatusOK,
wantKey: "fired",
},
{
name: "test with engine that fails",
setup: Trigger{ID: "t1", Name: "Fail", Condition: "leave", Enabled: true},
testID: "t1",
engine: &mockEngine{err: fmt.Errorf("boom")},
wantCode: http.StatusInternalServerError,
wantKey: "test fire failed",
},
{
name: "test nonexistent trigger",
setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true},
testID: "nonexistent",
engine: nil,
wantCode: http.StatusNotFound,
wantKey: "trigger not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
seedTrigger(t, h, tt.setup)
if tt.engine != nil {
h.SetEngine(tt.engine)
}
r := newTriggerTestRouter(h)
req := httptest.NewRequest("POST", "/api/triggers/"+tt.testID+"/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.wantCode {
t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String())
}
if !bytes.Contains(w.Body.Bytes(), []byte(tt.wantKey)) {
t.Errorf("expected response to contain %q, got %s", tt.wantKey, w.Body.String())
}
})
}
}
// TestTestTriggerDoesNotUpdateLastFired verifies that the test endpoint
// does not modify last_fired on the trigger.
func TestTestTriggerDoesNotUpdateLastFired(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
tr := Trigger{ID: "t1", Name: "Test", Condition: "enter", Enabled: true}
seedTrigger(t, h, tr)
r := newTriggerTestRouter(h)
req := httptest.NewRequest("POST", "/api/triggers/t1/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
h.mu.RLock()
trigger := h.triggers["t1"]
h.mu.RUnlock()
if trigger.LastFired != nil {
t.Error("expected last_fired to remain nil after test endpoint")
}
}
// ── CRUD round-trip ───────────────────────────────────────────────────────────────
// TestTriggerCRUDRoundTrip verifies the full lifecycle: create -> list -> update -> list -> delete -> verify gone.
func TestTriggerCRUDRoundTrip(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
r := newTriggerTestRouter(h)
// 1. Create
body := `{"id": "rt", "name": "Round Trip", "condition": "dwell", "condition_params": {"duration_s": 60}}`
req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("create: expected 201, got %d", w.Code)
}
// 2. List and verify
req2 := httptest.NewRequest("GET", "/api/triggers", nil)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
var triggers []Trigger
json.NewDecoder(w2.Body).Decode(&triggers)
if len(triggers) != 1 {
t.Fatalf("after create: expected 1 trigger, got %d", len(triggers))
}
if triggers[0].Name != "Round Trip" {
t.Errorf("after create: expected name 'Round Trip', got %s", triggers[0].Name)
}
// 3. Update
body3 := `{"name": "Updated Trip", "enabled": false}`
req3 := httptest.NewRequest("PUT", "/api/triggers/rt", bytes.NewReader([]byte(body3)))
req3.Header.Set("Content-Type", "application/json")
w3 := httptest.NewRecorder()
r.ServeHTTP(w3, req3)
if w3.Code != http.StatusOK {
t.Fatalf("update: expected 200, got %d", w3.Code)
}
// 4. Verify update
req4 := httptest.NewRequest("GET", "/api/triggers", nil)
w4 := httptest.NewRecorder()
r.ServeHTTP(w4, req4)
json.NewDecoder(w4.Body).Decode(&triggers)
if triggers[0].Name != "Updated Trip" {
t.Errorf("after update: expected name 'Updated Trip', got %s", triggers[0].Name)
}
if triggers[0].Enabled {
t.Error("after update: expected enabled=false")
}
// 5. Delete
req5 := httptest.NewRequest("DELETE", "/api/triggers/rt", nil)
w5 := httptest.NewRecorder()
r.ServeHTTP(w5, req5)
if w5.Code != http.StatusNoContent {
t.Fatalf("delete: expected 204, got %d", w5.Code)
}
// 6. Verify gone
req6 := httptest.NewRequest("GET", "/api/triggers", nil)
w6 := httptest.NewRecorder()
r.ServeHTTP(w6, req6)
json.NewDecoder(w6.Body).Decode(&triggers)
if len(triggers) != 0 {
t.Errorf("after delete: expected 0 triggers, got %d", len(triggers))
}
}
// ── EvaluateTriggers ─────────────────────────────────────────────────────────────
// TestEvaluateTriggers tests trigger evaluation logic.
func TestEvaluateTriggers(t *testing.T) {
h, cleanup := newTriggerTestHandler(t)
defer cleanup()
seedTrigger(t, h, Trigger{
ID: "vacant",
Name: "House Empty",
Condition: "vacant",
Enabled: true,
})
seedTrigger(t, h, Trigger{
ID: "disabled",
Name: "Disabled",
Condition: "vacant",
Enabled: false,
})
// No blobs = vacant should fire
fired := h.EvaluateTriggers(nil)
if len(fired) != 1 || fired[0] != "vacant" {
t.Errorf("expected [vacant], got %v", fired)
}
// With blobs = vacant should not fire
fired = h.EvaluateTriggers([]BlobPos{{ID: 1, X: 1, Y: 1, Z: 1}})
if len(fired) != 0 {
t.Errorf("expected no fires with blob present, got %v", fired)
}
// Disabled trigger never fires
fired = h.EvaluateTriggers(nil)
for _, id := range fired {
if id == "disabled" {
t.Error("disabled trigger should not fire")
}
}
}
// ── mock engine ──────────────────────────────────────────────────────────────────
type mockEngine struct {
err error
}
func (m *mockEngine) TestFire(triggerID string) error { return m.err }
func (m *mockEngine) IsInVolume(x, y, z float64, volumeID string) bool {
return true
}

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"log"
"net/http"
"strings"
)
// writeJSON writes a JSON response with the given status code.
@ -28,3 +29,8 @@ func writeJSONData(w http.ResponseWriter, v interface{}) {
func writeJSONError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]string{"error": message})
}
// joinComma joins a slice of strings with ", " separator.
func joinComma(parts []string) string {
return strings.Join(parts, ", ")
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,307 @@
// Package loadshed implements adaptive load shedding for the mothership fusion pipeline.
// It monitors pipeline iteration timing and applies 4 shedding levels to keep the
// system responsive under CPU/memory pressure, especially for large fleets.
//
// Level 0 (normal): rolling avg < 80 ms — full pipeline
// Level 1 (light): rolling avg >= 80 ms — suspend crowd flow accumulation
// Level 2 (moderate): rolling avg >= 90 ms — also suspend CSI replay buffer writes
// Level 3 (heavy): rolling avg >= 95 ms — drop CSI frames when ingest channel > 50% full;
// push rate reduction config to all nodes (10 Hz cap)
//
// Recovery: when rolling avg < 60 ms for 10 consecutive iterations, step down one level.
//
// All reads of the shedding level are lock-free via atomic operations.
package loadshed
import (
"log"
"sync/atomic"
"time"
)
// Level represents the current load shedding level (0-3).
type Level int32
const (
// LevelNormal is the default: full pipeline, no shedding.
LevelNormal Level = iota
// LevelLight suspends crowd flow accumulation.
LevelLight
// LevelModerate also suspends CSI replay buffer writes.
LevelModerate
// LevelHeavy drops CSI frames on full channels and caps node rate at 10 Hz.
LevelHeavy
)
// String returns a human-readable label for the shedding level.
func (l Level) String() string {
switch l {
case LevelNormal:
return "NOMINAL"
case LevelLight:
return "LIGHT"
case LevelModerate:
return "MODERATE"
case LevelHeavy:
return "HIGH"
default:
return "UNKNOWN"
}
}
// Thresholds for load shedding state transitions.
const (
// Thresholds for escalating shedding levels.
thresholdLevel1 = 80 * time.Millisecond
thresholdLevel2 = 90 * time.Millisecond
thresholdLevel3 = 95 * time.Millisecond
// Recovery threshold: rolling avg must stay below this for recoveryCount
// consecutive iterations before stepping down one level.
recoveryThreshold = 60 * time.Millisecond
recoveryCount = 10
// Level 3 CSI rate cap pushed to all nodes.
level3RateCapHz = 10
// Number of iterations in the rolling average window.
rollingWindowSize = 5
)
// IngestChannelFull is a callback that reports whether the ingest processing
// channel is more than 50% full. Returns true if shedding should drop frames.
type IngestChannelFull func() bool
// RatePushCallback is called when Level 3 is entered or exited to push
// rate config changes to all connected nodes.
type RatePushCallback func(rateHz int)
// Stage represents a named pipeline stage for timing instrumentation.
type Stage struct {
Name string
// Start is set by BeginStage. Use StageDuration() to get elapsed time.
start time.Time
}
// Shedder manages the load shedding state machine.
type Shedder struct {
level atomic.Int32 // Current shedding level (0-3), read lock-free.
recoveryTicks atomic.Int32 // Consecutive iterations below recovery threshold.
// Rolling average window (ring buffer).
durations [rollingWindowSize]time.Duration
durationsIdx int
durationsFilled int // how many slots have been written (< rollingWindowSize on startup)
// Pipeline stage timing for instrumentation.
stages [8]Stage
stageIdx int
// Iteration timing.
iterStart time.Time
// External callbacks.
ingestFull IngestChannelFull
ratePush RatePushCallback
// Previous rate before Level 3 was entered, for restoration.
prevRateHz atomic.Int32
level3Active atomic.Bool
}
// New creates a new Shedder.
func New() *Shedder {
return &Shedder{
prevRateHz: atomic.Int32{}, // defaults to 0 (unknown)
}
}
// SetIngestChannelFull sets the callback that reports whether the ingest
// channel is more than 50% full.
func (s *Shedder) SetIngestChannelFull(fn IngestChannelFull) {
s.ingestFull = fn
}
// SetRatePushCallback sets the callback for pushing rate config to nodes.
func (s *Shedder) SetRatePushCallback(fn RatePushCallback) {
s.ratePush = fn
}
// SetPreviousRate records the node rate that was active before Level 3
// was entered. Used to restore the rate on recovery.
func (s *Shedder) SetPreviousRate(hz int) {
s.prevRateHz.Store(int32(hz))
}
// GetLevel returns the current shedding level (lock-free read).
func (s *Shedder) GetLevel() Level {
return Level(s.level.Load())
}
// ShouldAccumulateCrowdFlow returns false when crowd flow accumulation
// should be suspended (Level >= 1).
func (s *Shedder) ShouldAccumulateCrowdFlow() bool {
return s.level.Load() < int32(LevelLight)
}
// ShouldWriteReplay returns false when CSI replay buffer writes
// should be suspended (Level >= 2).
func (s *Shedder) ShouldWriteReplay() bool {
return s.level.Load() < int32(LevelModerate)
}
// ShouldDropFrames returns true when CSI frames should be dropped because
// the ingest channel is more than 50% full (Level 3 only).
func (s *Shedder) ShouldDropFrames() bool {
return s.level.Load() >= int32(LevelHeavy) && s.ingestFull != nil && s.ingestFull()
}
// IsLevel3Active returns true when Level 3 shedding is active.
func (s *Shedder) IsLevel3Active() bool {
return s.level3Active.Load()
}
// GetLevel3RateCap returns the rate cap applied during Level 3 (10 Hz).
func (s *Shedder) GetLevel3RateCap() int {
return level3RateCapHz
}
// BeginIteration marks the start of a pipeline iteration. Call this at the
// beginning of each fusion tick.
func (s *Shedder) BeginIteration() {
s.iterStart = time.Now()
s.stageIdx = 0
}
// BeginStage starts timing a named pipeline stage. Returns a Stage handle
// whose duration is captured on EndIteration.
func (s *Shedder) BeginStage(name string) Stage {
st := Stage{Name: name, start: time.Now()}
if s.stageIdx < len(s.stages) {
s.stages[s.stageIdx] = st
}
return st
}
// EndStage marks the end of a pipeline stage.
func (s *Shedder) EndStage(st Stage) {
_ = st // duration computed lazily in GetStageDurations
}
// GetStageDurations returns the durations of all stages from the most recent
// completed iteration.
func (s *Shedder) GetStageDurations() []time.Duration {
n := s.stageIdx
if n > len(s.stages) {
n = len(s.stages)
}
result := make([]time.Duration, n)
for i := 0; i < n; i++ {
result[i] = time.Since(s.stages[i].start)
}
return result
}
// EndIteration marks the end of a pipeline iteration, updates the rolling
// average, and evaluates the shedding state machine.
func (s *Shedder) EndIteration() {
elapsed := time.Since(s.iterStart)
// Update rolling average window.
s.durations[s.durationsIdx] = elapsed
s.durationsIdx = (s.durationsIdx + 1) % rollingWindowSize
if s.durationsFilled < rollingWindowSize {
s.durationsFilled++
}
avg := s.rollingAvg()
// Evaluate state machine.
prevLevel := Level(s.level.Load())
var newLevel Level
if avg >= thresholdLevel3 {
newLevel = LevelHeavy
} else if avg >= thresholdLevel2 {
newLevel = LevelModerate
} else if avg >= thresholdLevel1 {
newLevel = LevelLight
} else {
// Below all escalation thresholds — check recovery.
if avg < recoveryThreshold {
ticks := s.recoveryTicks.Add(1)
if ticks >= recoveryCount && prevLevel > LevelNormal {
newLevel = prevLevel - 1
s.recoveryTicks.Store(0)
} else {
newLevel = prevLevel
}
} else {
// Between recovery threshold and Level 1 — hold current level.
s.recoveryTicks.Store(0)
newLevel = prevLevel
}
}
// Level can only go UP directly to the new level (no gradual escalation),
// but recovery steps down one level at a time.
if newLevel > prevLevel {
// Escalate directly.
s.setLevel(newLevel)
} else if newLevel < prevLevel {
// Recovery step down.
s.setLevel(newLevel)
} else {
// Reset recovery counter if we didn't step down this tick
// and we're not in recovery mode.
if avg >= recoveryThreshold {
s.recoveryTicks.Store(0)
}
}
}
// setLevel applies a level change and logs it.
func (s *Shedder) setLevel(new Level) {
prev := Level(s.level.Swap(int32(new)))
if prev == new {
return
}
log.Printf("[INFO] Load shedding level changed: %s (%d) → %s (%d)", prev, prev, new, new)
// Level 3 enter/exit: push rate config to nodes.
if new == LevelHeavy && prev < LevelHeavy {
// Entering Level 3.
s.level3Active.Store(true)
if s.ratePush != nil {
s.ratePush(level3RateCapHz)
}
} else if prev == LevelHeavy && new < LevelHeavy {
// Exiting Level 3: restore previous rate.
s.level3Active.Store(false)
if s.ratePush != nil {
prevRate := int(s.prevRateHz.Load())
if prevRate <= 0 {
prevRate = 20 // sensible default
}
s.ratePush(prevRate)
}
}
}
// rollingAvg computes the average iteration duration over the rolling window.
func (s *Shedder) rollingAvg() time.Duration {
n := s.durationsFilled
if n == 0 {
return 0
}
var sum time.Duration
for i := 0; i < n; i++ {
sum += s.durations[i]
}
return sum / time.Duration(n)
}
// RollingAvg returns the current rolling average iteration time (for diagnostics).
func (s *Shedder) RollingAvg() time.Duration {
return s.rollingAvg()
}

View file

@ -0,0 +1,403 @@
package loadshed
import (
"sync"
"sync/atomic"
"testing"
"time"
)
func TestLevelString(t *testing.T) {
tests := []struct {
l Level
want string
}{
{LevelNormal, "NOMINAL"},
{LevelLight, "LIGHT"},
{LevelModerate, "MODERATE"},
{LevelHeavy, "HIGH"},
{Level(99), "UNKNOWN"},
}
for _, tt := range tests {
if got := tt.l.String(); got != tt.want {
t.Errorf("Level(%d).String() = %q, want %q", tt.l, got, tt.want)
}
}
}
func TestShedderDefaultLevel(t *testing.T) {
s := New()
if got := s.GetLevel(); got != LevelNormal {
t.Errorf("New Shedder level = %d, want %d", got, LevelNormal)
}
}
func TestShedderCrowdFlowSuspended(t *testing.T) {
tests := []struct {
level Level
want bool
}{
{LevelNormal, true},
{LevelLight, false},
{LevelModerate, false},
{LevelHeavy, false},
}
for _, tt := range tests {
s := New()
s.level.Store(int32(tt.level))
if got := s.ShouldAccumulateCrowdFlow(); got != tt.want {
t.Errorf("ShouldAccumulateCrowdFlow() at level %d = %v, want %v", tt.level, got, tt.want)
}
}
}
func TestShedderReplayWriteSuspended(t *testing.T) {
tests := []struct {
level Level
want bool
}{
{LevelNormal, true},
{LevelLight, true},
{LevelModerate, false},
{LevelHeavy, false},
}
for _, tt := range tests {
s := New()
s.level.Store(int32(tt.level))
if got := s.ShouldWriteReplay(); got != tt.want {
t.Errorf("ShouldWriteReplay() at level %d = %v, want %v", tt.level, got, tt.want)
}
}
}
func TestShedderShouldDropFrames(t *testing.T) {
s := New()
// Level 3 with channel full → drop
s.level.Store(int32(LevelHeavy))
channelFull := true
s.SetIngestChannelFull(func() bool { return channelFull })
if !s.ShouldDropFrames() {
t.Error("ShouldDropFrames() = false at Level 3 with full channel, want true")
}
// Level 3 with channel not full → don't drop
channelFull = false
if s.ShouldDropFrames() {
t.Error("ShouldDropFrames() = true at Level 3 with empty channel, want false")
}
// Level 2 → never drop regardless of channel
s.level.Store(int32(LevelModerate))
channelFull = true
if s.ShouldDropFrames() {
t.Error("ShouldDropFrames() = true at Level 2, want false")
}
// No callback set → don't drop
s2 := New()
s2.level.Store(int32(LevelHeavy))
if s2.ShouldDropFrames() {
t.Error("ShouldDropFrames() = true with no callback, want false")
}
}
func TestRollingAverage(t *testing.T) {
s := New()
// Fill the window with 5 iterations of 50ms each.
for i := 0; i < 5; i++ {
s.BeginIteration()
s.EndIteration() // duration ≈ 0ms (no actual work)
// Manually set the duration since the test runs too fast.
s.durations[i] = 50 * time.Millisecond
}
s.durationsFilled = 5
if avg := s.rollingAvg(); avg != 50*time.Millisecond {
t.Errorf("rollingAvg() = %v, want 50ms", avg)
}
}
func TestEscalationFromNormal(t *testing.T) {
s := New()
// Fill window with 5x 85ms → should escalate to Level 1
fillWindow(s, 85*time.Millisecond)
s.EndIteration()
// The last EndIteration added a 0ms duration; let's force the durations.
// Instead, let's directly test the state machine logic by simulating.
s2 := New()
// Pre-fill 4 slots at 85ms
for i := 0; i < 4; i++ {
s2.durations[i] = 85 * time.Millisecond
}
s2.durationsFilled = 4
// Now run an iteration that adds 85ms (total 5 slots, avg 85ms)
s2.durations[4] = 85 * time.Millisecond
s2.durationsIdx = 0 // wrapped
s2.durationsFilled = 5
if avg := s2.rollingAvg(); avg != 85*time.Millisecond {
t.Fatalf("rollingAvg() = %v, want 85ms", avg)
}
// Directly test setLevel behavior
s2.setLevel(LevelLight)
if s2.GetLevel() != LevelLight {
t.Errorf("level = %d, want %d", s2.GetLevel(), LevelLight)
}
}
func TestEscalationToLevel3(t *testing.T) {
s := New()
// Pre-fill window at 97ms → Level 3
for i := 0; i < rollingWindowSize; i++ {
s.durations[i] = 97 * time.Millisecond
}
s.durationsFilled = rollingWindowSize
avg := s.rollingAvg()
if avg < thresholdLevel3 {
t.Fatalf("rollingAvg() = %v, expected >= %v", avg, thresholdLevel3)
}
}
func TestRecoveryStepDown(t *testing.T) {
s := New()
s.level.Store(int32(LevelHeavy))
// Simulate recovery: 10 consecutive iterations below 60ms.
for i := 0; i < recoveryCount; i++ {
s.durations[i%rollingWindowSize] = 50 * time.Millisecond
s.durationsIdx = (i + 1) % rollingWindowSize
s.durationsFilled = min(i+1, rollingWindowSize)
prevLevel := Level(s.level.Load())
ticks := s.recoveryTicks.Add(1)
var newLevel Level
if ticks >= recoveryCount && prevLevel > LevelNormal {
newLevel = prevLevel - 1
s.recoveryTicks.Store(0)
} else {
newLevel = prevLevel
}
s.setLevel(newLevel)
}
if s.GetLevel() != LevelModerate {
t.Errorf("after recovery, level = %d, want %d", s.GetLevel(), LevelModerate)
}
}
func TestRecoveryFullSequence(t *testing.T) {
s := New()
s.level.Store(int32(LevelModerate))
// Need 10 iterations below 60ms to recover from Level 2 → Level 1.
for i := 0; i < recoveryCount; i++ {
s.durations[i%rollingWindowSize] = 50 * time.Millisecond
s.durationsIdx = (i + 1) % rollingWindowSize
s.durationsFilled = min(i+1, rollingWindowSize)
prevLevel := Level(s.level.Load())
ticks := s.recoveryTicks.Add(1)
var newLevel Level
if ticks >= recoveryCount && prevLevel > LevelNormal {
newLevel = prevLevel - 1
s.recoveryTicks.Store(0)
} else {
newLevel = prevLevel
}
s.setLevel(newLevel)
}
if s.GetLevel() != LevelLight {
t.Errorf("after recovery from L2, level = %d, want %d", s.GetLevel(), LevelLight)
}
}
func TestRecoveryCounterResetOnAboveThreshold(t *testing.T) {
s := New()
s.level.Store(int32(LevelLight))
// 5 iterations below threshold, then one above.
for i := 0; i < 5; i++ {
s.durations[i] = 50 * time.Millisecond
s.durationsFilled = i + 1
s.recoveryTicks.Add(1)
}
// One iteration above recovery threshold but below L1
s.durations[5%rollingWindowSize] = 70 * time.Millisecond
s.durationsFilled = 6
avg := s.rollingAvg()
if avg >= recoveryThreshold {
s.recoveryTicks.Store(0)
}
if s.recoveryTicks.Load() != 0 {
t.Errorf("recovery ticks should be reset after above-threshold iteration, got %d", s.recoveryTicks.Load())
}
}
func TestLevel3RatePushCallback(t *testing.T) {
s := New()
var pushedRate atomic.Int32
s.SetRatePushCallback(func(rateHz int) {
pushedRate.Store(int32(rateHz))
})
s.SetPreviousRate(20)
// Enter Level 3
s.setLevel(LevelHeavy)
if pushedRate.Load() != level3RateCapHz {
t.Errorf("rate push on L3 enter = %d, want %d", pushedRate.Load(), level3RateCapHz)
}
if !s.IsLevel3Active() {
t.Error("IsLevel3Active() = false after entering L3")
}
// Exit Level 3
s.setLevel(LevelModerate)
if pushedRate.Load() != 20 {
t.Errorf("rate push on L3 exit = %d, want 20", pushedRate.Load())
}
if s.IsLevel3Active() {
t.Error("IsLevel3Active() = true after exiting L3")
}
}
func TestLevel3RestoreDefaultRate(t *testing.T) {
s := New()
var pushedRate atomic.Int32
s.SetRatePushCallback(func(rateHz int) {
pushedRate.Store(int32(rateHz))
})
// Don't set previous rate — should default to 20.
s.setLevel(LevelHeavy)
pushedRate.Store(0) // reset
s.setLevel(LevelModerate)
if pushedRate.Load() != 20 {
t.Errorf("rate push on L3 exit without prev = %d, want 20", pushedRate.Load())
}
}
func TestNoRatePushWithoutCallback(t *testing.T) {
s := New()
// No callback set — should not panic.
s.setLevel(LevelHeavy)
s.setLevel(LevelNormal)
if s.GetLevel() != LevelNormal {
t.Errorf("level = %d, want %d", s.GetLevel(), LevelNormal)
}
}
func TestSetLevelLogsOnce(t *testing.T) {
s := New()
// Setting same level should be a no-op (no extra log).
s.setLevel(LevelLight)
s.setLevel(LevelLight)
if s.GetLevel() != LevelLight {
t.Errorf("level = %d, want %d", s.GetLevel(), LevelLight)
}
}
func TestConcurrentLevelReads(t *testing.T) {
s := New()
s.level.Store(int32(LevelModerate))
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
l := s.GetLevel()
if l != LevelModerate {
t.Errorf("concurrent read got %d, want %d", l, LevelModerate)
}
}()
}
wg.Wait()
}
func TestRollingAvgPartialWindow(t *testing.T) {
s := New()
// Only 3 of 5 slots filled.
s.durations[0] = 100 * time.Millisecond
s.durations[1] = 200 * time.Millisecond
s.durations[2] = 300 * time.Millisecond
s.durationsFilled = 3
want := 200 * time.Millisecond
if got := s.rollingAvg(); got != want {
t.Errorf("partial window avg = %v, want %v", got, want)
}
}
func TestGetLevel3RateCap(t *testing.T) {
s := New()
if s.GetLevel3RateCap() != 10 {
t.Errorf("GetLevel3RateCap() = %d, want 10", s.GetLevel3RateCap())
}
}
func TestLockFreeReads(t *testing.T) {
s := New()
// Verify all "query" methods use only atomic reads (no mutex).
// This is a design assertion — the test confirms no data races under
// concurrent writes and reads.
var wg sync.WaitGroup
// Writer goroutine: rapidly change levels.
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 1000; i++ {
s.setLevel(Level(i % 4))
}
}()
// Reader goroutines: concurrently query all lock-free methods.
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 1000; j++ {
_ = s.GetLevel()
_ = s.ShouldAccumulateCrowdFlow()
_ = s.ShouldWriteReplay()
_ = s.IsLevel3Active()
_ = s.GetLevel3RateCap()
_ = s.RollingAvg()
}
}()
}
wg.Wait()
}
// fillWindow is a test helper that fills the rolling window with a given duration.
func fillWindow(s *Shedder, d time.Duration) {
for i := 0; i < rollingWindowSize; i++ {
s.durations[i] = d
}
s.durationsIdx = 0
s.durationsFilled = rollingWindowSize
}
func min(a, b int) int {
if a < b {
return a
}
return b
}