diff --git a/mothership/cmd/_parse_check.go b/mothership/cmd/_parse_check.go new file mode 100644 index 0000000..3e0d3fc --- /dev/null +++ b/mothership/cmd/_parse_check.go @@ -0,0 +1,20 @@ +//go:build ignore + +package main + +import ( + "fmt" + "go/parser" + "go/token" + "os" +) + +func main() { + fset := token.NewFileSet() + _, err := parser.ParseFile(fset, "cmd/mothership/main.go", nil, parser.AllErrors|parser.ParseComments) + if err != nil { + fmt.Println("Parse error:", err) + os.Exit(1) + } + fmt.Println("Parse OK") +} diff --git a/mothership/internal/analytics/anomaly.go b/mothership/internal/analytics/anomaly.go index 9f3f0c9..8d5f817 100644 --- a/mothership/internal/analytics/anomaly.go +++ b/mothership/internal/analytics/anomaly.go @@ -442,6 +442,17 @@ func (d *Detector) loadLearningState() error { d.modelReadyAt = d.learningStartTime.Add(7 * 24 * time.Hour) } + // Load security_mode from database (persisted across restarts) + var securityModeStr string + err = d.db.QueryRow(`SELECT value FROM learning_state WHERE key = 'security_mode'`).Scan(&securityModeStr) + if err == nil { + d.securityMode = SecurityMode(securityModeStr) + log.Printf("[INFO] Loaded security mode from database: %s", d.securityMode) + } else if err != sql.ErrNoRows { + log.Printf("[WARN] Failed to load security_mode from database: %v", err) + } + // If security_mode doesn't exist in DB, default to disarmed (already set in NewDetector) + // Load device first seen times deviceRows, err := d.db.Query(`SELECT mac, first_seen_ns FROM device_first_seen`) if err != nil { diff --git a/mothership/internal/api/security.go b/mothership/internal/api/security.go new file mode 100644 index 0000000..d0f2e5c --- /dev/null +++ b/mothership/internal/api/security.go @@ -0,0 +1,167 @@ +// Package api provides REST API handlers for Spaxel security mode. +package api + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/go-chi/chi" + "github.com/spaxel/mothership/internal/analytics" + "github.com/spaxel/mothership/internal/events" +) + +// SecurityHandler manages security mode state and API endpoints. +type SecurityHandler struct { + detector DetectorProvider +} + +// DetectorProvider is an interface to access the anomaly detector. +type DetectorProvider interface { + GetSecurityMode() analytics.SecurityMode + SetSecurityMode(mode analytics.SecurityMode, reason string) + IsSecurityModeActive() bool + GetLearningProgress() float64 + IsModelReady() bool + GetActiveAnomalies() []*events.AnomalyEvent + GetAnomalyHistory(limit int) []*events.AnomalyEvent +} + +// NewSecurityHandler creates a new security handler. +func NewSecurityHandler(detector DetectorProvider) *SecurityHandler { + return &SecurityHandler{ + detector: detector, + } +} + +// RegisterRoutes registers security API routes on the given router. +func (h *SecurityHandler) RegisterRoutes(r chi.Router) { + r.Post("/api/security/arm", h.handleArm) + r.Post("/api/security/disarm", h.handleDisarm) + r.Get("/api/security/status", h.handleStatus) +} + +// SecurityStatus represents the current security mode state. +type SecurityStatus struct { + Armed bool `json:"armed"` + Mode string `json:"mode,omitempty"` // "armed", "armed_stay", or "disarmed" + LearningUntil string `json:"learning_until,omitempty"` // ISO8601 when model will be ready, empty if ready + AnomalyCount24h int `json:"anomaly_count_24h"` + ModelReady bool `json:"model_ready"` +} + +// handleStatus returns the current security mode status. +// Response JSON: +// { +// "armed": true, +// "mode": "armed", +// "learning_until": "2024-04-15T10:30:00Z", // omitted if model_ready +// "anomaly_count_24h": 5, +// "model_ready": false +// } +func (h *SecurityHandler) handleStatus(w http.ResponseWriter, r *http.Request) { + if h.detector == nil { + http.Error(w, "detector not available", http.StatusServiceUnavailable) + return + } + + mode := h.detector.GetSecurityMode() + armed := h.detector.IsSecurityModeActive() + modelReady := h.detector.IsModelReady() + progress := h.detector.GetLearningProgress() + + status := SecurityStatus{ + Armed: armed, + Mode: string(mode), + ModelReady: modelReady, + AnomalyCount24h: h.countAnomalies24h(), + } + + // Calculate learning_until if model is not ready + if !modelReady { + // Get the learning start time by calculating from progress + // progress = elapsed / (7 days) + // elapsed = progress * 7 days + // learning_until = start + 7 days = now + (7 days - elapsed) + elapsed := time.Duration(float64(7*24*time.Hour) * progress) + remaining := 7*24*time.Hour - elapsed + learningUntil := time.Now().Add(remaining) + status.LearningUntil = learningUntil.Format(time.RFC3339) + } + + writeJSON(w, http.StatusOK, status) +} + +// handleArm enables security mode. +// Request body (optional): {"mode": "armed"} or {"mode": "armed_stay"} +// Default mode is "armed" if not specified. +// Response: {"armed": true, "mode": "armed"} +func (h *SecurityHandler) handleArm(w http.ResponseWriter, r *http.Request) { + if h.detector == nil { + http.Error(w, "detector not available", http.StatusServiceUnavailable) + return + } + + var req struct { + Mode string `json:"mode"` // "armed" or "armed_stay" + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil && err.Error() != "EOF" { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + var mode analytics.SecurityMode + switch req.Mode { + case "armed_stay": + mode = analytics.SecurityModeArmedStay + case "armed", "": + mode = analytics.SecurityModeArmed + default: + http.Error(w, "invalid mode: must be 'armed' or 'armed_stay'", http.StatusBadRequest) + return + } + + h.detector.SetSecurityMode(mode, "api") + + status := map[string]interface{}{ + "armed": true, + "mode": string(mode), + } + writeJSON(w, http.StatusOK, status) +} + +// handleDisarm disables security mode. +// Response: {"armed": false, "mode": "disarmed"} +func (h *SecurityHandler) handleDisarm(w http.ResponseWriter, r *http.Request) { + if h.detector == nil { + http.Error(w, "detector not available", http.StatusServiceUnavailable) + return + } + + h.detector.SetSecurityMode(analytics.SecurityModeDisarmed, "api") + + status := map[string]interface{}{ + "armed": false, + "mode": "disarmed", + } + writeJSON(w, http.StatusOK, status) +} + +// countAnomalies24h counts anomalies detected in the last 24 hours. +func (h *SecurityHandler) countAnomalies24h() int { + if h.detector == nil { + return 0 + } + + history := h.detector.GetAnomalyHistory(1000) // Get enough history + cutoff := time.Now().Add(-24 * time.Hour) + + count := 0 + for _, event := range history { + if event.Timestamp.After(cutoff) { + count++ + } + } + + return count +} diff --git a/mothership/internal/api/security_test.go b/mothership/internal/api/security_test.go new file mode 100644 index 0000000..17904fe --- /dev/null +++ b/mothership/internal/api/security_test.go @@ -0,0 +1,390 @@ +// Package api provides tests for security API endpoints. +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/spaxel/mothership/internal/analytics" + "github.com/spaxel/mothership/internal/events" +) + +// mockDetectorProvider is a mock implementation of DetectorProvider for testing. +type mockDetectorProvider struct { + mode analytics.SecurityMode + isActive bool + progress float64 + modelReady bool + activeAnomalies []*events.AnomalyEvent + history []*events.AnomalyEvent + modeChanges []analytics.SecurityMode +} + +func (m *mockDetectorProvider) GetSecurityMode() analytics.SecurityMode { + return m.mode +} + +func (m *mockDetectorProvider) SetSecurityMode(mode analytics.SecurityMode, reason string) { + m.mode = mode + m.modeChanges = append(m.modeChanges, mode) +} + +func (m *mockDetectorProvider) IsSecurityModeActive() bool { + return m.isActive +} + +func (m *mockDetectorProvider) GetLearningProgress() float64 { + return m.progress +} + +func (m *mockDetectorProvider) IsModelReady() bool { + return m.modelReady +} + +func (m *mockDetectorProvider) GetActiveAnomalies() []*events.AnomalyEvent { + return m.activeAnomalies +} + +func (m *mockDetectorProvider) GetAnomalyHistory(limit int) []*events.AnomalyEvent { + if len(m.history) <= limit { + return m.history + } + return m.history[len(m.history)-limit:] +} + +func TestSecurityHandler_Status(t *testing.T) { + tests := []struct { + name string + mode analytics.SecurityMode + isActive bool + modelReady bool + progress float64 + anomalies24h int + wantStatusCode int + wantArmed bool + wantMode string + }{ + { + name: "disarmed mode", + mode: analytics.SecurityModeDisarmed, + isActive: false, + modelReady: false, + progress: 0.5, + anomalies24h: 3, + wantStatusCode: http.StatusOK, + wantArmed: false, + wantMode: "disarmed", + }, + { + name: "armed mode", + mode: analytics.SecurityModeArmed, + isActive: true, + modelReady: true, + progress: 1.0, + anomalies24h: 0, + wantStatusCode: http.StatusOK, + wantArmed: true, + wantMode: "armed", + }, + { + name: "armed_stay mode", + mode: analytics.SecurityModeArmedStay, + isActive: true, + modelReady: true, + progress: 1.0, + anomalies24h: 1, + wantStatusCode: http.StatusOK, + wantArmed: true, + wantMode: "armed_stay", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create anomalies for the last 24h + history := make([]*events.AnomalyEvent, tt.anomalies24h) + for i := 0; i < tt.anomalies24h; i++ { + history[i] = &events.AnomalyEvent{ + ID: time.Now().Add(time.Duration(i) * time.Hour).Format("20060102150405"), + Timestamp: time.Now().Add(time.Duration(i) * time.Hour), + } + } + + mock := &mockDetectorProvider{ + mode: tt.mode, + isActive: tt.isActive, + modelReady: tt.modelReady, + progress: tt.progress, + history: history, + } + + handler := NewSecurityHandler(mock) + r := chi.NewRouter() + handler.RegisterRoutes(r) + + req := httptest.NewRequest("GET", "/api/security/status", nil) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != tt.wantStatusCode { + t.Errorf("status code = %d, want %d", w.Code, tt.wantStatusCode) + } + + var status SecurityStatus + if err := json.Unmarshal(w.Body.Bytes(), &status); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if status.Armed != tt.wantArmed { + t.Errorf("Armed = %v, want %v", status.Armed, tt.wantArmed) + } + + if status.Mode != tt.wantMode { + t.Errorf("Mode = %s, want %s", status.Mode, tt.wantMode) + } + + if status.ModelReady != tt.modelReady { + t.Errorf("ModelReady = %v, want %v", status.ModelReady, tt.modelReady) + } + + if status.AnomalyCount24h != tt.anomalies24h { + t.Errorf("AnomalyCount24h = %d, want %d", status.AnomalyCount24h, tt.anomalies24h) + } + + // Check learning_until is set when model is not ready + if !tt.modelReady && status.LearningUntil == "" { + t.Error("LearningUntil should be set when model is not ready") + } + if tt.modelReady && status.LearningUntil != "" { + t.Error("LearningUntil should be empty when model is ready") + } + }) + } +} + +func TestSecurityHandler_Arm(t *testing.T) { + tests := []struct { + name string + requestBody string + initialMode analytics.SecurityMode + wantMode analytics.SecurityMode + wantStatusCode int + }{ + { + name: "arm without mode defaults to armed", + requestBody: `{}`, + initialMode: analytics.SecurityModeDisarmed, + wantMode: analytics.SecurityModeArmed, + wantStatusCode: http.StatusOK, + }, + { + name: "arm with armed mode", + requestBody: `{"mode": "armed"}`, + initialMode: analytics.SecurityModeDisarmed, + wantMode: analytics.SecurityModeArmed, + wantStatusCode: http.StatusOK, + }, + { + name: "arm with armed_stay mode", + requestBody: `{"mode": "armed_stay"}`, + initialMode: analytics.SecurityModeDisarmed, + wantMode: analytics.SecurityModeArmedStay, + wantStatusCode: http.StatusOK, + }, + { + name: "invalid mode returns bad request", + requestBody: `{"mode": "invalid"}`, + initialMode: analytics.SecurityModeDisarmed, + wantMode: analytics.SecurityModeDisarmed, // unchanged + wantStatusCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockDetectorProvider{ + mode: tt.initialMode, + } + + handler := NewSecurityHandler(mock) + r := chi.NewRouter() + handler.RegisterRoutes(r) + + req := httptest.NewRequest("POST", "/api/security/arm", bytes.NewBufferString(tt.requestBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != tt.wantStatusCode { + t.Errorf("status code = %d, want %d", w.Code, tt.wantStatusCode) + } + + if tt.wantStatusCode == http.StatusOK { + if mock.mode != tt.wantMode { + t.Errorf("mode = %s, want %s", mock.mode, tt.wantMode) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["armed"] != true { + t.Errorf("armed = %v, want true", resp["armed"]) + } + } else { + // Mode should not have changed on error + if mock.mode != tt.initialMode { + t.Errorf("mode = %s, want %s (unchanged)", mock.mode, tt.initialMode) + } + } + }) + } +} + +func TestSecurityHandler_Disarm(t *testing.T) { + tests := []struct { + name string + initialMode analytics.SecurityMode + wantMode analytics.SecurityMode + wantStatusCode int + }{ + { + name: "disarm from armed", + initialMode: analytics.SecurityModeArmed, + wantMode: analytics.SecurityModeDisarmed, + wantStatusCode: http.StatusOK, + }, + { + name: "disarm from armed_stay", + initialMode: analytics.SecurityModeArmedStay, + wantMode: analytics.SecurityModeDisarmed, + wantStatusCode: http.StatusOK, + }, + { + name: "disarm when already disarmed", + initialMode: analytics.SecurityModeDisarmed, + wantMode: analytics.SecurityModeDisarmed, + wantStatusCode: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockDetectorProvider{ + mode: tt.initialMode, + } + + handler := NewSecurityHandler(mock) + r := chi.NewRouter() + handler.RegisterRoutes(r) + + req := httptest.NewRequest("POST", "/api/security/disarm", nil) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != tt.wantStatusCode { + t.Errorf("status code = %d, want %d", w.Code, tt.wantStatusCode) + } + + if mock.mode != tt.wantMode { + t.Errorf("mode = %s, want %s", mock.mode, tt.wantMode) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["armed"] != false { + t.Errorf("armed = %v, want false", resp["armed"]) + } + }) + } +} + +func TestSecurityHandler_NilDetector(t *testing.T) { + handler := NewSecurityHandler(nil) + r := chi.NewRouter() + handler.RegisterRoutes(r) + + tests := []struct { + name string + method string + path string + body string + }{ + {name: "status", method: "GET", path: "/api/security/status"}, + {name: "arm", method: "POST", path: "/api/security/arm", body: `{}`}, + {name: "disarm", method: "POST", path: "/api/security/disarm"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var body *bytes.Buffer + if tt.body != "" { + body = bytes.NewBufferString(tt.body) + } else { + body = &bytes.Buffer{} + } + + req := httptest.NewRequest(tt.method, tt.path, body) + if tt.method == "POST" { + req.Header.Set("Content-Type", "application/json") + } + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status code = %d, want %d", w.Code, http.StatusServiceUnavailable) + } + }) + } +} + +func TestSecurityHandler_CountAnomalies24h(t *testing.T) { + now := time.Now() + history := []*events.AnomalyEvent{ + {Timestamp: now.Add(-1 * time.Hour)}, // Within 24h + {Timestamp: now.Add(-12 * time.Hour)}, // Within 24h + {Timestamp: now.Add(-25 * time.Hour)}, // Outside 24h + {Timestamp: now.Add(-48 * time.Hour)}, // Outside 24h + } + + mock := &mockDetectorProvider{ + mode: analytics.SecurityModeDisarmed, + history: history, + } + + handler := NewSecurityHandler(mock) + r := chi.NewRouter() + handler.RegisterRoutes(r) + + req := httptest.NewRequest("GET", "/api/security/status", nil) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status code = %d, want %d", w.Code, http.StatusOK) + } + + var status SecurityStatus + if err := json.Unmarshal(w.Body.Bytes(), &status); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + // Should count only the 2 anomalies within 24h + if status.AnomalyCount24h != 2 { + t.Errorf("AnomalyCount24h = %d, want 2", status.AnomalyCount24h) + } +} diff --git a/mothership/internal/api/triggers.go b/mothership/internal/api/triggers.go index 610a6ef..696d37d 100644 --- a/mothership/internal/api/triggers.go +++ b/mothership/internal/api/triggers.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "sync" "time" @@ -18,10 +19,10 @@ import ( // TriggersHandler manages automation triggers. type TriggersHandler struct { - mu sync.RWMutex - db *sql.DB - triggers map[string]*Trigger - engine TriggerEngine + mu sync.RWMutex + db *sql.DB + triggers map[string]*Trigger + engine TriggerEngine } // Trigger represents an automation trigger. @@ -138,13 +139,15 @@ func (t *TriggersHandler) SetEngine(engine TriggerEngine) { t.mu.Unlock() } -// RegisterRoutes registers triggers endpoints. +// RegisterRoutes registers triggers endpoints on the given router. // -// GET /api/triggers — list all triggers -// POST /api/triggers — create trigger -// PUT /api/triggers/{id} — update -// DELETE /api/triggers/{id} — delete -// POST /api/triggers/{id}/test — fire trigger once for testing +// Routes: +// +// GET /api/triggers — list all triggers +// POST /api/triggers — create a new trigger +// PUT /api/triggers/{id} — update an existing trigger +// DELETE /api/triggers/{id} — delete a trigger +// POST /api/triggers/{id}/test — fire trigger actions once for testing func (t *TriggersHandler) RegisterRoutes(r chi.Router) { r.Get("/api/triggers", t.listTriggers) r.Post("/api/triggers", t.createTrigger) @@ -153,11 +156,28 @@ func (t *TriggersHandler) RegisterRoutes(r chi.Router) { r.Post("/api/triggers/{id}/test", t.testTrigger) } +// listTriggers handles GET /api/triggers. +// +// Returns all registered triggers as a JSON array. +// +// Response 200 (application/json): +// +// [{ +// "id": "t1", +// "name": "Couch Dwell", +// "enabled": true, +// "condition": "dwell", +// "condition_params": {"duration_s": 30}, +// "time_constraint": {"from": "22:00", "to": "06:00"}, +// "actions": [{"type": "webhook", "url": "http://example.com/hook"}], +// "last_fired": "2024-03-15T14:32:05Z", +// "elapsed": 142, +// "created_at": "2024-03-10T08:00:00Z" +// }] func (t *TriggersHandler) listTriggers(w http.ResponseWriter, r *http.Request) { t.mu.RLock() triggers := make([]*Trigger, 0, len(t.triggers)) for _, trigger := range t.triggers { - // Update elapsed time if trigger.LastFired != nil { trigger.Elapsed = int(time.Since(*trigger.LastFired).Seconds()) } @@ -178,6 +198,26 @@ type createTriggerRequest struct { Actions json.RawMessage `json:"actions"` } +// createTrigger handles POST /api/triggers. +// +// Creates a new automation trigger. The request body must include id, name, +// and condition. Actions default to an empty array if omitted. +// +// Request body (application/json): +// +// { +// "id": "t1", +// "name": "Couch Dwell", +// "condition": "dwell", +// "condition_params": {"duration_s": 30}, +// "time_constraint": {"from": "22:00", "to": "06:00"}, +// "actions": [{"type": "webhook", "url": "http://example.com/hook"}], +// "enabled": true +// } +// +// Response 201 (application/json): the created trigger object. +// Response 400: missing required fields or invalid condition value. +// Response 500: database error. func (t *TriggersHandler) createTrigger(w http.ResponseWriter, r *http.Request) { var req createTriggerRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -257,6 +297,18 @@ type updateTriggerRequest struct { Actions *json.RawMessage `json:"actions,omitempty"` } +// updateTrigger handles PUT /api/triggers/{id}. +// +// Updates an existing trigger. Only fields present in the request body are +// modified; omitted fields retain their current values. If the body contains +// no recognized fields, the current trigger is returned unchanged. +// +// Request body (application/json): partial trigger object with fields to update. +// +// Response 200 (application/json): the updated trigger object. +// Response 400: invalid request body or invalid condition value. +// Response 404: trigger not found. +// Response 500: database error. func (t *TriggersHandler) updateTrigger(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") @@ -321,7 +373,7 @@ func (t *TriggersHandler) updateTrigger(w http.ResponseWriter, r *http.Request) } args = append(args, id) - query := "UPDATE triggers SET " + joinComma(updates) + " WHERE id = ?" + query := "UPDATE triggers SET " + strings.Join(updates, ", ") + " WHERE id = ?" _, err := t.db.Exec(query, args...) if err != nil { @@ -354,6 +406,12 @@ func (t *TriggersHandler) updateTrigger(w http.ResponseWriter, r *http.Request) writeJSON(w, http.StatusOK, trigger) } +// deleteTrigger handles DELETE /api/triggers/{id}. +// +// Removes a trigger by ID. Deleting a nonexistent ID is a no-op. +// +// Response 204: trigger deleted (or did not exist). +// Response 500: database error. func (t *TriggersHandler) deleteTrigger(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") @@ -379,6 +437,22 @@ func (t *TriggersHandler) deleteTrigger(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusNoContent) } +// testTrigger handles POST /api/triggers/{id}/test. +// +// Fires the trigger's actions once with a synthetic event payload for testing. +// If no automation engine is attached, returns a simulated success response. +// Does not update last_fired or trigger any real automation logic. +// +// Response 200 (application/json): +// +// { +// "status": "fired", +// "message": "Trigger fired successfully", +// "trigger": { ... } +// } +// +// Response 404: trigger not found. +// Response 500: engine test-fire failed. func (t *TriggersHandler) testTrigger(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") @@ -391,7 +465,6 @@ func (t *TriggersHandler) testTrigger(w http.ResponseWriter, r *http.Request) { return } - // Check if engine is available t.mu.RLock() engine := t.engine t.mu.RUnlock() @@ -451,7 +524,6 @@ func (t *TriggersHandler) EvaluateTriggers(blobs []BlobPos) []string { shouldFire := false switch trigger.Condition { case "enter", "leave": - // Volume-based trigger if params.VolumeID != "" { for _, blob := range blobs { if t.engine != nil && t.engine.IsInVolume(blob.X, blob.Y, blob.Z, params.VolumeID) { @@ -498,6 +570,6 @@ func (t *TriggersHandler) EvaluateTriggers(blobs []BlobPos) []string { // BlobPos represents a blob position for trigger evaluation. type BlobPos struct { - ID int + ID int X, Y, Z float64 } diff --git a/mothership/internal/api/triggers_test.go b/mothership/internal/api/triggers_test.go new file mode 100644 index 0000000..c8a6265 --- /dev/null +++ b/mothership/internal/api/triggers_test.go @@ -0,0 +1,794 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" +) + +// newTriggerTestHandler creates a TriggersHandler backed by an in-memory database. +func newTriggerTestHandler(t *testing.T) (*TriggersHandler, func()) { + t.Helper() + h, err := NewTriggersHandler(":memory:") + if err != nil { + t.Fatalf("NewTriggersHandler: %v", err) + } + return h, func() { h.Close() } +} + +// newTriggerTestRouter creates a chi.Router with trigger routes registered. +func newTriggerTestRouter(h *TriggersHandler) *chi.Mux { + r := chi.NewRouter() + h.RegisterRoutes(r) + return r +} + +// seedTrigger creates a trigger directly in the handler for test setup. +func seedTrigger(t *testing.T, h *TriggersHandler, tr Trigger) { + t.Helper() + now := time.Now().UnixNano() + enabled := 0 + if tr.Enabled { + enabled = 1 + } + actions := tr.Actions + if len(actions) == 0 { + actions = json.RawMessage("[]") + } + conditionParams := tr.ConditionParams + if len(conditionParams) == 0 { + conditionParams = json.RawMessage("{}") + } + _, err := h.db.Exec(` + INSERT INTO triggers (id, name, enabled, condition, condition_params, actions, created_at) + VALUES (?, ?, ?, ?, ?, ?) + `, tr.ID, tr.Name, enabled, tr.Condition, string(conditionParams), string(actions), now) + if err != nil { + t.Fatalf("seedTrigger: %v", err) + } + h.mu.Lock() + h.triggers[tr.ID] = &Trigger{ + ID: tr.ID, + Name: tr.Name, + Enabled: tr.Enabled, + Condition: tr.Condition, + ConditionParams: conditionParams, + Actions: actions, + CreatedAt: time.Unix(0, now), + } + h.mu.Unlock() +} + +// ── GET /api/triggers ───────────────────────────────────────────────────────────── + +// TestListTriggers tests GET /api/triggers. +func TestListTriggers(t *testing.T) { + tests := []struct { + name string + setup []Trigger + wantLen int + wantCode int + }{ + { + name: "empty store", + setup: nil, + wantLen: 0, + wantCode: http.StatusOK, + }, + { + name: "single trigger", + setup: []Trigger{ + {ID: "t1", Name: "Couch Dwell", Condition: "dwell", Enabled: true}, + }, + wantLen: 1, + wantCode: http.StatusOK, + }, + { + name: "multiple triggers", + setup: []Trigger{ + {ID: "t1", Name: "Enter Hallway", Condition: "enter", Enabled: true}, + {ID: "t2", Name: "Leave Home", Condition: "vacant", Enabled: false}, + {ID: "t3", Name: "Count Kitchen", Condition: "count", Enabled: true}, + }, + wantLen: 3, + wantCode: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + for _, tr := range tt.setup { + seedTrigger(t, h, tr) + } + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("GET", "/api/triggers", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.wantCode { + t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String()) + } + + var result []Trigger + if err := json.NewDecoder(w.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode: %v", err) + } + if len(result) != tt.wantLen { + t.Errorf("expected %d triggers, got %d", tt.wantLen, len(result)) + } + }) + } +} + +// ── POST /api/triggers ──────────────────────────────────────────────────────────── + +// TestCreateTrigger tests POST /api/triggers. +func TestCreateTrigger(t *testing.T) { + tests := []struct { + name string + body string + wantCode int + wantID string + wantErr string + }{ + { + name: "valid trigger with all fields", + body: `{ + "id": "t1", + "name": "Couch Dwell", + "condition": "dwell", + "condition_params": {"duration_s": 30}, + "time_constraint": {"from": "22:00", "to": "06:00"}, + "actions": [{"type": "webhook", "url": "http://example.com/hook"}], + "enabled": true + }`, + wantCode: http.StatusCreated, + wantID: "t1", + }, + { + name: "minimal valid trigger", + body: `{"id": "t2", "name": "Enter", "condition": "enter"}`, + wantCode: http.StatusCreated, + wantID: "t2", + }, + { + name: "missing id", + body: `{"name": "No ID", "condition": "enter"}`, + wantCode: http.StatusBadRequest, + wantErr: "id is required", + }, + { + name: "missing name", + body: `{"id": "t3", "condition": "enter"}`, + wantCode: http.StatusBadRequest, + wantErr: "name is required", + }, + { + name: "invalid condition", + body: `{"id": "t4", "name": "Bad", "condition": "fly"}`, + wantCode: http.StatusBadRequest, + wantErr: "condition must be one of", + }, + { + name: "missing condition", + body: `{"id": "t5", "name": "NoCond"}`, + wantCode: http.StatusBadRequest, + wantErr: "condition must be one of", + }, + { + name: "malformed JSON", + body: `{invalid}`, + wantCode: http.StatusBadRequest, + wantErr: "invalid request body", + }, + { + name: "empty body", + body: ``, + wantCode: http.StatusBadRequest, + wantErr: "invalid request body", + }, + { + name: "disabled by default when not specified", + body: `{"id": "t6", "name": "Default", "condition": "leave"}`, + wantCode: http.StatusCreated, + wantID: "t6", + }, + { + name: "explicitly disabled", + body: `{"id": "t7", "name": "Off", "condition": "vacant", "enabled": false}`, + wantCode: http.StatusCreated, + wantID: "t7", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.wantCode { + t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String()) + } + + if tt.wantErr != "" { + if !bytes.Contains(w.Body.Bytes(), []byte(tt.wantErr)) { + t.Errorf("expected error to contain %q, got %s", tt.wantErr, w.Body.String()) + } + return + } + + var created Trigger + if err := json.NewDecoder(w.Body).Decode(&created); err != nil { + t.Fatalf("failed to decode: %v", err) + } + if created.ID != tt.wantID { + t.Errorf("expected ID %q, got %q", tt.wantID, created.ID) + } + if created.CreatedAt.IsZero() { + t.Error("expected non-zero CreatedAt") + } + }) + } +} + +// TestCreateTriggerDuplicate tests that creating a trigger with a duplicate ID fails. +func TestCreateTriggerDuplicate(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + seedTrigger(t, h, Trigger{ID: "t1", Name: "First", Condition: "enter", Enabled: true}) + + r := newTriggerTestRouter(h) + body := `{"id": "t1", "name": "Duplicate", "condition": "dwell"}` + req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // SQLite PRIMARY KEY constraint should reject the duplicate + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500 for duplicate ID, got %d", w.Code) + } +} + +// TestCreateTriggerPersists tests that a created trigger survives a handler reload. +func TestCreateTriggerPersists(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + r := newTriggerTestRouter(h) + body := `{"id": "persist", "name": "Persistent", "condition": "leave"}` + req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("create: expected 201, got %d", w.Code) + } + + // Reload from DB + if err := h.load(); err != nil { + t.Fatalf("reload: %v", err) + } + + // Verify it's still there + req2 := httptest.NewRequest("GET", "/api/triggers", nil) + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + + var result []Trigger + json.NewDecoder(w2.Body).Decode(&result) + if len(result) != 1 { + t.Fatalf("after reload: expected 1 trigger, got %d", len(result)) + } + if result[0].Name != "Persistent" { + t.Errorf("after reload: expected name 'Persistent', got %s", result[0].Name) + } +} + +// ── PUT /api/triggers/{id} ──────────────────────────────────────────────────────── + +// TestUpdateTrigger tests PUT /api/triggers/{id}. +func TestUpdateTrigger(t *testing.T) { + tests := []struct { + name string + setup Trigger + body string + wantCode int + wantName string + wantEnable bool + }{ + { + name: "update name", + setup: Trigger{ID: "t1", Name: "Old", Condition: "enter", Enabled: true}, + body: `{"name": "New Name"}`, + wantCode: http.StatusOK, + wantName: "New Name", + wantEnable: true, + }, + { + name: "disable trigger", + setup: Trigger{ID: "t1", Name: "On", Condition: "dwell", Enabled: true}, + body: `{"enabled": false}`, + wantCode: http.StatusOK, + wantName: "On", + wantEnable: false, + }, + { + name: "enable trigger", + setup: Trigger{ID: "t1", Name: "Off", Condition: "vacant", Enabled: false}, + body: `{"enabled": true}`, + wantCode: http.StatusOK, + wantName: "Off", + wantEnable: true, + }, + { + name: "change condition", + setup: Trigger{ID: "t1", Name: "Flex", Condition: "enter", Enabled: true}, + body: `{"condition": "dwell"}`, + wantCode: http.StatusOK, + wantName: "Flex", + wantEnable: true, + }, + { + name: "update multiple fields", + setup: Trigger{ID: "t1", Name: "Old", Condition: "enter", Enabled: true}, + body: `{"name": "Multi", "condition": "count", "enabled": false}`, + wantCode: http.StatusOK, + wantName: "Multi", + wantEnable: false, + }, + { + name: "no-op update returns current", + setup: Trigger{ID: "t1", Name: "Same", Condition: "leave", Enabled: true}, + body: `{}`, + wantCode: http.StatusOK, + wantName: "Same", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + seedTrigger(t, h, tt.setup) + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("PUT", "/api/triggers/"+tt.setup.ID, bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.wantCode { + t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String()) + } + + var updated Trigger + if err := json.NewDecoder(w.Body).Decode(&updated); err != nil { + t.Fatalf("failed to decode: %v", err) + } + if updated.Name != tt.wantName { + t.Errorf("expected name %q, got %q", tt.wantName, updated.Name) + } + if updated.Enabled != tt.wantEnable { + t.Errorf("expected enabled=%v, got %v", tt.wantEnable, updated.Enabled) + } + }) + } +} + +// TestUpdateTriggerInvalid tests PUT /api/triggers/{id} with invalid input. +func TestUpdateTriggerInvalid(t *testing.T) { + tests := []struct { + name string + setup Trigger + body string + want int + wantErr string + }{ + { + name: "nonexistent trigger", + setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true}, + body: `{"name": "Nope"}`, + want: http.StatusNotFound, + wantErr: "trigger not found", + }, + { + name: "malformed JSON", + setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true}, + body: `{bad}`, + want: http.StatusBadRequest, + wantErr: "invalid request body", + }, + { + name: "invalid condition", + setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true}, + body: `{"condition": "invalid"}`, + want: http.StatusBadRequest, + wantErr: "condition must be one of", + }, + { + name: "empty body", + setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true}, + body: ``, + want: http.StatusBadRequest, + wantErr: "invalid request body", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + seedTrigger(t, h, tt.setup) + + id := tt.setup.ID + if tt.name == "nonexistent trigger" { + id = "nonexistent" + } + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("PUT", "/api/triggers/"+id, bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.want { + t.Fatalf("expected %d, got %d: %s", tt.want, w.Code, w.Body.String()) + } + if !bytes.Contains(w.Body.Bytes(), []byte(tt.wantErr)) { + t.Errorf("expected error to contain %q, got %s", tt.wantErr, w.Body.String()) + } + }) + } +} + +// TestUpdateTriggerPersists tests that an update is persisted across reload. +func TestUpdateTriggerPersists(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + seedTrigger(t, h, Trigger{ID: "t1", Name: "Original", Condition: "enter", Enabled: true}) + + r := newTriggerTestRouter(h) + body := `{"name": "Updated Name"}` + req := httptest.NewRequest("PUT", "/api/triggers/t1", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("update: expected 200, got %d", w.Code) + } + + // Reload from DB + if err := h.load(); err != nil { + t.Fatalf("reload: %v", err) + } + + req2 := httptest.NewRequest("GET", "/api/triggers", nil) + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + + var result []Trigger + json.NewDecoder(w2.Body).Decode(&result) + if len(result) != 1 { + t.Fatalf("after reload: expected 1 trigger, got %d", len(result)) + } + if result[0].Name != "Updated Name" { + t.Errorf("after reload: expected name 'Updated Name', got %s", result[0].Name) + } +} + +// ── DELETE /api/triggers/{id} ───────────────────────────────────────────────────── + +// TestDeleteTrigger tests DELETE /api/triggers/{id}. +func TestDeleteTrigger(t *testing.T) { + tests := []struct { + name string + setup []Trigger + deleteID string + wantCode int + wantLen int + }{ + { + name: "delete existing trigger", + setup: []Trigger{ + {ID: "t1", Name: "Keep", Condition: "enter", Enabled: true}, + {ID: "t2", Name: "Delete Me", Condition: "dwell", Enabled: true}, + }, + deleteID: "t2", + wantCode: http.StatusNoContent, + wantLen: 1, + }, + { + name: "delete nonexistent trigger", + setup: []Trigger{{ID: "t1", Name: "Only", Condition: "enter", Enabled: true}}, + deleteID: "nonexistent", + wantCode: http.StatusNotFound, + wantLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + for _, tr := range tt.setup { + seedTrigger(t, h, tr) + } + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("DELETE", "/api/triggers/"+tt.deleteID, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.wantCode { + t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String()) + } + + // Verify via list + req2 := httptest.NewRequest("GET", "/api/triggers", nil) + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + var result []Trigger + json.NewDecoder(w2.Body).Decode(&result) + if len(result) != tt.wantLen { + t.Errorf("expected %d triggers after delete, got %d", tt.wantLen, len(result)) + } + + // Verify deleted trigger is not in memory + if tt.wantCode == http.StatusNoContent { + h.mu.RLock() + _, exists := h.triggers[tt.deleteID] + h.mu.RUnlock() + if exists { + t.Error("trigger should be removed from memory") + } + } + }) + } +} + +// ── POST /api/triggers/{id}/test ───────────────────────────────────────────────── + +// TestTestTrigger tests POST /api/triggers/{id}/test. +func TestTestTrigger(t *testing.T) { + tests := []struct { + name string + setup Trigger + testID string + engine TriggerEngine + wantCode int + wantKey string + }{ + { + name: "test with no engine returns simulated", + setup: Trigger{ID: "t1", Name: "Sim", Condition: "dwell", Enabled: true}, + testID: "t1", + engine: nil, + wantCode: http.StatusOK, + wantKey: "simulated", + }, + { + name: "test with engine that succeeds", + setup: Trigger{ID: "t1", Name: "Fire", Condition: "enter", Enabled: true}, + testID: "t1", + engine: &mockEngine{err: nil}, + wantCode: http.StatusOK, + wantKey: "fired", + }, + { + name: "test with engine that fails", + setup: Trigger{ID: "t1", Name: "Fail", Condition: "leave", Enabled: true}, + testID: "t1", + engine: &mockEngine{err: fmt.Errorf("boom")}, + wantCode: http.StatusInternalServerError, + wantKey: "test fire failed", + }, + { + name: "test nonexistent trigger", + setup: Trigger{ID: "t1", Name: "Exists", Condition: "enter", Enabled: true}, + testID: "nonexistent", + engine: nil, + wantCode: http.StatusNotFound, + wantKey: "trigger not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + seedTrigger(t, h, tt.setup) + if tt.engine != nil { + h.SetEngine(tt.engine) + } + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("POST", "/api/triggers/"+tt.testID+"/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.wantCode { + t.Fatalf("expected %d, got %d: %s", tt.wantCode, w.Code, w.Body.String()) + } + + if !bytes.Contains(w.Body.Bytes(), []byte(tt.wantKey)) { + t.Errorf("expected response to contain %q, got %s", tt.wantKey, w.Body.String()) + } + }) + } +} + +// TestTestTriggerDoesNotUpdateLastFired verifies that the test endpoint +// does not modify last_fired on the trigger. +func TestTestTriggerDoesNotUpdateLastFired(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + tr := Trigger{ID: "t1", Name: "Test", Condition: "enter", Enabled: true} + seedTrigger(t, h, tr) + + r := newTriggerTestRouter(h) + req := httptest.NewRequest("POST", "/api/triggers/t1/test", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + h.mu.RLock() + trigger := h.triggers["t1"] + h.mu.RUnlock() + + if trigger.LastFired != nil { + t.Error("expected last_fired to remain nil after test endpoint") + } +} + +// ── CRUD round-trip ─────────────────────────────────────────────────────────────── + +// TestTriggerCRUDRoundTrip verifies the full lifecycle: create -> list -> update -> list -> delete -> verify gone. +func TestTriggerCRUDRoundTrip(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + r := newTriggerTestRouter(h) + + // 1. Create + body := `{"id": "rt", "name": "Round Trip", "condition": "dwell", "condition_params": {"duration_s": 60}}` + req := httptest.NewRequest("POST", "/api/triggers", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("create: expected 201, got %d", w.Code) + } + + // 2. List and verify + req2 := httptest.NewRequest("GET", "/api/triggers", nil) + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + var triggers []Trigger + json.NewDecoder(w2.Body).Decode(&triggers) + if len(triggers) != 1 { + t.Fatalf("after create: expected 1 trigger, got %d", len(triggers)) + } + if triggers[0].Name != "Round Trip" { + t.Errorf("after create: expected name 'Round Trip', got %s", triggers[0].Name) + } + + // 3. Update + body3 := `{"name": "Updated Trip", "enabled": false}` + req3 := httptest.NewRequest("PUT", "/api/triggers/rt", bytes.NewReader([]byte(body3))) + req3.Header.Set("Content-Type", "application/json") + w3 := httptest.NewRecorder() + r.ServeHTTP(w3, req3) + if w3.Code != http.StatusOK { + t.Fatalf("update: expected 200, got %d", w3.Code) + } + + // 4. Verify update + req4 := httptest.NewRequest("GET", "/api/triggers", nil) + w4 := httptest.NewRecorder() + r.ServeHTTP(w4, req4) + json.NewDecoder(w4.Body).Decode(&triggers) + if triggers[0].Name != "Updated Trip" { + t.Errorf("after update: expected name 'Updated Trip', got %s", triggers[0].Name) + } + if triggers[0].Enabled { + t.Error("after update: expected enabled=false") + } + + // 5. Delete + req5 := httptest.NewRequest("DELETE", "/api/triggers/rt", nil) + w5 := httptest.NewRecorder() + r.ServeHTTP(w5, req5) + if w5.Code != http.StatusNoContent { + t.Fatalf("delete: expected 204, got %d", w5.Code) + } + + // 6. Verify gone + req6 := httptest.NewRequest("GET", "/api/triggers", nil) + w6 := httptest.NewRecorder() + r.ServeHTTP(w6, req6) + json.NewDecoder(w6.Body).Decode(&triggers) + if len(triggers) != 0 { + t.Errorf("after delete: expected 0 triggers, got %d", len(triggers)) + } +} + +// ── EvaluateTriggers ───────────────────────────────────────────────────────────── + +// TestEvaluateTriggers tests trigger evaluation logic. +func TestEvaluateTriggers(t *testing.T) { + h, cleanup := newTriggerTestHandler(t) + defer cleanup() + + seedTrigger(t, h, Trigger{ + ID: "vacant", + Name: "House Empty", + Condition: "vacant", + Enabled: true, + }) + seedTrigger(t, h, Trigger{ + ID: "disabled", + Name: "Disabled", + Condition: "vacant", + Enabled: false, + }) + + // No blobs = vacant should fire + fired := h.EvaluateTriggers(nil) + if len(fired) != 1 || fired[0] != "vacant" { + t.Errorf("expected [vacant], got %v", fired) + } + + // With blobs = vacant should not fire + fired = h.EvaluateTriggers([]BlobPos{{ID: 1, X: 1, Y: 1, Z: 1}}) + if len(fired) != 0 { + t.Errorf("expected no fires with blob present, got %v", fired) + } + + // Disabled trigger never fires + fired = h.EvaluateTriggers(nil) + for _, id := range fired { + if id == "disabled" { + t.Error("disabled trigger should not fire") + } + } +} + +// ── mock engine ────────────────────────────────────────────────────────────────── + +type mockEngine struct { + err error +} + +func (m *mockEngine) TestFire(triggerID string) error { return m.err } +func (m *mockEngine) IsInVolume(x, y, z float64, volumeID string) bool { + return true +} diff --git a/mothership/internal/api/utils.go b/mothership/internal/api/utils.go index 3c3b2f8..8980514 100644 --- a/mothership/internal/api/utils.go +++ b/mothership/internal/api/utils.go @@ -5,6 +5,7 @@ import ( "encoding/json" "log" "net/http" + "strings" ) // writeJSON writes a JSON response with the given status code. @@ -28,3 +29,8 @@ func writeJSONData(w http.ResponseWriter, v interface{}) { func writeJSONError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]string{"error": message}) } + +// joinComma joins a slice of strings with ", " separator. +func joinComma(parts []string) string { + return strings.Join(parts, ", ") +} diff --git a/mothership/internal/api/zones_test.go b/mothership/internal/api/zones_test.go new file mode 100644 index 0000000..fda34e7 --- /dev/null +++ b/mothership/internal/api/zones_test.go @@ -0,0 +1,1018 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/go-chi/chi" + "github.com/spaxel/mothership/internal/zones" +) + +// newTestHandler creates a ZonesHandler backed by a temporary zones.Manager. +func newTestHandler(t *testing.T) (*ZonesHandler, func()) { + t.Helper() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "zones.db") + mgr, err := zones.NewManager(dbPath, nil) + if err != nil { + t.Fatalf("Failed to create zones manager: %v", err) + } + handler := NewZonesHandler(mgr) + return handler, func() { mgr.Close() } +} + +// setupRouter creates a chi.Router with all zones/portals routes registered. +func setupRouter(h *ZonesHandler) *chi.Mux { + r := chi.NewRouter() + h.RegisterRoutes(r) + return r +} + +// TestListZones tests GET /api/zones. +func TestListZones(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Seed two zones + if err := h.mgr.CreateZone(&zones.Zone{ + ID: "z1", Name: "Kitchen", MinX: 0, MinY: 0, MinZ: 0, + MaxX: 4, MaxY: 3, MaxZ: 2.5, + }); err != nil { + t.Fatalf("CreateZone: %v", err) + } + if err := h.mgr.CreateZone(&zones.Zone{ + ID: "z2", Name: "Bedroom", MinX: 4, MinY: 0, MinZ: 0, + MaxX: 8, MaxY: 4, MaxZ: 2.5, ZoneType: zones.ZoneTypeBedroom, + }); err != nil { + t.Fatalf("CreateZone: %v", err) + } + + r := setupRouter(h) + req := httptest.NewRequest("GET", "/api/zones", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + var result []zoneWithOcc + if err := json.NewDecoder(rr.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if len(result) != 2 { + t.Fatalf("Expected 2 zones, got %d", len(result)) + } + + // Verify fields + if result[0].ID != "z1" || result[0].Name != "Kitchen" { + t.Errorf("Zone z1 mismatch: %+v", result[0]) + } + if result[1].ID != "z2" || result[1].Name != "Bedroom" { + t.Errorf("Zone z2 mismatch: %+v", result[1]) + } + if result[1].ZoneType != "bedroom" { + t.Errorf("Expected zone_type=bedroom, got %s", result[1].ZoneType) + } + + // Occupancy defaults + for _, z := range result { + if z.Occupancy != 0 { + t.Errorf("Zone %s: expected occupancy=0, got %d", z.ID, z.Occupancy) + } + if z.People == nil { + t.Errorf("Zone %s: expected non-nil people", z.ID) + } + } + + // Verify computed width/depth/height + if result[0].Width != 4 || result[0].Depth != 3 || result[0].Height != 2.5 { + t.Errorf("Zone z1 dimensions wrong: w=%f d=%f h=%f", result[0].Width, result[0].Depth, result[0].Height) + } +} + +// TestListZonesEmpty tests GET /api/zones with no zones. +func TestListZonesEmpty(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + req := httptest.NewRequest("GET", "/api/zones", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d", rr.Code) + } + + var result []zoneWithOcc + if err := json.NewDecoder(rr.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if len(result) != 0 { + t.Errorf("Expected 0 zones, got %d", len(result)) + } +} + +// TestCreateZone tests POST /api/zones. +func TestCreateZone(t *testing.T) { + tests := []struct { + name string + body zones.Zone + wantStatus int + wantID string + }{ + { + name: "create with explicit ID", + body: zones.Zone{ + ID: "kitchen", Name: "Kitchen", + MinX: 0, MinY: 0, MinZ: 0, MaxX: 4, MaxY: 3, MaxZ: 2.5, + }, + wantStatus: http.StatusCreated, + wantID: "kitchen", + }, + { + name: "create with auto-generated ID", + body: zones.Zone{ + Name: "Living Room", + MinX: 4, MinY: 0, MinZ: 0, MaxX: 8, MaxY: 5, MaxZ: 2.5, + }, + wantStatus: http.StatusCreated, + wantID: "zone_", + }, + { + name: "create bedroom zone", + body: zones.Zone{ + ID: "bed1", Name: "Master Bedroom", ZoneType: zones.ZoneTypeBedroom, + MinX: 0, MinY: 5, MinZ: 0, MaxX: 4, MaxY: 9, MaxZ: 2.5, + }, + wantStatus: http.StatusCreated, + wantID: "bed1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest("POST", "/api/zones", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != tt.wantStatus { + t.Fatalf("Expected %d, got %d: %s", tt.wantStatus, rr.Code, rr.Body.String()) + } + + var created zoneWithOcc + if err := json.NewDecoder(rr.Body).Decode(&created); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if created.ID != tt.wantID { + t.Errorf("Expected ID starting with %q, got %q", tt.wantID, created.ID) + } + if created.CreatedAt.IsZero() { + t.Error("Expected non-zero CreatedAt") + } + }) + } +} + +// TestCreateZoneInvalid tests POST /api/zones with invalid input. +func TestCreateZoneInvalid(t *testing.T) { + tests := []struct { + name string + body string + wantMsg string + }{ + { + name: "malformed JSON", + body: `{invalid}`, + wantMsg: "invalid request body", + }, + { + name: "empty body", + body: ``, + wantMsg: "invalid request body", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + req := httptest.NewRequest("POST", "/api/zones", bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("Expected 400, got %d", rr.Code) + } + + var errResp map[string]string + if err := json.NewDecoder(rr.Body).Decode(&errResp); err != nil { + t.Fatalf("Failed to decode error: %v", err) + } + if errResp["error"] == "" { + t.Error("Expected error message") + } + }) + } +} + +// TestUpdateZone tests PUT /api/zones/{id}. +func TestUpdateZone(t *testing.T) { + tests := []struct { + name string + setup zones.Zone + update zones.Zone + wantName string + }{ + { + name: "update zone name", + setup: zones.Zone{ID: "z1", Name: "Kitchen", MinX: 0, MinY: 0, MinZ: 0, MaxX: 4, MaxY: 3, MaxZ: 2.5}, + update: zones.Zone{ID: "z1", Name: "Big Kitchen", MinX: 0, MinY: 0, MinZ: 0, MaxX: 6, MaxY: 5, MaxZ: 3}, + wantName: "Big Kitchen", + }, + { + name: "update zone type to bedroom", + setup: zones.Zone{ID: "z1", Name: "Room", MinX: 0, MinY: 0, MinZ: 0, MaxX: 4, MaxY: 3, MaxZ: 2.5}, + update: zones.Zone{ID: "z1", Name: "Room", MinX: 0, MinY: 0, MinZ: 0, MaxX: 4, MaxY: 3, MaxZ: 2.5, ZoneType: zones.ZoneTypeBedroom}, + wantName: "Room", + }, + { + name: "update zone bounds", + setup: zones.Zone{ID: "z1", Name: "Box", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}, + update: zones.Zone{ID: "z1", Name: "Box", MinX: 2, MinY: 3, MinZ: 1, MaxX: 10, MaxY: 8, MaxZ: 4}, + wantName: "Box", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Setup + if err := h.mgr.CreateZone(&tt.setup); err != nil { + t.Fatalf("CreateZone: %v", err) + } + + // Update + r := setupRouter(h) + body, _ := json.Marshal(tt.update) + req := httptest.NewRequest("PUT", "/api/zones/"+tt.setup.ID, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + var updated zoneWithOcc + if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if updated.Name != tt.wantName { + t.Errorf("Expected name %q, got %q", tt.wantName, updated.Name) + } + if updated.ID != tt.setup.ID { + t.Errorf("Expected ID %q, got %q", tt.setup.ID, updated.ID) + } + + // Verify the update persisted via GET + req2 := httptest.NewRequest("GET", "/api/zones", nil) + rr2 := httptest.NewRecorder() + r.ServeHTTP(rr2, req2) + var allZones []zoneWithOcc + json.NewDecoder(rr2.Body).Decode(&allZones) + found := false + for _, z := range allZones { + if z.ID == tt.setup.ID { + found = true + if z.Name != tt.wantName { + t.Errorf("GET after PUT: expected name %q, got %q", tt.wantName, z.Name) + } + } + } + if !found { + t.Error("Zone not found after update") + } + }) + } +} + +// TestUpdateZoneInvalid tests PUT /api/zones/{id} with invalid input. +func TestUpdateZoneInvalid(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Setup a zone + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "Room", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + + tests := []struct { + name string + body string + want int + }{ + {"malformed JSON", `{bad}`, http.StatusBadRequest}, + {"empty body", ``, http.StatusBadRequest}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := setupRouter(h) + req := httptest.NewRequest("PUT", "/api/zones/z1", bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != tt.want { + t.Errorf("Expected %d, got %d", tt.want, rr.Code) + } + }) + } +} + +// TestUpdateZoneNotFound tests PUT /api/zones/{id} for nonexistent zone. +func TestUpdateZoneNotFound(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + body := `{"name": "Nope"}` + req := httptest.NewRequest("PUT", "/api/zones/nonexistent", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Errorf("Expected 404, got %d", rr.Code) + } +} + +// TestDeleteZone tests DELETE /api/zones/{id}. +func TestDeleteZone(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Setup + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "Room", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "Room2", MinX: 2, MinY: 0, MinZ: 0, MaxX: 3, MaxY: 1, MaxZ: 1}) + + r := setupRouter(h) + + // Delete z1 + req := httptest.NewRequest("DELETE", "/api/zones/z1", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("Expected 204, got %d: %s", rr.Code, rr.Body.String()) + } + + // Verify z1 is gone + if h.mgr.GetZone("z1") != nil { + t.Error("Zone z1 should be deleted") + } + + // Verify z2 still exists + if h.mgr.GetZone("z2") == nil { + t.Error("Zone z2 should still exist") + } + + // Verify via GET + req2 := httptest.NewRequest("GET", "/api/zones", nil) + rr2 := httptest.NewRecorder() + r.ServeHTTP(rr2, req2) + var allZones []zoneWithOcc + json.NewDecoder(rr2.Body).Decode(&allZones) + if len(allZones) != 1 { + t.Errorf("Expected 1 zone after delete, got %d", len(allZones)) + } +} + +// TestDeleteZoneNotFound tests DELETE /api/zones/{id} for nonexistent zone. +func TestDeleteZoneNotFound(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + req := httptest.NewRequest("DELETE", "/api/zones/nonexistent", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // Manager.DeleteZone returns nil error even if zone doesn't exist + if rr.Code != http.StatusNoContent { + t.Fatalf("Expected 204, got %d", rr.Code) + } +} + +// TestGetZoneHistory tests GET /api/zones/{id}/history. +func TestGetZoneHistory(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "Room", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + + tests := []struct { + name string + zoneID string + wantCode int + }{ + {"existing zone", "z1", http.StatusOK}, + {"nonexistent zone", "nope", http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := setupRouter(h) + req := httptest.NewRequest("GET", "/api/zones/"+tt.zoneID+"/history", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != tt.wantCode { + t.Errorf("Expected %d, got %d", tt.wantCode, rr.Code) + } + }) + } +} + +// ── Portals ───────────────────────────────────────────────────────────────────── + +// TestListPortals tests GET /api/portals. +func TestListPortals(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Seed zones for the portals + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "Kitchen", MinX: 0, MinY: 0, MinZ: 0, MaxX: 4, MaxY: 3, MaxZ: 2.5}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "Hallway", MinX: 4, MinY: 0, MinZ: 0, MaxX: 8, MaxY: 3, MaxZ: 2.5}) + + // Create a portal + p := zones.Portal{ + ID: "p1", Name: "Kitchen Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 4, P1Y: 0, P1Z: 0, + P2X: 4, P2Y: 2, P2Z: 0, + P3X: 4, P3Y: 2, P3Z: 2.5, + Width: 2.5, Height: 2.5, + } + if err := h.mgr.CreatePortal(&p); err != nil { + t.Fatalf("CreatePortal: %v", err) + } + + r := setupRouter(h) + req := httptest.NewRequest("GET", "/api/portals", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + var result []portalWithZones + if err := json.NewDecoder(rr.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if len(result) != 1 { + t.Fatalf("Expected 1 portal, got %d", len(result)) + } + if result[0].ID != "p1" { + t.Errorf("Expected portal ID p1, got %s", result[0].ID) + } + if result[0].Name != "Kitchen Door" { + t.Errorf("Expected name 'Kitchen Door', got %s", result[0].Name) + } + // Normal vector should be computed + if result[0].NX == 0 && result[0].NY == 0 && result[0].NZ == 0 { + t.Error("Expected computed normal vector, got zero") + } +} + +// TestCreatePortal tests POST /api/portals. +func TestCreatePortal(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Seed zones + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "B", MinX: 1, MinY: 0, MinZ: 0, MaxX: 2, MaxY: 1, MaxZ: 1}) + + p := zones.Portal{ + ID: "door1", Name: "A-B Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, + P2X: 1, P2Y: 0.5, P2Z: 0, + P3X: 1, P3Y: 0.5, P3Z: 1, + Width: 1, Height: 1, + } + + r := setupRouter(h) + body, _ := json.Marshal(p) + req := httptest.NewRequest("POST", "/api/portals", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("Expected 201, got %d: %s", rr.Code, rr.Body.String()) + } + + var created portalWithZones + if err := json.NewDecoder(rr.Body).Decode(&created); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if created.ID != "door1" { + t.Errorf("Expected ID 'door1', got %s", created.ID) + } + if created.CreatedAt.IsZero() { + t.Error("Expected non-zero CreatedAt") + } + + // Verify it persists + portal := h.mgr.GetPortal("door1") + if portal == nil { + t.Fatal("Portal not found after creation") + } + if portal.Name != "A-B Door" { + t.Errorf("Expected name 'A-B Door', got %s", portal.Name) + } +} + +// TestCreatePortalAutoID tests POST /api/portals with no ID. +func TestCreatePortalAutoID(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "B", MinX: 1, MinY: 0, MinZ: 0, MaxX: 2, MaxY: 1, MaxZ: 1}) + + p := zones.Portal{ + Name: "Auto Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, + P2X: 1, P2Y: 0.5, P2Z: 0, + P3X: 1, P3Y: 0.5, P3Z: 1, + Width: 1, Height: 1, + } + + r := setupRouter(h) + body, _ := json.Marshal(p) + req := httptest.NewRequest("POST", "/api/portals", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("Expected 201, got %d: %s", rr.Code, rr.Body.String()) + } + + var created portalWithZones + json.NewDecoder(rr.Body).Decode(&created) + if created.ID == "" { + t.Error("Expected auto-generated ID, got empty") + } +} + +// TestCreatePortalInvalid tests POST /api/portals with invalid input. +func TestCreatePortalInvalid(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + tests := []struct { + name string + body string + want int + }{ + {"malformed JSON", `{bad}`, http.StatusBadRequest}, + {"empty body", ``, http.StatusBadRequest}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := setupRouter(h) + req := httptest.NewRequest("POST", "/api/portals", bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != tt.want { + t.Errorf("Expected %d, got %d", tt.want, rr.Code) + } + }) + } +} + +// TestCreatePortalInvalidZone tests POST /api/portals with nonexistent zone reference. +func TestCreatePortalInvalidZone(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + + p := zones.Portal{ + ID: "p1", Name: "Bad Zone", ZoneAID: "z1", ZoneBID: "nonexistent", + P1X: 0, P1Y: 0, P1Z: 0, P2X: 1, P2Y: 0, P2Z: 0, P3X: 0, P3Y: 0, P3Z: 1, + Width: 1, Height: 1, + } + + r := setupRouter(h) + body, _ := json.Marshal(p) + req := httptest.NewRequest("POST", "/api/portals", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected 400 for invalid zone_b, got %d", rr.Code) + } +} + +// TestUpdatePortal tests PUT /api/portals/{id}. +func TestUpdatePortal(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "B", MinX: 1, MinY: 0, MinZ: 0, MaxX: 2, MaxY: 1, MaxZ: 1}) + + // Create initial portal + h.mgr.CreatePortal(&zones.Portal{ + ID: "p1", Name: "Old Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, + P2X: 1, P2Y: 0.5, P2Z: 0, + P3X: 1, P3Y: 0.5, P3Z: 1, + Width: 1, Height: 1, + }) + + // Update portal + updated := zones.Portal{ + ID: "p1", Name: "New Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, + P2X: 1, P2Y: 1, P2Z: 0, + P3X: 1, P3Y: 1, P3Z: 2, + Width: 2, Height: 2, + } + + r := setupRouter(h) + body, _ := json.Marshal(updated) + req := httptest.NewRequest("PUT", "/api/portals/p1", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + var result portalWithZones + if err := json.NewDecoder(rr.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + if result.Name != "New Door" { + t.Errorf("Expected name 'New Door', got %s", result.Name) + } + + // Verify persist + p := h.mgr.GetPortal("p1") + if p.Name != "New Door" { + t.Errorf("Persisted name mismatch: %s", p.Name) + } +} + +// TestUpdatePortalInvalid tests PUT /api/portals/{id} with invalid input. +func TestUpdatePortalInvalid(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreatePortal(&zones.Portal{ + ID: "p1", Name: "Door", ZoneAID: "z1", ZoneBID: "z1", + P1X: 0, P1Y: 0, P1Z: 0, P2X: 1, P2Y: 0, P2Z: 0, P3X: 0, P3Y: 0, P3Z: 1, + Width: 1, Height: 1, + }) + + tests := []struct { + name string + body string + want int + }{ + {"malformed JSON", `{bad}`, http.StatusBadRequest}, + {"empty body", ``, http.StatusBadRequest}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := setupRouter(h) + req := httptest.NewRequest("PUT", "/api/portals/p1", bytes.NewReader([]byte(tt.body))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != tt.want { + t.Errorf("Expected %d, got %d", tt.want, rr.Code) + } + }) + } +} + +// TestUpdatePortalNotFound tests PUT /api/portals/{id} for nonexistent portal. +func TestUpdatePortalNotFound(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + body := `{"name": "Nope"}` + req := httptest.NewRequest("PUT", "/api/portals/nonexistent", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Errorf("Expected 404, got %d", rr.Code) + } +} + +// TestDeletePortal tests DELETE /api/portals/{id}. +func TestDeletePortal(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreatePortal(&zones.Portal{ + ID: "p1", Name: "Door", ZoneAID: "z1", ZoneBID: "z1", + P1X: 0, P1Y: 0, P1Z: 0, P2X: 1, P2Y: 0, P2Z: 0, P3X: 0, P3Y: 0, P3Z: 1, + Width: 1, Height: 1, + }) + + r := setupRouter(h) + req := httptest.NewRequest("DELETE", "/api/portals/p1", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("Expected 204, got %d: %s", rr.Code, rr.Body.String()) + } + + if h.mgr.GetPortal("p1") != nil { + t.Error("Portal should be deleted") + } + + // Verify via GET + req2 := httptest.NewRequest("GET", "/api/portals", nil) + rr2 := httptest.NewRecorder() + r.ServeHTTP(rr2, req2) + var result []portalWithZones + json.NewDecoder(rr2.Body).Decode(&result) + if len(result) != 0 { + t.Errorf("Expected 0 portals after delete, got %d", len(result)) + } +} + +// TestDeletePortalNotFound tests DELETE /api/portals/{id} for nonexistent portal. +func TestDeletePortalNotFound(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + req := httptest.NewRequest("DELETE", "/api/portals/nonexistent", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Errorf("Expected 204, got %d", rr.Code) + } +} + +// TestGetPortalCrossings tests GET /api/portals/{id}/crossings. +func TestGetPortalCrossings(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "B", MinX: 1, MinY: 0, MinZ: 0, MaxX: 2, MaxY: 1, MaxZ: 1}) + h.mgr.CreatePortal(&zones.Portal{ + ID: "p1", Name: "Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, P2X: 1, P2Y: 0.5, P2Z: 0, P3X: 1, P3Y: 0.5, P3Z: 1, + Width: 1, Height: 1, + }) + + tests := []struct { + name string + portalID string + wantCode int + }{ + {"existing portal", "p1", http.StatusOK}, + {"nonexistent portal", "nope", http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := setupRouter(h) + req := httptest.NewRequest("GET", "/api/portals/"+tt.portalID+"/crossings", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != tt.wantCode { + t.Errorf("Expected %d, got %d", tt.wantCode, rr.Code) + } + }) + } +} + +// TestPortalNormalComputed verifies that portal normal vector is auto-computed on creation. +func TestPortalNormalComputed(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "B", MinX: 1, MinY: 0, MinZ: 0, MaxX: 2, MaxY: 1, MaxZ: 1}) + + // Portal on the X=1 plane, pointing in +X direction + p := zones.Portal{ + ID: "p1", Name: "Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, + P2X: 1, P2Y: 1, P2Z: 0, + P3X: 1, P3Y: 1, P3Z: 1, + Width: 1, Height: 1, + } + + r := setupRouter(h) + body, _ := json.Marshal(p) + req := httptest.NewRequest("POST", "/api/portals", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("Expected 201, got %d: %s", rr.Code, rr.Body.String()) + } + + var created portalWithZones + json.NewDecoder(rr.Body).Decode(&created) + + // Normal should point in roughly +X direction + if created.NX <= 0 { + t.Errorf("Expected NX > 0 (portal normal in +X), got %f", created.NX) + } + // For this geometry, NY and NZ should be ~0 since the portal is on the X=1 plane + if created.NY > 0.01 || created.NZ > 0.01 { + t.Errorf("Expected NY≈0, NZ≈0 for X=1 plane portal, got NY=%f, NZ=%f", created.NY, created.NZ) + } +} + +// TestZoneCRUDRoundTrip verifies the full lifecycle: create -> read -> update -> read -> delete -> verify gone. +func TestZoneCRUDRoundTrip(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + r := setupRouter(h) + + // 1. Create + zone := zones.Zone{ + ID: "roundtrip", Name: "Initial", ZoneType: zones.ZoneTypeKitchen, + MinX: 0, MinY: 0, MinZ: 0, MaxX: 3, MaxY: 3, MaxZ: 2.5, + } + body, _ := json.Marshal(zone) + req := httptest.NewRequest("POST", "/api/zones", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("Create: expected 201, got %d", rr.Code) + } + + // 2. Read (via list) + req2 := httptest.NewRequest("GET", "/api/zones", nil) + rr2 := httptest.NewRecorder() + r.ServeHTTP(rr2, req2) + var zonesList []zoneWithOcc + json.NewDecoder(rr2.Body).Decode(&zonesList) + if len(zonesList) != 1 { + t.Fatalf("After create: expected 1 zone, got %d", len(zonesList)) + } + if zonesList[0].Name != "Initial" { + t.Errorf("After create: expected name 'Initial', got %s", zonesList[0].Name) + } + + // 3. Update + zone.Name = "Updated" + zone.MaxX = 5 + zone.MaxY = 4 + body, _ = json.Marshal(zone) + req3 := httptest.NewRequest("PUT", "/api/zones/roundtrip", bytes.NewReader(body)) + req3.Header.Set("Content-Type", "application/json") + rr3 := httptest.NewRecorder() + r.ServeHTTP(rr3, req3) + if rr3.Code != http.StatusOK { + t.Fatalf("Update: expected 200, got %d", rr3.Code) + } + + // 4. Read after update + req4 := httptest.NewRequest("GET", "/api/zones", nil) + rr4 := httptest.NewRecorder() + r.ServeHTTP(rr4, req4) + json.NewDecoder(rr4.Body).Decode(&zonesList) + if zonesList[0].Name != "Updated" { + t.Errorf("After update: expected name 'Updated', got %s", zonesList[0].Name) + } + + // 5. Delete + req5 := httptest.NewRequest("DELETE", "/api/zones/roundtrip", nil) + rr5 := httptest.NewRecorder() + r.ServeHTTP(rr5, req5) + if rr5.Code != http.StatusNoContent { + t.Fatalf("Delete: expected 204, got %d", rr5.Code) + } + + // 6. Verify gone + req6 := httptest.NewRequest("GET", "/api/zones", nil) + rr6 := httptest.NewRecorder() + r.ServeHTTP(rr6, req6) + json.NewDecoder(rr6.Body).Decode(&zonesList) + if len(zonesList) != 0 { + t.Errorf("After delete: expected 0 zones, got %d", len(zonesList)) + } +} + +// TestPortalCRUDRoundTrip verifies the full portal lifecycle. +func TestPortalCRUDRoundTrip(t *testing.T) { + h, cleanup := newTestHandler(t) + defer cleanup() + + // Seed zones + h.mgr.CreateZone(&zones.Zone{ID: "z1", Name: "A", MinX: 0, MinY: 0, MinZ: 0, MaxX: 1, MaxY: 1, MaxZ: 1}) + h.mgr.CreateZone(&zones.Zone{ID: "z2", Name: "B", MinX: 1, MinY: 0, MinZ: 0, MaxX: 2, MaxY: 1, MaxZ: 1}) + + r := setupRouter(h) + + // Create + p := zones.Portal{ + ID: "ptrt", Name: "Door", ZoneAID: "z1", ZoneBID: "z2", + P1X: 1, P1Y: 0, P1Z: 0, P2X: 1, P2Y: 0.5, P2Z: 0, P3X: 1, P3Y: 0.5, P3Z: 1, + Width: 1, Height: 1, + } + body, _ := json.Marshal(p) + req := httptest.NewRequest("POST", "/api/portals", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("Create: expected 201, got %d: %s", rr.Code, rr.Body.String()) + } + + // Verify via list + req2 := httptest.NewRequest("GET", "/api/portals", nil) + rr2 := httptest.NewRecorder() + r.ServeHTTP(rr2, req2) + var portals []portalWithZones + json.NewDecoder(rr2.Body).Decode(&portals) + if len(portals) != 1 { + t.Fatalf("Expected 1 portal after create, got %d", len(portals)) + } + + // Update + p.Name = "Big Door" + p.Width = 2 + body, _ = json.Marshal(p) + req3 := httptest.NewRequest("PUT", "/api/portals/ptrt", bytes.NewReader(body)) + req3.Header.Set("Content-Type", "application/json") + rr3 := httptest.NewRecorder() + r.ServeHTTP(rr3, req3) + if rr3.Code != http.StatusOK { + t.Fatalf("Update: expected 200, got %d: %s", rr3.Code, rr3.Body.String()) + } + + // Verify updated + var updated portalWithZones + json.NewDecoder(rr3.Body).Decode(&updated) + if updated.Name != "Big Door" { + t.Errorf("Expected name 'Big Door', got %s", updated.Name) + } + + // Delete + req4 := httptest.NewRequest("DELETE", "/api/portals/ptrt", nil) + rr4 := httptest.NewRecorder() + r.ServeHTTP(rr4, req4) + if rr4.Code != http.StatusNoContent { + t.Fatalf("Delete: expected 204, got %d", rr4.Code) + } + + // Verify gone + req5 := httptest.NewRequest("GET", "/api/portals", nil) + rr5 := httptest.NewRecorder() + r.ServeHTTP(rr5, req5) + json.NewDecoder(rr5.Body).Decode(&portals) + if len(portals) != 0 { + t.Errorf("Expected 0 portals after delete, got %d", len(portals)) + } +} diff --git a/mothership/internal/loadshed/loadshed.go b/mothership/internal/loadshed/loadshed.go new file mode 100644 index 0000000..3743594 --- /dev/null +++ b/mothership/internal/loadshed/loadshed.go @@ -0,0 +1,307 @@ +// Package loadshed implements adaptive load shedding for the mothership fusion pipeline. +// It monitors pipeline iteration timing and applies 4 shedding levels to keep the +// system responsive under CPU/memory pressure, especially for large fleets. +// +// Level 0 (normal): rolling avg < 80 ms — full pipeline +// Level 1 (light): rolling avg >= 80 ms — suspend crowd flow accumulation +// Level 2 (moderate): rolling avg >= 90 ms — also suspend CSI replay buffer writes +// Level 3 (heavy): rolling avg >= 95 ms — drop CSI frames when ingest channel > 50% full; +// push rate reduction config to all nodes (10 Hz cap) +// +// Recovery: when rolling avg < 60 ms for 10 consecutive iterations, step down one level. +// +// All reads of the shedding level are lock-free via atomic operations. +package loadshed + +import ( + "log" + "sync/atomic" + "time" +) + +// Level represents the current load shedding level (0-3). +type Level int32 + +const ( + // LevelNormal is the default: full pipeline, no shedding. + LevelNormal Level = iota + // LevelLight suspends crowd flow accumulation. + LevelLight + // LevelModerate also suspends CSI replay buffer writes. + LevelModerate + // LevelHeavy drops CSI frames on full channels and caps node rate at 10 Hz. + LevelHeavy +) + +// String returns a human-readable label for the shedding level. +func (l Level) String() string { + switch l { + case LevelNormal: + return "NOMINAL" + case LevelLight: + return "LIGHT" + case LevelModerate: + return "MODERATE" + case LevelHeavy: + return "HIGH" + default: + return "UNKNOWN" + } +} + +// Thresholds for load shedding state transitions. +const ( + // Thresholds for escalating shedding levels. + thresholdLevel1 = 80 * time.Millisecond + thresholdLevel2 = 90 * time.Millisecond + thresholdLevel3 = 95 * time.Millisecond + + // Recovery threshold: rolling avg must stay below this for recoveryCount + // consecutive iterations before stepping down one level. + recoveryThreshold = 60 * time.Millisecond + recoveryCount = 10 + + // Level 3 CSI rate cap pushed to all nodes. + level3RateCapHz = 10 + + // Number of iterations in the rolling average window. + rollingWindowSize = 5 +) + +// IngestChannelFull is a callback that reports whether the ingest processing +// channel is more than 50% full. Returns true if shedding should drop frames. +type IngestChannelFull func() bool + +// RatePushCallback is called when Level 3 is entered or exited to push +// rate config changes to all connected nodes. +type RatePushCallback func(rateHz int) + +// Stage represents a named pipeline stage for timing instrumentation. +type Stage struct { + Name string + // Start is set by BeginStage. Use StageDuration() to get elapsed time. + start time.Time +} + +// Shedder manages the load shedding state machine. +type Shedder struct { + level atomic.Int32 // Current shedding level (0-3), read lock-free. + recoveryTicks atomic.Int32 // Consecutive iterations below recovery threshold. + + // Rolling average window (ring buffer). + durations [rollingWindowSize]time.Duration + durationsIdx int + durationsFilled int // how many slots have been written (< rollingWindowSize on startup) + + // Pipeline stage timing for instrumentation. + stages [8]Stage + stageIdx int + + // Iteration timing. + iterStart time.Time + + // External callbacks. + ingestFull IngestChannelFull + ratePush RatePushCallback + + // Previous rate before Level 3 was entered, for restoration. + prevRateHz atomic.Int32 + level3Active atomic.Bool +} + +// New creates a new Shedder. +func New() *Shedder { + return &Shedder{ + prevRateHz: atomic.Int32{}, // defaults to 0 (unknown) + } +} + +// SetIngestChannelFull sets the callback that reports whether the ingest +// channel is more than 50% full. +func (s *Shedder) SetIngestChannelFull(fn IngestChannelFull) { + s.ingestFull = fn +} + +// SetRatePushCallback sets the callback for pushing rate config to nodes. +func (s *Shedder) SetRatePushCallback(fn RatePushCallback) { + s.ratePush = fn +} + +// SetPreviousRate records the node rate that was active before Level 3 +// was entered. Used to restore the rate on recovery. +func (s *Shedder) SetPreviousRate(hz int) { + s.prevRateHz.Store(int32(hz)) +} + +// GetLevel returns the current shedding level (lock-free read). +func (s *Shedder) GetLevel() Level { + return Level(s.level.Load()) +} + +// ShouldAccumulateCrowdFlow returns false when crowd flow accumulation +// should be suspended (Level >= 1). +func (s *Shedder) ShouldAccumulateCrowdFlow() bool { + return s.level.Load() < int32(LevelLight) +} + +// ShouldWriteReplay returns false when CSI replay buffer writes +// should be suspended (Level >= 2). +func (s *Shedder) ShouldWriteReplay() bool { + return s.level.Load() < int32(LevelModerate) +} + +// ShouldDropFrames returns true when CSI frames should be dropped because +// the ingest channel is more than 50% full (Level 3 only). +func (s *Shedder) ShouldDropFrames() bool { + return s.level.Load() >= int32(LevelHeavy) && s.ingestFull != nil && s.ingestFull() +} + +// IsLevel3Active returns true when Level 3 shedding is active. +func (s *Shedder) IsLevel3Active() bool { + return s.level3Active.Load() +} + +// GetLevel3RateCap returns the rate cap applied during Level 3 (10 Hz). +func (s *Shedder) GetLevel3RateCap() int { + return level3RateCapHz +} + +// BeginIteration marks the start of a pipeline iteration. Call this at the +// beginning of each fusion tick. +func (s *Shedder) BeginIteration() { + s.iterStart = time.Now() + s.stageIdx = 0 +} + +// BeginStage starts timing a named pipeline stage. Returns a Stage handle +// whose duration is captured on EndIteration. +func (s *Shedder) BeginStage(name string) Stage { + st := Stage{Name: name, start: time.Now()} + if s.stageIdx < len(s.stages) { + s.stages[s.stageIdx] = st + } + return st +} + +// EndStage marks the end of a pipeline stage. +func (s *Shedder) EndStage(st Stage) { + _ = st // duration computed lazily in GetStageDurations +} + +// GetStageDurations returns the durations of all stages from the most recent +// completed iteration. +func (s *Shedder) GetStageDurations() []time.Duration { + n := s.stageIdx + if n > len(s.stages) { + n = len(s.stages) + } + result := make([]time.Duration, n) + for i := 0; i < n; i++ { + result[i] = time.Since(s.stages[i].start) + } + return result +} + +// EndIteration marks the end of a pipeline iteration, updates the rolling +// average, and evaluates the shedding state machine. +func (s *Shedder) EndIteration() { + elapsed := time.Since(s.iterStart) + + // Update rolling average window. + s.durations[s.durationsIdx] = elapsed + s.durationsIdx = (s.durationsIdx + 1) % rollingWindowSize + if s.durationsFilled < rollingWindowSize { + s.durationsFilled++ + } + + avg := s.rollingAvg() + + // Evaluate state machine. + prevLevel := Level(s.level.Load()) + var newLevel Level + + if avg >= thresholdLevel3 { + newLevel = LevelHeavy + } else if avg >= thresholdLevel2 { + newLevel = LevelModerate + } else if avg >= thresholdLevel1 { + newLevel = LevelLight + } else { + // Below all escalation thresholds — check recovery. + if avg < recoveryThreshold { + ticks := s.recoveryTicks.Add(1) + if ticks >= recoveryCount && prevLevel > LevelNormal { + newLevel = prevLevel - 1 + s.recoveryTicks.Store(0) + } else { + newLevel = prevLevel + } + } else { + // Between recovery threshold and Level 1 — hold current level. + s.recoveryTicks.Store(0) + newLevel = prevLevel + } + } + + // Level can only go UP directly to the new level (no gradual escalation), + // but recovery steps down one level at a time. + if newLevel > prevLevel { + // Escalate directly. + s.setLevel(newLevel) + } else if newLevel < prevLevel { + // Recovery step down. + s.setLevel(newLevel) + } else { + // Reset recovery counter if we didn't step down this tick + // and we're not in recovery mode. + if avg >= recoveryThreshold { + s.recoveryTicks.Store(0) + } + } +} + +// setLevel applies a level change and logs it. +func (s *Shedder) setLevel(new Level) { + prev := Level(s.level.Swap(int32(new))) + if prev == new { + return + } + log.Printf("[INFO] Load shedding level changed: %s (%d) → %s (%d)", prev, prev, new, new) + + // Level 3 enter/exit: push rate config to nodes. + if new == LevelHeavy && prev < LevelHeavy { + // Entering Level 3. + s.level3Active.Store(true) + if s.ratePush != nil { + s.ratePush(level3RateCapHz) + } + } else if prev == LevelHeavy && new < LevelHeavy { + // Exiting Level 3: restore previous rate. + s.level3Active.Store(false) + if s.ratePush != nil { + prevRate := int(s.prevRateHz.Load()) + if prevRate <= 0 { + prevRate = 20 // sensible default + } + s.ratePush(prevRate) + } + } +} + +// rollingAvg computes the average iteration duration over the rolling window. +func (s *Shedder) rollingAvg() time.Duration { + n := s.durationsFilled + if n == 0 { + return 0 + } + var sum time.Duration + for i := 0; i < n; i++ { + sum += s.durations[i] + } + return sum / time.Duration(n) +} + +// RollingAvg returns the current rolling average iteration time (for diagnostics). +func (s *Shedder) RollingAvg() time.Duration { + return s.rollingAvg() +} diff --git a/mothership/internal/loadshed/loadshed_test.go b/mothership/internal/loadshed/loadshed_test.go new file mode 100644 index 0000000..599e83b --- /dev/null +++ b/mothership/internal/loadshed/loadshed_test.go @@ -0,0 +1,403 @@ +package loadshed + +import ( + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestLevelString(t *testing.T) { + tests := []struct { + l Level + want string + }{ + {LevelNormal, "NOMINAL"}, + {LevelLight, "LIGHT"}, + {LevelModerate, "MODERATE"}, + {LevelHeavy, "HIGH"}, + {Level(99), "UNKNOWN"}, + } + for _, tt := range tests { + if got := tt.l.String(); got != tt.want { + t.Errorf("Level(%d).String() = %q, want %q", tt.l, got, tt.want) + } + } +} + +func TestShedderDefaultLevel(t *testing.T) { + s := New() + if got := s.GetLevel(); got != LevelNormal { + t.Errorf("New Shedder level = %d, want %d", got, LevelNormal) + } +} + +func TestShedderCrowdFlowSuspended(t *testing.T) { + tests := []struct { + level Level + want bool + }{ + {LevelNormal, true}, + {LevelLight, false}, + {LevelModerate, false}, + {LevelHeavy, false}, + } + for _, tt := range tests { + s := New() + s.level.Store(int32(tt.level)) + if got := s.ShouldAccumulateCrowdFlow(); got != tt.want { + t.Errorf("ShouldAccumulateCrowdFlow() at level %d = %v, want %v", tt.level, got, tt.want) + } + } +} + +func TestShedderReplayWriteSuspended(t *testing.T) { + tests := []struct { + level Level + want bool + }{ + {LevelNormal, true}, + {LevelLight, true}, + {LevelModerate, false}, + {LevelHeavy, false}, + } + for _, tt := range tests { + s := New() + s.level.Store(int32(tt.level)) + if got := s.ShouldWriteReplay(); got != tt.want { + t.Errorf("ShouldWriteReplay() at level %d = %v, want %v", tt.level, got, tt.want) + } + } +} + +func TestShedderShouldDropFrames(t *testing.T) { + s := New() + + // Level 3 with channel full → drop + s.level.Store(int32(LevelHeavy)) + channelFull := true + s.SetIngestChannelFull(func() bool { return channelFull }) + if !s.ShouldDropFrames() { + t.Error("ShouldDropFrames() = false at Level 3 with full channel, want true") + } + + // Level 3 with channel not full → don't drop + channelFull = false + if s.ShouldDropFrames() { + t.Error("ShouldDropFrames() = true at Level 3 with empty channel, want false") + } + + // Level 2 → never drop regardless of channel + s.level.Store(int32(LevelModerate)) + channelFull = true + if s.ShouldDropFrames() { + t.Error("ShouldDropFrames() = true at Level 2, want false") + } + + // No callback set → don't drop + s2 := New() + s2.level.Store(int32(LevelHeavy)) + if s2.ShouldDropFrames() { + t.Error("ShouldDropFrames() = true with no callback, want false") + } +} + +func TestRollingAverage(t *testing.T) { + s := New() + + // Fill the window with 5 iterations of 50ms each. + for i := 0; i < 5; i++ { + s.BeginIteration() + s.EndIteration() // duration ≈ 0ms (no actual work) + // Manually set the duration since the test runs too fast. + s.durations[i] = 50 * time.Millisecond + } + s.durationsFilled = 5 + + if avg := s.rollingAvg(); avg != 50*time.Millisecond { + t.Errorf("rollingAvg() = %v, want 50ms", avg) + } +} + +func TestEscalationFromNormal(t *testing.T) { + s := New() + + // Fill window with 5x 85ms → should escalate to Level 1 + fillWindow(s, 85*time.Millisecond) + s.EndIteration() + + // The last EndIteration added a 0ms duration; let's force the durations. + // Instead, let's directly test the state machine logic by simulating. + s2 := New() + // Pre-fill 4 slots at 85ms + for i := 0; i < 4; i++ { + s2.durations[i] = 85 * time.Millisecond + } + s2.durationsFilled = 4 + + // Now run an iteration that adds 85ms (total 5 slots, avg 85ms) + s2.durations[4] = 85 * time.Millisecond + s2.durationsIdx = 0 // wrapped + s2.durationsFilled = 5 + + if avg := s2.rollingAvg(); avg != 85*time.Millisecond { + t.Fatalf("rollingAvg() = %v, want 85ms", avg) + } + + // Directly test setLevel behavior + s2.setLevel(LevelLight) + if s2.GetLevel() != LevelLight { + t.Errorf("level = %d, want %d", s2.GetLevel(), LevelLight) + } +} + +func TestEscalationToLevel3(t *testing.T) { + s := New() + + // Pre-fill window at 97ms → Level 3 + for i := 0; i < rollingWindowSize; i++ { + s.durations[i] = 97 * time.Millisecond + } + s.durationsFilled = rollingWindowSize + + avg := s.rollingAvg() + if avg < thresholdLevel3 { + t.Fatalf("rollingAvg() = %v, expected >= %v", avg, thresholdLevel3) + } +} + +func TestRecoveryStepDown(t *testing.T) { + s := New() + s.level.Store(int32(LevelHeavy)) + + // Simulate recovery: 10 consecutive iterations below 60ms. + for i := 0; i < recoveryCount; i++ { + s.durations[i%rollingWindowSize] = 50 * time.Millisecond + s.durationsIdx = (i + 1) % rollingWindowSize + s.durationsFilled = min(i+1, rollingWindowSize) + + prevLevel := Level(s.level.Load()) + ticks := s.recoveryTicks.Add(1) + var newLevel Level + if ticks >= recoveryCount && prevLevel > LevelNormal { + newLevel = prevLevel - 1 + s.recoveryTicks.Store(0) + } else { + newLevel = prevLevel + } + s.setLevel(newLevel) + } + + if s.GetLevel() != LevelModerate { + t.Errorf("after recovery, level = %d, want %d", s.GetLevel(), LevelModerate) + } +} + +func TestRecoveryFullSequence(t *testing.T) { + s := New() + s.level.Store(int32(LevelModerate)) + + // Need 10 iterations below 60ms to recover from Level 2 → Level 1. + for i := 0; i < recoveryCount; i++ { + s.durations[i%rollingWindowSize] = 50 * time.Millisecond + s.durationsIdx = (i + 1) % rollingWindowSize + s.durationsFilled = min(i+1, rollingWindowSize) + + prevLevel := Level(s.level.Load()) + ticks := s.recoveryTicks.Add(1) + var newLevel Level + if ticks >= recoveryCount && prevLevel > LevelNormal { + newLevel = prevLevel - 1 + s.recoveryTicks.Store(0) + } else { + newLevel = prevLevel + } + s.setLevel(newLevel) + } + + if s.GetLevel() != LevelLight { + t.Errorf("after recovery from L2, level = %d, want %d", s.GetLevel(), LevelLight) + } +} + +func TestRecoveryCounterResetOnAboveThreshold(t *testing.T) { + s := New() + s.level.Store(int32(LevelLight)) + + // 5 iterations below threshold, then one above. + for i := 0; i < 5; i++ { + s.durations[i] = 50 * time.Millisecond + s.durationsFilled = i + 1 + s.recoveryTicks.Add(1) + } + // One iteration above recovery threshold but below L1 + s.durations[5%rollingWindowSize] = 70 * time.Millisecond + s.durationsFilled = 6 + + avg := s.rollingAvg() + if avg >= recoveryThreshold { + s.recoveryTicks.Store(0) + } + + if s.recoveryTicks.Load() != 0 { + t.Errorf("recovery ticks should be reset after above-threshold iteration, got %d", s.recoveryTicks.Load()) + } +} + +func TestLevel3RatePushCallback(t *testing.T) { + s := New() + + var pushedRate atomic.Int32 + s.SetRatePushCallback(func(rateHz int) { + pushedRate.Store(int32(rateHz)) + }) + s.SetPreviousRate(20) + + // Enter Level 3 + s.setLevel(LevelHeavy) + if pushedRate.Load() != level3RateCapHz { + t.Errorf("rate push on L3 enter = %d, want %d", pushedRate.Load(), level3RateCapHz) + } + if !s.IsLevel3Active() { + t.Error("IsLevel3Active() = false after entering L3") + } + + // Exit Level 3 + s.setLevel(LevelModerate) + if pushedRate.Load() != 20 { + t.Errorf("rate push on L3 exit = %d, want 20", pushedRate.Load()) + } + if s.IsLevel3Active() { + t.Error("IsLevel3Active() = true after exiting L3") + } +} + +func TestLevel3RestoreDefaultRate(t *testing.T) { + s := New() + + var pushedRate atomic.Int32 + s.SetRatePushCallback(func(rateHz int) { + pushedRate.Store(int32(rateHz)) + }) + // Don't set previous rate — should default to 20. + + s.setLevel(LevelHeavy) + pushedRate.Store(0) // reset + + s.setLevel(LevelModerate) + if pushedRate.Load() != 20 { + t.Errorf("rate push on L3 exit without prev = %d, want 20", pushedRate.Load()) + } +} + +func TestNoRatePushWithoutCallback(t *testing.T) { + s := New() + // No callback set — should not panic. + s.setLevel(LevelHeavy) + s.setLevel(LevelNormal) + if s.GetLevel() != LevelNormal { + t.Errorf("level = %d, want %d", s.GetLevel(), LevelNormal) + } +} + +func TestSetLevelLogsOnce(t *testing.T) { + s := New() + // Setting same level should be a no-op (no extra log). + s.setLevel(LevelLight) + s.setLevel(LevelLight) + if s.GetLevel() != LevelLight { + t.Errorf("level = %d, want %d", s.GetLevel(), LevelLight) + } +} + +func TestConcurrentLevelReads(t *testing.T) { + s := New() + s.level.Store(int32(LevelModerate)) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + l := s.GetLevel() + if l != LevelModerate { + t.Errorf("concurrent read got %d, want %d", l, LevelModerate) + } + }() + } + wg.Wait() +} + +func TestRollingAvgPartialWindow(t *testing.T) { + s := New() + + // Only 3 of 5 slots filled. + s.durations[0] = 100 * time.Millisecond + s.durations[1] = 200 * time.Millisecond + s.durations[2] = 300 * time.Millisecond + s.durationsFilled = 3 + + want := 200 * time.Millisecond + if got := s.rollingAvg(); got != want { + t.Errorf("partial window avg = %v, want %v", got, want) + } +} + +func TestGetLevel3RateCap(t *testing.T) { + s := New() + if s.GetLevel3RateCap() != 10 { + t.Errorf("GetLevel3RateCap() = %d, want 10", s.GetLevel3RateCap()) + } +} + +func TestLockFreeReads(t *testing.T) { + s := New() + + // Verify all "query" methods use only atomic reads (no mutex). + // This is a design assertion — the test confirms no data races under + // concurrent writes and reads. + var wg sync.WaitGroup + + // Writer goroutine: rapidly change levels. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 1000; i++ { + s.setLevel(Level(i % 4)) + } + }() + + // Reader goroutines: concurrently query all lock-free methods. + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 1000; j++ { + _ = s.GetLevel() + _ = s.ShouldAccumulateCrowdFlow() + _ = s.ShouldWriteReplay() + _ = s.IsLevel3Active() + _ = s.GetLevel3RateCap() + _ = s.RollingAvg() + } + }() + } + + wg.Wait() +} + +// fillWindow is a test helper that fills the rolling window with a given duration. +func fillWindow(s *Shedder, d time.Duration) { + for i := 0; i < rollingWindowSize; i++ { + s.durations[i] = d + } + s.durationsIdx = 0 + s.durationsFilled = rollingWindowSize +} + +func min(a, b int) int { + if a < b { + return a + } + return b +}