From 582e45734c127d5d13e87f2b73a4dd05fecc0214 Mon Sep 17 00:00:00 2001 From: jedarden Date: Tue, 5 May 2026 08:16:53 -0400 Subject: [PATCH] api: add GET /api/doctor endpoint for pre-flight configuration diagnostics The /api/doctor endpoint complements /healthz (runtime state) with configuration correctness checks. Returns 200 with a JSON report containing all check results. Checks implemented: - data_dir_writable: verifies /data is writable with >100 MB free - db_integrity: runs PRAGMA integrity_check - firmware_dir: checks for *.bin files in /firmware - mdns_binding: verifies mDNS service is registered (or SPAXEL_MDNS_ENABLED=false) - mqtt_reachable: TCP connectivity test if SPAXEL_MQTT_BROKER is set - ntp_reachable: UDP connectivity test if SPAXEL_NTP_SERVER is set - install_secret: verifies install_secret row exists in auth table - pin_configured: verifies pin_bcrypt is non-null in auth table - node_token_consistency: verifies all nodes have valid MAC addresses Response format includes overall status (ok/warn/error), individual check results with name/status/message, and checked_at timestamp. Requires session cookie authentication via authHandler.RequireAuth. Co-Authored-By: Claude Opus 4.7 --- mothership/cmd/mothership/main.go | 42 ++ mothership/internal/doctor/doctor.go | 459 +++++++++++++++++ mothership/internal/doctor/doctor_test.go | 601 ++++++++++++++++++++++ 3 files changed, 1102 insertions(+) create mode 100644 mothership/internal/doctor/doctor.go create mode 100644 mothership/internal/doctor/doctor_test.go diff --git a/mothership/cmd/mothership/main.go b/mothership/cmd/mothership/main.go index 1211598..2877bdf 100644 --- a/mothership/cmd/mothership/main.go +++ b/mothership/cmd/mothership/main.go @@ -31,6 +31,7 @@ import ( "github.com/spaxel/mothership/internal/dashboard" "github.com/spaxel/mothership/internal/db" "github.com/spaxel/mothership/internal/diagnostics" + "github.com/spaxel/mothership/internal/doctor" "github.com/spaxel/mothership/internal/eventbus" "github.com/spaxel/mothership/internal/events" "github.com/spaxel/mothership/internal/explainability" @@ -512,6 +513,8 @@ func main() { settingsHandler.RegisterRoutes(r) log.Printf("[INFO] Settings API registered at /api/settings") + // Note: Doctor API is registered after mdnsServer initialization + // Phase 6: Integration Settings REST API (MQTT + system webhook) // Note: mqttClient and webhookPublisher are wired below after they are initialized. integrationSettingsHandler := api.NewIntegrationSettingsHandler(mainDB, "") @@ -4068,6 +4071,45 @@ func main() { } } + // Phase 6: Pre-flight diagnostics API + // Get install secret from database for doctor checker + var installSecret []byte + err = mainDB.QueryRow("SELECT install_secret FROM auth WHERE id = 1").Scan(&installSecret) + if err != nil { + log.Printf("[WARN] Failed to load install secret for doctor: %v", err) + installSecret = nil + } + + doctorChecker := doctor.New(doctor.Config{ + DB: mainDB, + DataDir: cfg.DataDir, + FirmwareDir: cfg.SeedFirmwareDir, + MDNSEnabled: cfg.MDNSEnabled, + MQTTBroker: cfg.MQTTBroker, + NTPServer: cfg.NTPServer, + InstallSecret: installSecret, + FleetGetNodes: func() ([]doctor.NodeInfo, error) { + nodes, err := fleetReg.GetAllNodes() + if err != nil { + return nil, err + } + result := make([]doctor.NodeInfo, len(nodes)) + for i, n := range nodes { + result[i] = doctor.NodeInfo{MAC: n.MAC} + } + return result, nil + }, + MDNSIsRegistered: func() bool { + return mdnsServer != nil + }, + }) + if authHandler != nil { + r.Get("/api/doctor", doctorChecker.Handler(authHandler.RequireAuth)) + log.Printf("[INFO] Doctor diagnostics API registered at /api/doctor") + } else { + log.Printf("[WARN] Auth handler not available, doctor endpoint requires auth") + } + srv := &http.Server{ Addr: cfg.BindAddr, Handler: r, diff --git a/mothership/internal/doctor/doctor.go b/mothership/internal/doctor/doctor.go new file mode 100644 index 0000000..4194f95 --- /dev/null +++ b/mothership/internal/doctor/doctor.go @@ -0,0 +1,459 @@ +// Package doctor provides pre-flight configuration diagnostics for the Spaxel mothership. +// It complements the /healthz endpoint (runtime state) with configuration correctness checks. +package doctor + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "syscall" + "time" +) + +// Checker runs pre-flight diagnostics on the mothership configuration. +type Checker struct { + db *sql.DB + dataDir string + firmwareDir string + mdnsEnabled bool + mqttBroker string + ntpServer string + installSecret []byte + fleetGetNodes func() ([]NodeInfo, error) + mdnsIsRegistered func() bool +} + +// Config holds checker configuration. +type Config struct { + DB *sql.DB + DataDir string + FirmwareDir string + MDNSEnabled bool + MQTTBroker string + NTPServer string + InstallSecret []byte + FleetGetNodes func() ([]NodeInfo, error) + MDNSIsRegistered func() bool +} + +// NodeInfo represents minimal node information for token consistency checks. +type NodeInfo struct { + MAC string +} + +// New creates a new doctor checker. +func New(cfg Config) *Checker { + return &Checker{ + db: cfg.DB, + dataDir: cfg.DataDir, + firmwareDir: cfg.FirmwareDir, + mdnsEnabled: cfg.MDNSEnabled, + mqttBroker: cfg.MQTTBroker, + ntpServer: cfg.NTPServer, + installSecret: cfg.InstallSecret, + fleetGetNodes: cfg.FleetGetNodes, + mdnsIsRegistered: cfg.MDNSIsRegistered, + } +} + +// CheckResult represents the result of a single diagnostic check. +type CheckResult struct { + Name string `json:"name"` + Status string `json:"status"` // "ok", "warn", "error" + Message string `json:"message"` // null if ok, error message otherwise +} + +// Response is the doctor endpoint response. +type Response struct { + Checks []CheckResult `json:"checks"` + Overall string `json:"overall"` // "ok", "warn", "error" + CheckedAt string `json:"checked_at"` +} + +// Check runs all pre-flight diagnostics and returns the results. +func (c *Checker) Check() Response { + results := []CheckResult{ + c.checkDataDirWritable(), + c.checkDBIntegrity(), + c.checkFirmwareDir(), + c.checkMDNSBinding(), + c.checkMQTTReachable(), + c.checkNTPReachable(), + c.checkInstallSecret(), + c.checkPINConfigured(), + c.checkNodeTokenConsistency(), + } + + overall := "ok" + for _, r := range results { + if r.Status == "error" { + overall = "error" + break + } + if r.Status == "warn" && overall == "ok" { + overall = "warn" + } + } + + return Response{ + Checks: results, + Overall: overall, + CheckedAt: time.Now().UTC().Format("2006-01-02T15:04:05Z"), + } +} + +// checkDataDirWritable verifies /data is writable and has >100 MB free. +func (c *Checker) checkDataDirWritable() CheckResult { + // Check if directory is writable by attempting to create a temp file + testFile := filepath.Join(c.dataDir, ".doctor_write_test") + f, err := os.Create(testFile) + if err != nil { + return CheckResult{ + Name: "data_dir_writable", + Status: "error", + Message: "Data directory not writable: " + err.Error(), + } + } + _ = f.Close() + _ = os.Remove(testFile) + + // Check disk space using syscall.Statfs + var stat syscall.Statfs_t + if err := syscall.Statfs(c.dataDir, &stat); err != nil { + return CheckResult{ + Name: "data_dir_writable", + Status: "warn", + Message: "Cannot check disk space: " + err.Error(), + } + } + + // Calculate free space in MB: (Bavail * Frsize) / (1024 * 1024) + freeBytes := stat.Bavail * uint64(stat.Frsize) + freeMB := freeBytes / (1024 * 1024) + + if freeMB < 100 { + return CheckResult{ + Name: "data_dir_writable", + Status: "error", + Message: fmt.Sprintf("Disk space low: %d MB free (minimum 100 MB required)", freeMB), + } + } + + return CheckResult{ + Name: "data_dir_writable", + Status: "ok", + Message: "", + } +} + +// checkDBIntegrity runs PRAGMA integrity_check on the database. +func (c *Checker) checkDBIntegrity() CheckResult { + var result string + err := c.db.QueryRow("PRAGMA integrity_check").Scan(&result) + if err != nil { + return CheckResult{ + Name: "db_integrity", + Status: "error", + Message: "SQLite integrity check failed: " + err.Error(), + } + } + + if result != "ok" { + return CheckResult{ + Name: "db_integrity", + Status: "error", + Message: "SQLite integrity check failed: " + result, + } + } + + return CheckResult{ + Name: "db_integrity", + Status: "ok", + Message: "", + } +} + +// checkFirmwareDir verifies /firmware contains at least one *.bin file. +func (c *Checker) checkFirmwareDir() CheckResult { + entries, err := os.ReadDir(c.firmwareDir) + if err != nil { + if os.IsNotExist(err) { + return CheckResult{ + Name: "firmware_dir", + Status: "error", + Message: "Firmware directory does not exist: " + c.firmwareDir, + } + } + return CheckResult{ + Name: "firmware_dir", + Status: "warn", + Message: "Cannot read firmware directory: " + err.Error(), + } + } + + hasBin := false + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".bin") { + hasBin = true + break + } + } + + if !hasBin { + return CheckResult{ + Name: "firmware_dir", + Status: "error", + Message: "No firmware binaries found — OTA updates unavailable", + } + } + + return CheckResult{ + Name: "firmware_dir", + Status: "ok", + Message: "", + } +} + +// checkMDNSBinding verifies mDNS service is registered or SPAXEL_MDNS_ENABLED=false. +func (c *Checker) checkMDNSBinding() CheckResult { + if !c.mdnsEnabled { + return CheckResult{ + Name: "mdns_binding", + Status: "ok", + Message: "", + } + } + + if c.mdnsIsRegistered != nil && c.mdnsIsRegistered() { + return CheckResult{ + Name: "mdns_binding", + Status: "ok", + Message: "", + } + } + + return CheckResult{ + Name: "mdns_binding", + Status: "warn", + Message: "mDNS not advertising — nodes cannot auto-discover mothership", + } +} + +// checkMQTTReachable tests TCP connectivity to MQTT broker if configured. +func (c *Checker) checkMQTTReachable() CheckResult { + if c.mqttBroker == "" { + return CheckResult{ + Name: "mqtt_reachable", + Status: "ok", + Message: "", + } + } + + // Parse broker URL to extract host:port + parts := strings.SplitN(c.mqttBroker, "://", 2) + if len(parts) != 2 { + return CheckResult{ + Name: "mqtt_reachable", + Status: "warn", + Message: "Invalid MQTT broker URL: " + c.mqttBroker, + } + } + + addr := parts[1] + // Remove any path component + if idx := strings.Index(addr, "/"); idx >= 0 { + addr = addr[:idx] + } + + // Add default port if not specified + if !strings.Contains(addr, ":") { + addr += ":1883" + } + + // Try to connect with 3s timeout + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return CheckResult{ + Name: "mqtt_reachable", + Status: "warn", + Message: "MQTT broker unreachable: " + c.mqttBroker + " (" + err.Error() + ")", + } + } + _ = conn.Close() + + return CheckResult{ + Name: "mqtt_reachable", + Status: "ok", + Message: "", + } +} + +// checkNTPReachable tests UDP connectivity to NTP server. +func (c *Checker) checkNTPReachable() CheckResult { + if c.ntpServer == "" { + return CheckResult{ + Name: "ntp_reachable", + Status: "ok", + Message: "", + } + } + + // NTP uses UDP port 123 + addr := c.ntpServer + ":123" + + // Try to connect with 3s timeout + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, "udp", addr) + if err != nil { + return CheckResult{ + Name: "ntp_reachable", + Status: "warn", + Message: "NTP server unreachable — node clock sync may fail: " + c.ntpServer + " (" + err.Error() + ")", + } + } + _ = conn.Close() + + return CheckResult{ + Name: "ntp_reachable", + Status: "ok", + Message: "", + } +} + +// checkInstallSecret verifies the install_secret row exists in auth table. +func (c *Checker) checkInstallSecret() CheckResult { + if len(c.installSecret) == 0 { + return CheckResult{ + Name: "install_secret", + Status: "error", + Message: "Installation secret missing — re-run container to regenerate", + } + } + + // Also verify it exists in the database + var secret []byte + err := c.db.QueryRow("SELECT install_secret FROM auth WHERE id = 1").Scan(&secret) + if err != nil { + if err == sql.ErrNoRows { + return CheckResult{ + Name: "install_secret", + Status: "error", + Message: "Installation secret missing from database — re-run container to regenerate", + } + } + return CheckResult{ + Name: "install_secret", + Status: "warn", + Message: "Cannot verify install_secret in database: " + err.Error(), + } + } + + return CheckResult{ + Name: "install_secret", + Status: "ok", + Message: "", + } +} + +// checkPINConfigured verifies pin_bcrypt is non-null in auth table. +func (c *Checker) checkPINConfigured() CheckResult { + var pinBcrypt sql.NullString + err := c.db.QueryRow("SELECT pin_bcrypt FROM auth WHERE id = 1").Scan(&pinBcrypt) + if err != nil { + return CheckResult{ + Name: "pin_configured", + Status: "warn", + Message: "Cannot check PIN configuration: " + err.Error(), + } + } + + if !pinBcrypt.Valid || pinBcrypt.String == "" { + return CheckResult{ + Name: "pin_configured", + Status: "error", + Message: "Dashboard PIN not configured — run first-time setup", + } + } + + return CheckResult{ + Name: "pin_configured", + Status: "ok", + Message: "", + } +} + +// checkNodeTokenConsistency verifies all nodes in registry can derive valid tokens. +func (c *Checker) checkNodeTokenConsistency() CheckResult { + if c.fleetGetNodes == nil { + return CheckResult{ + Name: "node_token_consistency", + Status: "ok", + Message: "", + } + } + + nodes, err := c.fleetGetNodes() + if err != nil { + return CheckResult{ + Name: "node_token_consistency", + Status: "warn", + Message: "Cannot check node tokens: " + err.Error(), + } + } + + // All nodes can derive valid tokens since tokens are computed from MAC + install_secret + // This check is more informational - if there are nodes, they can all authenticate + if len(nodes) == 0 { + return CheckResult{ + Name: "node_token_consistency", + Status: "ok", + Message: "", + } + } + + // Verify each node has a valid MAC address + for _, node := range nodes { + if node.MAC == "" { + return CheckResult{ + Name: "node_token_consistency", + Status: "error", + Message: "Node with empty MAC address found in registry", + } + } + } + + return CheckResult{ + Name: "node_token_consistency", + Status: "ok", + Message: "", + } +} + +// Handler returns an http.HandlerFunc for the /api/doctor endpoint. +func (c *Checker) Handler(requireAuth func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { + return requireAuth(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + response := c.Check() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(response) + }) +} diff --git a/mothership/internal/doctor/doctor_test.go b/mothership/internal/doctor/doctor_test.go new file mode 100644 index 0000000..bc617b3 --- /dev/null +++ b/mothership/internal/doctor/doctor_test.go @@ -0,0 +1,601 @@ +// Package doctor provides pre-flight configuration diagnostics for the Spaxel mothership. +package doctor + +import ( + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + _ "modernc.org/sqlite" +) + +// mockRow is a simple mock for sql.Row scanning. +type mockRow struct { + values []interface{} + err error +} + +func (m mockRow) Scan(dest ...interface{}) error { + if m.err != nil { + return m.err + } + if len(m.values) != len(dest) { + return sql.ErrNoRows + } + for i, v := range m.values { + switch d := dest[i].(type) { + case *string: + if v == nil { + return sql.ErrNoRows + } + *d = v.(string) + case *[]byte: + if v == nil { + return sql.ErrNoRows + } + *d = v.([]byte) + case *sql.NullString: + if v == nil { + d.Valid = false + } else { + d.String = v.(string) + d.Valid = true + } + default: + // For other types, try direct assignment + } + } + return nil +} + +func TestCheckDataDirWritable(t *testing.T) { + tests := []struct { + name string + setup func() (string, func()) + wantName string + wantStatus string + }{ + { + name: "writable with enough space", + setup: func() (string, func()) { + dir := t.TempDir() + // Create a test file to verify writability + testFile := filepath.Join(dir, "test") + if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { + t.Fatalf("failed to create test file: %v", err) + } + return dir, func() { os.RemoveAll(dir) } + }, + wantName: "data_dir_writable", + wantStatus: "ok", + }, + { + name: "directory not writable", + setup: func() (string, func()) { + // Use /proc which exists but isn't writable + return "/proc", func() {} + }, + wantName: "data_dir_writable", + wantStatus: "error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dataDir, cleanup := tt.setup() + defer cleanup() + + c := &Checker{dataDir: dataDir} + result := c.checkDataDirWritable() + + if result.Name != tt.wantName { + t.Errorf("Name = %v, want %v", result.Name, tt.wantName) + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus) + } + }) + } +} + +func TestCheckFirmwareDir(t *testing.T) { + tests := []struct { + name string + setupFirmware func() string + wantName string + wantStatus string + wantMessage string + }{ + { + name: "has firmware binaries", + setupFirmware: func() string { + dir := t.TempDir() + _ = os.WriteFile(filepath.Join(dir, "spaxel.bin"), []byte("firmware"), 0644) + return dir + }, + wantName: "firmware_dir", + wantStatus: "ok", + }, + { + name: "no firmware binaries", + setupFirmware: func() string { + return t.TempDir() + }, + wantName: "firmware_dir", + wantStatus: "error", + wantMessage: "No firmware binaries found", + }, + { + name: "directory does not exist", + setupFirmware: func() string { + return "/nonexistent/firmware" + }, + wantName: "firmware_dir", + wantStatus: "error", + wantMessage: "Firmware directory does not exist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + firmwareDir := tt.setupFirmware() + c := &Checker{firmwareDir: firmwareDir} + result := c.checkFirmwareDir() + + if result.Name != tt.wantName { + t.Errorf("Name = %v, want %v", result.Name, tt.wantName) + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus) + } + if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage) + } + }) + } +} + +func TestCheckMDNSBinding(t *testing.T) { + tests := []struct { + name string + enabled bool + reg func() bool + want CheckResult + }{ + { + name: "mdns disabled", + enabled: false, + want: CheckResult{Name: "mdns_binding", Status: "ok"}, + }, + { + name: "mdns enabled and registered", + enabled: true, + reg: func() bool { return true }, + want: CheckResult{Name: "mdns_binding", Status: "ok"}, + }, + { + name: "mdns enabled but not registered", + enabled: true, + reg: func() bool { return false }, + want: CheckResult{Name: "mdns_binding", Status: "warn", Message: "mDNS not advertising"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Checker{ + mdnsEnabled: tt.enabled, + mdnsIsRegistered: tt.reg, + } + result := c.checkMDNSBinding() + + if result.Name != tt.want.Name { + t.Errorf("Name = %v, want %v", result.Name, tt.want.Name) + } + if result.Status != tt.want.Status { + t.Errorf("Status = %v, want %v", result.Status, tt.want.Status) + } + if tt.want.Message != "" && !strings.Contains(result.Message, tt.want.Message) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.want.Message) + } + }) + } +} + +func TestCheckMQTTReachable(t *testing.T) { + tests := []struct { + name string + broker string + wantMsg string + }{ + { + name: "no broker configured", + broker: "", + wantMsg: "", + }, + { + name: "invalid broker URL", + broker: "not-a-url", + wantMsg: "Invalid MQTT broker URL", + }, + { + name: "unreachable broker", + broker: "mqtt://localhost:9999", + wantMsg: "MQTT broker unreachable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Checker{mqttBroker: tt.broker} + result := c.checkMQTTReachable() + + if result.Name != "mqtt_reachable" { + t.Errorf("Name = %v, want mqtt_reachable", result.Name) + } + if tt.wantMsg != "" && result.Status != "ok" { + if !strings.Contains(result.Message, tt.wantMsg) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMsg) + } + } + }) + } +} + +func TestCheckNTPReachable(t *testing.T) { + tests := []struct { + name string + server string + wantMsg string + }{ + { + name: "no server configured", + server: "", + wantMsg: "", + }, + { + name: "invalid server", + server: "invalid host name with spaces", + wantMsg: "NTP server unreachable", + }, + { + name: "valid server (pool.ntp.org)", + server: "pool.ntp.org", + wantMsg: "", // May be ok or warn depending on network + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Checker{ntpServer: tt.server} + result := c.checkNTPReachable() + + if result.Name != "ntp_reachable" { + t.Errorf("Name = %v, want ntp_reachable", result.Name) + } + if tt.wantMsg != "" && result.Status != "ok" { + if !strings.Contains(result.Message, tt.wantMsg) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMsg) + } + } + }) + } +} + +func TestCheckInstallSecret(t *testing.T) { + tests := []struct { + name string + installSecret []byte + dbSetup func(*sql.DB) + wantStatus string + wantMessage string + }{ + { + name: "secret exists in memory and db", + installSecret: []byte("test-secret-32-bytes-long-enough"), + dbSetup: func(db *sql.DB) { + _, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, install_secret BLOB)") + _, _ = db.Exec("INSERT INTO auth (id, install_secret) VALUES (1, ?)", []byte("test-secret-32-bytes-long-enough")) + }, + wantStatus: "ok", + }, + { + name: "secret missing from memory", + installSecret: nil, + dbSetup: func(db *sql.DB) {}, + wantStatus: "error", + wantMessage: "Installation secret missing", + }, + { + name: "secret missing from db", + installSecret: []byte("test"), + dbSetup: func(db *sql.DB) { + _, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, install_secret BLOB)") + }, + wantStatus: "error", + wantMessage: "Installation secret missing from database", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, _ := sql.Open("sqlite", ":memory:") + tt.dbSetup(db) + defer db.Close() + + c := &Checker{ + db: db, + installSecret: tt.installSecret, + } + result := c.checkInstallSecret() + + if result.Name != "install_secret" { + t.Errorf("Name = %v, want install_secret", result.Name) + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus) + } + if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage) + } + }) + } +} + +func TestCheckPINConfigured(t *testing.T) { + tests := []struct { + name string + dbSetup func(*sql.DB) + wantStatus string + wantMessage string + }{ + { + name: "pin configured", + dbSetup: func(db *sql.DB) { + _, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, pin_bcrypt TEXT)") + _, _ = db.Exec("INSERT INTO auth (id, pin_bcrypt) VALUES (1, '$2a$12$hash')") + }, + wantStatus: "ok", + }, + { + name: "pin not configured", + dbSetup: func(db *sql.DB) { + _, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, pin_bcrypt TEXT)") + _, _ = db.Exec("INSERT INTO auth (id, pin_bcrypt) VALUES (1, NULL)") + }, + wantStatus: "error", + wantMessage: "Dashboard PIN not configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, _ := sql.Open("sqlite", ":memory:") + tt.dbSetup(db) + defer db.Close() + + c := &Checker{db: db} + result := c.checkPINConfigured() + + if result.Name != "pin_configured" { + t.Errorf("Name = %v, want pin_configured", result.Name) + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus) + } + if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage) + } + }) + } +} + +func TestCheckNodeTokenConsistency(t *testing.T) { + tests := []struct { + name string + getNodes func() ([]NodeInfo, error) + wantStatus string + wantMessage string + }{ + { + name: "no nodes", + getNodes: func() ([]NodeInfo, error) { + return []NodeInfo{}, nil + }, + wantStatus: "ok", + }, + { + name: "nodes with valid MACs", + getNodes: func() ([]NodeInfo, error) { + return []NodeInfo{ + {MAC: "AA:BB:CC:DD:EE:FF"}, + {MAC: "11:22:33:44:55:66"}, + }, nil + }, + wantStatus: "ok", + }, + { + name: "node with empty MAC", + getNodes: func() ([]NodeInfo, error) { + return []NodeInfo{{MAC: ""}}, nil + }, + wantStatus: "error", + wantMessage: "Node with empty MAC address", + }, + { + name: "get nodes error", + getNodes: func() ([]NodeInfo, error) { + return nil, sql.ErrConnDone + }, + wantStatus: "warn", + wantMessage: "Cannot check node tokens", + }, + { + name: "nil getNodes function", + getNodes: nil, + wantStatus: "ok", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Checker{fleetGetNodes: tt.getNodes} + result := c.checkNodeTokenConsistency() + + if result.Name != "node_token_consistency" { + t.Errorf("Name = %v, want node_token_consistency", result.Name) + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus) + } + if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) { + t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage) + } + }) + } +} + +func TestCheckOverall(t *testing.T) { + tests := []struct { + name string + checks []CheckResult + wantOverall string + }{ + { + name: "all ok", + checks: []CheckResult{ + {Name: "check1", Status: "ok"}, + {Name: "check2", Status: "ok"}, + }, + wantOverall: "ok", + }, + { + name: "one warn", + checks: []CheckResult{ + {Name: "check1", Status: "ok"}, + {Name: "check2", Status: "warn"}, + }, + wantOverall: "warn", + }, + { + name: "one error", + checks: []CheckResult{ + {Name: "check1", Status: "ok"}, + {Name: "check2", Status: "error"}, + }, + wantOverall: "error", + }, + { + name: "warn and error", + checks: []CheckResult{ + {Name: "check1", Status: "warn"}, + {Name: "check2", Status: "error"}, + }, + wantOverall: "error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + overall := "ok" + for _, r := range tt.checks { + if r.Status == "error" { + overall = "error" + break + } + if r.Status == "warn" && overall == "ok" { + overall = "warn" + } + } + + if overall != tt.wantOverall { + t.Errorf("Overall = %v, want %v", overall, tt.wantOverall) + } + }) + } +} + +func TestHandler(t *testing.T) { + requireAuthCalled := false + requireAuth := func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + requireAuthCalled = true + next(w, r) + } + } + + // Create a test database + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + // Create auth table for tests + _, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, install_secret BLOB, pin_bcrypt TEXT)") + _, _ = db.Exec("INSERT INTO auth (id, install_secret, pin_bcrypt) VALUES (1, ?, '$2a$12$hash')", []byte("test-secret-32-bytes-long-enough")) + + c := &Checker{ + db: db, + dataDir: t.TempDir(), + firmwareDir: t.TempDir(), + installSecret: []byte("test-secret-32-bytes-long-enough"), + } + + // Create firmware file + _ = os.WriteFile(filepath.Join(c.firmwareDir, "test.bin"), []byte("firmware"), 0644) + + handler := c.Handler(requireAuth) + + req := httptest.NewRequest("GET", "/api/doctor", nil) + w := httptest.NewRecorder() + + handler(w, req) + + if !requireAuthCalled { + t.Error("requireAuth was not called") + } + + if w.Code != http.StatusOK { + t.Errorf("Status = %v, want %v", w.Code, http.StatusOK) + } + + var response Response + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response.Overall == "" { + t.Error("Overall status is empty") + } + + if response.CheckedAt == "" { + t.Error("CheckedAt is empty") + } + + if len(response.Checks) != 9 { + t.Errorf("Got %d checks, want 9", len(response.Checks)) + } +} + +func TestHandlerMethodNotAllowed(t *testing.T) { + requireAuth := func(next http.HandlerFunc) http.HandlerFunc { + return next + } + + c := &Checker{} + handler := c.Handler(requireAuth) + + req := httptest.NewRequest("POST", "/api/doctor", nil) + w := httptest.NewRecorder() + + handler(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Status = %v, want %v", w.Code, http.StatusMethodNotAllowed) + } +}