From 72e155391febcb9045e905c78952d80ec0bd8ecf Mon Sep 17 00:00:00 2001 From: jedarden Date: Thu, 9 Apr 2026 07:14:32 -0400 Subject: [PATCH] feat: implement REST API endpoint for node identify - POST /api/nodes/{mac}/identify endpoint with {duration_ms: 5000} body - Forwards identify message as WebSocket JSON to target node - Returns 404 if node not connected; 200 on success - Includes table-driven tests for all edge cases Co-Authored-By: Claude Opus 4.6 --- mothership/cmd/mothership/main.go | 1 + mothership/internal/fleet/handler.go | 55 +++- mothership/internal/fleet/handler_test.go | 335 ++++++++++++++++++++++ mothership/internal/ingestion/server.go | 26 ++ 4 files changed, 416 insertions(+), 1 deletion(-) create mode 100644 mothership/internal/fleet/handler_test.go diff --git a/mothership/cmd/mothership/main.go b/mothership/cmd/mothership/main.go index ac4d16a..345d217 100644 --- a/mothership/cmd/mothership/main.go +++ b/mothership/cmd/mothership/main.go @@ -1991,6 +1991,7 @@ func main() { // Fleet REST API fleetHandler := fleet.NewHandler(fleetMgr) + fleetHandler.SetNodeIdentifier(ingestSrv) fleetHandler.RegisterRoutes(r) // Floorplan REST API diff --git a/mothership/internal/fleet/handler.go b/mothership/internal/fleet/handler.go index 5e39047..5cfd1d3 100644 --- a/mothership/internal/fleet/handler.go +++ b/mothership/internal/fleet/handler.go @@ -10,9 +10,15 @@ import ( "github.com/spaxel/mothership/internal/events" ) +// NodeIdentifier sends identify commands to connected nodes. +type NodeIdentifier interface { + SendIdentifyToMAC(mac string, durationMS int) bool +} + // Handler serves the fleet REST API. type Handler struct { - mgr *Manager + mgr *Manager + nodeID NodeIdentifier } // NewHandler creates a new fleet REST handler backed by mgr. @@ -20,6 +26,11 @@ func NewHandler(mgr *Manager) *Handler { return &Handler{mgr: mgr} } +// SetNodeIdentifier sets the node identifier for sending identify commands. +func (h *Handler) SetNodeIdentifier(ni NodeIdentifier) { + h.nodeID = ni +} + // RegisterRoutes mounts fleet endpoints on r. // // GET /api/nodes — list all nodes @@ -27,6 +38,7 @@ func NewHandler(mgr *Manager) *Handler { // POST /api/nodes/{mac}/role — override node role // PUT /api/nodes/{mac}/position — update node 3D position // DELETE /api/nodes/{mac} — delete a node +// POST /api/nodes/{mac}/identify — blink LED for identification // POST /api/nodes/virtual — add a virtual planning node // PUT /api/room — update room dimensions func (h *Handler) RegisterRoutes(r chi.Router) { @@ -35,6 +47,7 @@ func (h *Handler) RegisterRoutes(r chi.Router) { r.Post("/api/nodes/{mac}/role", h.setNodeRole) r.Put("/api/nodes/{mac}/position", h.updateNodePosition) r.Delete("/api/nodes/{mac}", h.deleteNode) + r.Post("/api/nodes/{mac}/identify", h.identifyNode) r.Post("/api/nodes/virtual", h.addVirtualNode) r.Put("/api/room", h.updateRoom) // System mode endpoints @@ -166,6 +179,46 @@ func (h *Handler) deleteNode(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } +type identifyNodeRequest struct { + DurationMS int `json:"duration_ms"` +} + +func (h *Handler) identifyNode(w http.ResponseWriter, r *http.Request) { + mac := chi.URLParam(r, "mac") + + // Verify node exists. + if _, err := h.mgr.registry.GetNode(mac); errors.Is(err, sql.ErrNoRows) { + http.Error(w, "node not found", http.StatusNotFound) + return + } else if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + // Parse request body. + var req identifyNodeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + // Default to 5000ms if not specified. + durationMS := req.DurationMS + if durationMS <= 0 { + durationMS = 5000 + } + + // Send identify command if node identifier is available. + if h.nodeID != nil { + if !h.nodeID.SendIdentifyToMAC(mac, durationMS) { + http.Error(w, "node not connected", http.StatusNotFound) + return + } + } + + writeJSON(w, map[string]bool{"ok": true}) +} + type updateRoomRequest struct { Width float64 `json:"width"` Depth float64 `json:"depth"` diff --git a/mothership/internal/fleet/handler_test.go b/mothership/internal/fleet/handler_test.go new file mode 100644 index 0000000..187a46b --- /dev/null +++ b/mothership/internal/fleet/handler_test.go @@ -0,0 +1,335 @@ +package fleet + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +// mockNodeIdentifier is a mock implementation of NodeIdentifier for testing. +type mockNodeIdentifier struct { + sendIdentifyFunc func(mac string, durationMS int) bool +} + +func (m *mockNodeIdentifier) SendIdentifyToMAC(mac string, durationMS int) bool { + if m.sendIdentifyFunc != nil { + return m.sendIdentifyFunc(mac, durationMS) + } + return true +} + +// mockRegistry is a minimal mock of Registry for testing. +type mockRegistry struct { + nodes map[string]NodeRecord + err error +} + +func (m *mockRegistry) GetNode(mac string) (NodeRecord, error) { + if m.err != nil { + return NodeRecord{}, m.err + } + if node, ok := m.nodes[mac]; ok { + return node, nil + } + return NodeRecord{}, sql.ErrNoRows +} + +func (m *mockRegistry) GetAllNodes() ([]NodeRecord, error) { + var nodes []NodeRecord + for _, node := range m.nodes { + nodes = append(nodes, node) + } + return nodes, m.err +} + +func (m *mockRegistry) SetNodePosition(mac string, x, y, z float64) error { + return nil +} + +func (m *mockRegistry) AddVirtualNode(mac, name string, x, y, z float64) error { + return nil +} + +func (m *mockRegistry) DeleteNode(mac string) error { + return nil +} + +func (m *mockRegistry) SetRoom(room RoomConfig) error { + return nil +} + +func (m *mockRegistry) GetRoom() (RoomConfig, error) { + return RoomConfig{}, nil +} + +func (m *mockRegistry) GetNodesByRole(role string) ([]NodeRecord, error) { + return nil, nil +} + +func TestHandlerIdentifyNode(t *testing.T) { + tests := []struct { + name string + mac string + reqBody string + nodeExists bool + nodeConnected bool + wantStatus int + wantResponse string + }{ + { + name: "successful identify with default duration", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `{}`, + nodeExists: true, + nodeConnected: true, + wantStatus: http.StatusOK, + wantResponse: `{"ok":true}`, + }, + { + name: "successful identify with custom duration", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `{"duration_ms": 10000}`, + nodeExists: true, + nodeConnected: true, + wantStatus: http.StatusOK, + wantResponse: `{"ok":true}`, + }, + { + name: "node not found", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `{}`, + nodeExists: false, + nodeConnected: true, + wantStatus: http.StatusNotFound, + wantResponse: "node not found\n", + }, + { + name: "node not connected", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `{}`, + nodeExists: true, + nodeConnected: false, + wantStatus: http.StatusNotFound, + wantResponse: "node not connected\n", + }, + { + name: "invalid request body", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `invalid json`, + nodeExists: true, + nodeConnected: true, + wantStatus: http.StatusBadRequest, + wantResponse: "invalid request body\n", + }, + { + name: "zero duration uses default", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `{"duration_ms": 0}`, + nodeExists: true, + nodeConnected: true, + wantStatus: http.StatusOK, + wantResponse: `{"ok":true}`, + }, + { + name: "negative duration uses default", + mac: "AA:BB:CC:DD:EE:FF", + reqBody: `{"duration_ms": -1000}`, + nodeExists: true, + nodeConnected: true, + wantStatus: http.StatusOK, + wantResponse: `{"ok":true}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock registry with the test node + reg := &mockRegistry{ + nodes: make(map[string]NodeRecord), + } + if tt.nodeExists { + reg.nodes[tt.mac] = NodeRecord{ + MAC: tt.mac, + Name: "Test Node", + Role: "rx", + } + } + + // Create a manager with the mock registry + mgr := &Manager{ + registry: reg, + } + + // Create handler with mock node identifier + h := &Handler{ + mgr: mgr, + nodeID: &mockNodeIdentifier{ + sendIdentifyFunc: func(mac string, durationMS int) bool { + return tt.nodeConnected + }, + }, + } + + // Create a test request + req := httptest.NewRequest("POST", "/api/nodes/"+tt.mac+"/identify", bytes.NewBufferString(tt.reqBody)) + req.Header.Set("Content-Type", "application/json") + + // Use chi URLParam to set the MAC parameter + rctx := chi.NewRouteContext() + rctx.URLParams.Add("mac", tt.mac) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Create response recorder + w := httptest.NewRecorder() + + // Call the handler + h.identifyNode(w, req) + + // Check status code + if w.Code != tt.wantStatus { + t.Errorf("identifyNode() status = %v, want %v", w.Code, tt.wantStatus) + } + + // Check response body + if tt.wantResponse != "" { + resp := w.Body.String() + if resp != tt.wantResponse { + t.Errorf("identifyNode() response = %q, want %q", resp, tt.wantResponse) + } + } + }) + } +} + +func TestHandlerIdentifyNodeDurationParsing(t *testing.T) { + tests := []struct { + name string + reqBody string + expectedDuration int + }{ + { + name: "default duration when not specified", + reqBody: `{}`, + expectedDuration: 5000, + }, + { + name: "custom duration", + reqBody: `{"duration_ms": 10000}`, + expectedDuration: 10000, + }, + { + name: "zero uses default", + reqBody: `{"duration_ms": 0}`, + expectedDuration: 5000, + }, + { + name: "negative uses default", + reqBody: `{"duration_ms": -1000}`, + expectedDuration: 5000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var actualDuration int + + reg := &mockRegistry{ + nodes: map[string]NodeRecord{ + "AA:BB:CC:DD:EE:FF": { + MAC: "AA:BB:CC:DD:EE:FF", + Name: "Test Node", + Role: "rx", + }, + }, + } + + mgr := &Manager{ + registry: reg, + } + + h := &Handler{ + mgr: mgr, + nodeID: &mockNodeIdentifier{ + sendIdentifyFunc: func(mac string, durationMS int) bool { + actualDuration = durationMS + return true + }, + }, + } + + req := httptest.NewRequest("POST", "/api/nodes/AA:BB:CC:DD:EE:FF/identify", bytes.NewBufferString(tt.reqBody)) + req.Header.Set("Content-Type", "application/json") + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("mac", "AA:BB:CC:DD:EE:FF") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + h.identifyNode(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %v", w.Code) + } + + if actualDuration != tt.expectedDuration { + t.Errorf("Duration = %v, want %v", actualDuration, tt.expectedDuration) + } + }) + } +} + +func TestIdentifyNodeRequest(t *testing.T) { + tests := []struct { + name string + json string + wantErr bool + }{ + { + name: "valid empty object", + json: `{}`, + wantErr: false, + }, + { + name: "valid with duration", + json: `{"duration_ms": 10000}`, + wantErr: false, + }, + { + name: "valid with zero duration", + json: `{"duration_ms": 0}`, + wantErr: false, + }, + { + name: "valid with negative duration", + json: `{"duration_ms": -1000}`, + wantErr: false, + }, + { + name: "invalid json", + json: `invalid`, + wantErr: true, + }, + { + name: "extra fields ignored", + json: `{"duration_ms": 5000, "extra": "ignored"}`, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req identifyNodeRequest + err := json.NewDecoder(bytes.NewBufferString(tt.json)).Decode(&req) + + if (err != nil) != tt.wantErr { + t.Errorf("json.NewDecoder().Decode() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/mothership/internal/ingestion/server.go b/mothership/internal/ingestion/server.go index 18be39a..74f88c9 100644 --- a/mothership/internal/ingestion/server.go +++ b/mothership/internal/ingestion/server.go @@ -334,6 +334,32 @@ func (s *Server) SendOTAToMAC(mac, url, sha256, version string) { log.Printf("[INFO] Sent OTA trigger to node %s: version=%s url=%s", mac, version, url) } +// SendIdentifyToMAC sends an LED blink command to a connected node. +// Returns false if the node is not connected. +func (s *Server) SendIdentifyToMAC(mac string, durationMS int) bool { + s.mu.RLock() + nc, ok := s.connections[mac] + s.mu.RUnlock() + if !ok { + return false + } + msg := IdentifyMessage{Type: "identify", DurationMS: durationMS} + data, _ := json.Marshal(msg) + nc.writeMu.Lock() + nc.Conn.WriteMessage(websocket.TextMessage, data) + nc.writeMu.Unlock() + log.Printf("[INFO] Sent identify command to node %s: duration=%dms", mac, durationMS) + return true +} + +// IsNodeConnected returns true if the node with the given MAC is currently connected. +func (s *Server) IsNodeConnected(mac string) bool { + s.mu.RLock() + _, ok := s.connections[mac] + s.mu.RUnlock() + return ok +} + // HandleNodeWS handles WebSocket connections at /ws/node func (s *Server) HandleNodeWS(w http.ResponseWriter, r *http.Request) { // Step 1 of shutdown: return HTTP 503 for new WebSocket upgrade requests