api: add GET /api/doctor endpoint for pre-flight configuration diagnostics

The /api/doctor endpoint complements /healthz (runtime state) with
configuration correctness checks. Returns 200 with a JSON report
containing all check results.

Checks implemented:
- data_dir_writable: verifies /data is writable with >100 MB free
- db_integrity: runs PRAGMA integrity_check
- firmware_dir: checks for *.bin files in /firmware
- mdns_binding: verifies mDNS service is registered (or SPAXEL_MDNS_ENABLED=false)
- mqtt_reachable: TCP connectivity test if SPAXEL_MQTT_BROKER is set
- ntp_reachable: UDP connectivity test if SPAXEL_NTP_SERVER is set
- install_secret: verifies install_secret row exists in auth table
- pin_configured: verifies pin_bcrypt is non-null in auth table
- node_token_consistency: verifies all nodes have valid MAC addresses

Response format includes overall status (ok/warn/error), individual check
results with name/status/message, and checked_at timestamp.

Requires session cookie authentication via authHandler.RequireAuth.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-05-05 08:16:53 -04:00
parent b036bd5183
commit 582e45734c
3 changed files with 1102 additions and 0 deletions

View file

@ -31,6 +31,7 @@ import (
"github.com/spaxel/mothership/internal/dashboard"
"github.com/spaxel/mothership/internal/db"
"github.com/spaxel/mothership/internal/diagnostics"
"github.com/spaxel/mothership/internal/doctor"
"github.com/spaxel/mothership/internal/eventbus"
"github.com/spaxel/mothership/internal/events"
"github.com/spaxel/mothership/internal/explainability"
@ -512,6 +513,8 @@ func main() {
settingsHandler.RegisterRoutes(r)
log.Printf("[INFO] Settings API registered at /api/settings")
// Note: Doctor API is registered after mdnsServer initialization
// Phase 6: Integration Settings REST API (MQTT + system webhook)
// Note: mqttClient and webhookPublisher are wired below after they are initialized.
integrationSettingsHandler := api.NewIntegrationSettingsHandler(mainDB, "")
@ -4068,6 +4071,45 @@ func main() {
}
}
// Phase 6: Pre-flight diagnostics API
// Get install secret from database for doctor checker
var installSecret []byte
err = mainDB.QueryRow("SELECT install_secret FROM auth WHERE id = 1").Scan(&installSecret)
if err != nil {
log.Printf("[WARN] Failed to load install secret for doctor: %v", err)
installSecret = nil
}
doctorChecker := doctor.New(doctor.Config{
DB: mainDB,
DataDir: cfg.DataDir,
FirmwareDir: cfg.SeedFirmwareDir,
MDNSEnabled: cfg.MDNSEnabled,
MQTTBroker: cfg.MQTTBroker,
NTPServer: cfg.NTPServer,
InstallSecret: installSecret,
FleetGetNodes: func() ([]doctor.NodeInfo, error) {
nodes, err := fleetReg.GetAllNodes()
if err != nil {
return nil, err
}
result := make([]doctor.NodeInfo, len(nodes))
for i, n := range nodes {
result[i] = doctor.NodeInfo{MAC: n.MAC}
}
return result, nil
},
MDNSIsRegistered: func() bool {
return mdnsServer != nil
},
})
if authHandler != nil {
r.Get("/api/doctor", doctorChecker.Handler(authHandler.RequireAuth))
log.Printf("[INFO] Doctor diagnostics API registered at /api/doctor")
} else {
log.Printf("[WARN] Auth handler not available, doctor endpoint requires auth")
}
srv := &http.Server{
Addr: cfg.BindAddr,
Handler: r,

View file

@ -0,0 +1,459 @@
// Package doctor provides pre-flight configuration diagnostics for the Spaxel mothership.
// It complements the /healthz endpoint (runtime state) with configuration correctness checks.
package doctor
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"syscall"
"time"
)
// Checker runs pre-flight diagnostics on the mothership configuration.
type Checker struct {
db *sql.DB
dataDir string
firmwareDir string
mdnsEnabled bool
mqttBroker string
ntpServer string
installSecret []byte
fleetGetNodes func() ([]NodeInfo, error)
mdnsIsRegistered func() bool
}
// Config holds checker configuration.
type Config struct {
DB *sql.DB
DataDir string
FirmwareDir string
MDNSEnabled bool
MQTTBroker string
NTPServer string
InstallSecret []byte
FleetGetNodes func() ([]NodeInfo, error)
MDNSIsRegistered func() bool
}
// NodeInfo represents minimal node information for token consistency checks.
type NodeInfo struct {
MAC string
}
// New creates a new doctor checker.
func New(cfg Config) *Checker {
return &Checker{
db: cfg.DB,
dataDir: cfg.DataDir,
firmwareDir: cfg.FirmwareDir,
mdnsEnabled: cfg.MDNSEnabled,
mqttBroker: cfg.MQTTBroker,
ntpServer: cfg.NTPServer,
installSecret: cfg.InstallSecret,
fleetGetNodes: cfg.FleetGetNodes,
mdnsIsRegistered: cfg.MDNSIsRegistered,
}
}
// CheckResult represents the result of a single diagnostic check.
type CheckResult struct {
Name string `json:"name"`
Status string `json:"status"` // "ok", "warn", "error"
Message string `json:"message"` // null if ok, error message otherwise
}
// Response is the doctor endpoint response.
type Response struct {
Checks []CheckResult `json:"checks"`
Overall string `json:"overall"` // "ok", "warn", "error"
CheckedAt string `json:"checked_at"`
}
// Check runs all pre-flight diagnostics and returns the results.
func (c *Checker) Check() Response {
results := []CheckResult{
c.checkDataDirWritable(),
c.checkDBIntegrity(),
c.checkFirmwareDir(),
c.checkMDNSBinding(),
c.checkMQTTReachable(),
c.checkNTPReachable(),
c.checkInstallSecret(),
c.checkPINConfigured(),
c.checkNodeTokenConsistency(),
}
overall := "ok"
for _, r := range results {
if r.Status == "error" {
overall = "error"
break
}
if r.Status == "warn" && overall == "ok" {
overall = "warn"
}
}
return Response{
Checks: results,
Overall: overall,
CheckedAt: time.Now().UTC().Format("2006-01-02T15:04:05Z"),
}
}
// checkDataDirWritable verifies /data is writable and has >100 MB free.
func (c *Checker) checkDataDirWritable() CheckResult {
// Check if directory is writable by attempting to create a temp file
testFile := filepath.Join(c.dataDir, ".doctor_write_test")
f, err := os.Create(testFile)
if err != nil {
return CheckResult{
Name: "data_dir_writable",
Status: "error",
Message: "Data directory not writable: " + err.Error(),
}
}
_ = f.Close()
_ = os.Remove(testFile)
// Check disk space using syscall.Statfs
var stat syscall.Statfs_t
if err := syscall.Statfs(c.dataDir, &stat); err != nil {
return CheckResult{
Name: "data_dir_writable",
Status: "warn",
Message: "Cannot check disk space: " + err.Error(),
}
}
// Calculate free space in MB: (Bavail * Frsize) / (1024 * 1024)
freeBytes := stat.Bavail * uint64(stat.Frsize)
freeMB := freeBytes / (1024 * 1024)
if freeMB < 100 {
return CheckResult{
Name: "data_dir_writable",
Status: "error",
Message: fmt.Sprintf("Disk space low: %d MB free (minimum 100 MB required)", freeMB),
}
}
return CheckResult{
Name: "data_dir_writable",
Status: "ok",
Message: "",
}
}
// checkDBIntegrity runs PRAGMA integrity_check on the database.
func (c *Checker) checkDBIntegrity() CheckResult {
var result string
err := c.db.QueryRow("PRAGMA integrity_check").Scan(&result)
if err != nil {
return CheckResult{
Name: "db_integrity",
Status: "error",
Message: "SQLite integrity check failed: " + err.Error(),
}
}
if result != "ok" {
return CheckResult{
Name: "db_integrity",
Status: "error",
Message: "SQLite integrity check failed: " + result,
}
}
return CheckResult{
Name: "db_integrity",
Status: "ok",
Message: "",
}
}
// checkFirmwareDir verifies /firmware contains at least one *.bin file.
func (c *Checker) checkFirmwareDir() CheckResult {
entries, err := os.ReadDir(c.firmwareDir)
if err != nil {
if os.IsNotExist(err) {
return CheckResult{
Name: "firmware_dir",
Status: "error",
Message: "Firmware directory does not exist: " + c.firmwareDir,
}
}
return CheckResult{
Name: "firmware_dir",
Status: "warn",
Message: "Cannot read firmware directory: " + err.Error(),
}
}
hasBin := false
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".bin") {
hasBin = true
break
}
}
if !hasBin {
return CheckResult{
Name: "firmware_dir",
Status: "error",
Message: "No firmware binaries found — OTA updates unavailable",
}
}
return CheckResult{
Name: "firmware_dir",
Status: "ok",
Message: "",
}
}
// checkMDNSBinding verifies mDNS service is registered or SPAXEL_MDNS_ENABLED=false.
func (c *Checker) checkMDNSBinding() CheckResult {
if !c.mdnsEnabled {
return CheckResult{
Name: "mdns_binding",
Status: "ok",
Message: "",
}
}
if c.mdnsIsRegistered != nil && c.mdnsIsRegistered() {
return CheckResult{
Name: "mdns_binding",
Status: "ok",
Message: "",
}
}
return CheckResult{
Name: "mdns_binding",
Status: "warn",
Message: "mDNS not advertising — nodes cannot auto-discover mothership",
}
}
// checkMQTTReachable tests TCP connectivity to MQTT broker if configured.
func (c *Checker) checkMQTTReachable() CheckResult {
if c.mqttBroker == "" {
return CheckResult{
Name: "mqtt_reachable",
Status: "ok",
Message: "",
}
}
// Parse broker URL to extract host:port
parts := strings.SplitN(c.mqttBroker, "://", 2)
if len(parts) != 2 {
return CheckResult{
Name: "mqtt_reachable",
Status: "warn",
Message: "Invalid MQTT broker URL: " + c.mqttBroker,
}
}
addr := parts[1]
// Remove any path component
if idx := strings.Index(addr, "/"); idx >= 0 {
addr = addr[:idx]
}
// Add default port if not specified
if !strings.Contains(addr, ":") {
addr += ":1883"
}
// Try to connect with 3s timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return CheckResult{
Name: "mqtt_reachable",
Status: "warn",
Message: "MQTT broker unreachable: " + c.mqttBroker + " (" + err.Error() + ")",
}
}
_ = conn.Close()
return CheckResult{
Name: "mqtt_reachable",
Status: "ok",
Message: "",
}
}
// checkNTPReachable tests UDP connectivity to NTP server.
func (c *Checker) checkNTPReachable() CheckResult {
if c.ntpServer == "" {
return CheckResult{
Name: "ntp_reachable",
Status: "ok",
Message: "",
}
}
// NTP uses UDP port 123
addr := c.ntpServer + ":123"
// Try to connect with 3s timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, "udp", addr)
if err != nil {
return CheckResult{
Name: "ntp_reachable",
Status: "warn",
Message: "NTP server unreachable — node clock sync may fail: " + c.ntpServer + " (" + err.Error() + ")",
}
}
_ = conn.Close()
return CheckResult{
Name: "ntp_reachable",
Status: "ok",
Message: "",
}
}
// checkInstallSecret verifies the install_secret row exists in auth table.
func (c *Checker) checkInstallSecret() CheckResult {
if len(c.installSecret) == 0 {
return CheckResult{
Name: "install_secret",
Status: "error",
Message: "Installation secret missing — re-run container to regenerate",
}
}
// Also verify it exists in the database
var secret []byte
err := c.db.QueryRow("SELECT install_secret FROM auth WHERE id = 1").Scan(&secret)
if err != nil {
if err == sql.ErrNoRows {
return CheckResult{
Name: "install_secret",
Status: "error",
Message: "Installation secret missing from database — re-run container to regenerate",
}
}
return CheckResult{
Name: "install_secret",
Status: "warn",
Message: "Cannot verify install_secret in database: " + err.Error(),
}
}
return CheckResult{
Name: "install_secret",
Status: "ok",
Message: "",
}
}
// checkPINConfigured verifies pin_bcrypt is non-null in auth table.
func (c *Checker) checkPINConfigured() CheckResult {
var pinBcrypt sql.NullString
err := c.db.QueryRow("SELECT pin_bcrypt FROM auth WHERE id = 1").Scan(&pinBcrypt)
if err != nil {
return CheckResult{
Name: "pin_configured",
Status: "warn",
Message: "Cannot check PIN configuration: " + err.Error(),
}
}
if !pinBcrypt.Valid || pinBcrypt.String == "" {
return CheckResult{
Name: "pin_configured",
Status: "error",
Message: "Dashboard PIN not configured — run first-time setup",
}
}
return CheckResult{
Name: "pin_configured",
Status: "ok",
Message: "",
}
}
// checkNodeTokenConsistency verifies all nodes in registry can derive valid tokens.
func (c *Checker) checkNodeTokenConsistency() CheckResult {
if c.fleetGetNodes == nil {
return CheckResult{
Name: "node_token_consistency",
Status: "ok",
Message: "",
}
}
nodes, err := c.fleetGetNodes()
if err != nil {
return CheckResult{
Name: "node_token_consistency",
Status: "warn",
Message: "Cannot check node tokens: " + err.Error(),
}
}
// All nodes can derive valid tokens since tokens are computed from MAC + install_secret
// This check is more informational - if there are nodes, they can all authenticate
if len(nodes) == 0 {
return CheckResult{
Name: "node_token_consistency",
Status: "ok",
Message: "",
}
}
// Verify each node has a valid MAC address
for _, node := range nodes {
if node.MAC == "" {
return CheckResult{
Name: "node_token_consistency",
Status: "error",
Message: "Node with empty MAC address found in registry",
}
}
}
return CheckResult{
Name: "node_token_consistency",
Status: "ok",
Message: "",
}
}
// Handler returns an http.HandlerFunc for the /api/doctor endpoint.
func (c *Checker) Handler(requireAuth func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
return requireAuth(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
response := c.Check()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response)
})
}

View file

@ -0,0 +1,601 @@
// Package doctor provides pre-flight configuration diagnostics for the Spaxel mothership.
package doctor
import (
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
_ "modernc.org/sqlite"
)
// mockRow is a simple mock for sql.Row scanning.
type mockRow struct {
values []interface{}
err error
}
func (m mockRow) Scan(dest ...interface{}) error {
if m.err != nil {
return m.err
}
if len(m.values) != len(dest) {
return sql.ErrNoRows
}
for i, v := range m.values {
switch d := dest[i].(type) {
case *string:
if v == nil {
return sql.ErrNoRows
}
*d = v.(string)
case *[]byte:
if v == nil {
return sql.ErrNoRows
}
*d = v.([]byte)
case *sql.NullString:
if v == nil {
d.Valid = false
} else {
d.String = v.(string)
d.Valid = true
}
default:
// For other types, try direct assignment
}
}
return nil
}
func TestCheckDataDirWritable(t *testing.T) {
tests := []struct {
name string
setup func() (string, func())
wantName string
wantStatus string
}{
{
name: "writable with enough space",
setup: func() (string, func()) {
dir := t.TempDir()
// Create a test file to verify writability
testFile := filepath.Join(dir, "test")
if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
}
return dir, func() { os.RemoveAll(dir) }
},
wantName: "data_dir_writable",
wantStatus: "ok",
},
{
name: "directory not writable",
setup: func() (string, func()) {
// Use /proc which exists but isn't writable
return "/proc", func() {}
},
wantName: "data_dir_writable",
wantStatus: "error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dataDir, cleanup := tt.setup()
defer cleanup()
c := &Checker{dataDir: dataDir}
result := c.checkDataDirWritable()
if result.Name != tt.wantName {
t.Errorf("Name = %v, want %v", result.Name, tt.wantName)
}
if result.Status != tt.wantStatus {
t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus)
}
})
}
}
func TestCheckFirmwareDir(t *testing.T) {
tests := []struct {
name string
setupFirmware func() string
wantName string
wantStatus string
wantMessage string
}{
{
name: "has firmware binaries",
setupFirmware: func() string {
dir := t.TempDir()
_ = os.WriteFile(filepath.Join(dir, "spaxel.bin"), []byte("firmware"), 0644)
return dir
},
wantName: "firmware_dir",
wantStatus: "ok",
},
{
name: "no firmware binaries",
setupFirmware: func() string {
return t.TempDir()
},
wantName: "firmware_dir",
wantStatus: "error",
wantMessage: "No firmware binaries found",
},
{
name: "directory does not exist",
setupFirmware: func() string {
return "/nonexistent/firmware"
},
wantName: "firmware_dir",
wantStatus: "error",
wantMessage: "Firmware directory does not exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
firmwareDir := tt.setupFirmware()
c := &Checker{firmwareDir: firmwareDir}
result := c.checkFirmwareDir()
if result.Name != tt.wantName {
t.Errorf("Name = %v, want %v", result.Name, tt.wantName)
}
if result.Status != tt.wantStatus {
t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus)
}
if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage)
}
})
}
}
func TestCheckMDNSBinding(t *testing.T) {
tests := []struct {
name string
enabled bool
reg func() bool
want CheckResult
}{
{
name: "mdns disabled",
enabled: false,
want: CheckResult{Name: "mdns_binding", Status: "ok"},
},
{
name: "mdns enabled and registered",
enabled: true,
reg: func() bool { return true },
want: CheckResult{Name: "mdns_binding", Status: "ok"},
},
{
name: "mdns enabled but not registered",
enabled: true,
reg: func() bool { return false },
want: CheckResult{Name: "mdns_binding", Status: "warn", Message: "mDNS not advertising"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Checker{
mdnsEnabled: tt.enabled,
mdnsIsRegistered: tt.reg,
}
result := c.checkMDNSBinding()
if result.Name != tt.want.Name {
t.Errorf("Name = %v, want %v", result.Name, tt.want.Name)
}
if result.Status != tt.want.Status {
t.Errorf("Status = %v, want %v", result.Status, tt.want.Status)
}
if tt.want.Message != "" && !strings.Contains(result.Message, tt.want.Message) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.want.Message)
}
})
}
}
func TestCheckMQTTReachable(t *testing.T) {
tests := []struct {
name string
broker string
wantMsg string
}{
{
name: "no broker configured",
broker: "",
wantMsg: "",
},
{
name: "invalid broker URL",
broker: "not-a-url",
wantMsg: "Invalid MQTT broker URL",
},
{
name: "unreachable broker",
broker: "mqtt://localhost:9999",
wantMsg: "MQTT broker unreachable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Checker{mqttBroker: tt.broker}
result := c.checkMQTTReachable()
if result.Name != "mqtt_reachable" {
t.Errorf("Name = %v, want mqtt_reachable", result.Name)
}
if tt.wantMsg != "" && result.Status != "ok" {
if !strings.Contains(result.Message, tt.wantMsg) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMsg)
}
}
})
}
}
func TestCheckNTPReachable(t *testing.T) {
tests := []struct {
name string
server string
wantMsg string
}{
{
name: "no server configured",
server: "",
wantMsg: "",
},
{
name: "invalid server",
server: "invalid host name with spaces",
wantMsg: "NTP server unreachable",
},
{
name: "valid server (pool.ntp.org)",
server: "pool.ntp.org",
wantMsg: "", // May be ok or warn depending on network
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Checker{ntpServer: tt.server}
result := c.checkNTPReachable()
if result.Name != "ntp_reachable" {
t.Errorf("Name = %v, want ntp_reachable", result.Name)
}
if tt.wantMsg != "" && result.Status != "ok" {
if !strings.Contains(result.Message, tt.wantMsg) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMsg)
}
}
})
}
}
func TestCheckInstallSecret(t *testing.T) {
tests := []struct {
name string
installSecret []byte
dbSetup func(*sql.DB)
wantStatus string
wantMessage string
}{
{
name: "secret exists in memory and db",
installSecret: []byte("test-secret-32-bytes-long-enough"),
dbSetup: func(db *sql.DB) {
_, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, install_secret BLOB)")
_, _ = db.Exec("INSERT INTO auth (id, install_secret) VALUES (1, ?)", []byte("test-secret-32-bytes-long-enough"))
},
wantStatus: "ok",
},
{
name: "secret missing from memory",
installSecret: nil,
dbSetup: func(db *sql.DB) {},
wantStatus: "error",
wantMessage: "Installation secret missing",
},
{
name: "secret missing from db",
installSecret: []byte("test"),
dbSetup: func(db *sql.DB) {
_, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, install_secret BLOB)")
},
wantStatus: "error",
wantMessage: "Installation secret missing from database",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, _ := sql.Open("sqlite", ":memory:")
tt.dbSetup(db)
defer db.Close()
c := &Checker{
db: db,
installSecret: tt.installSecret,
}
result := c.checkInstallSecret()
if result.Name != "install_secret" {
t.Errorf("Name = %v, want install_secret", result.Name)
}
if result.Status != tt.wantStatus {
t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus)
}
if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage)
}
})
}
}
func TestCheckPINConfigured(t *testing.T) {
tests := []struct {
name string
dbSetup func(*sql.DB)
wantStatus string
wantMessage string
}{
{
name: "pin configured",
dbSetup: func(db *sql.DB) {
_, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, pin_bcrypt TEXT)")
_, _ = db.Exec("INSERT INTO auth (id, pin_bcrypt) VALUES (1, '$2a$12$hash')")
},
wantStatus: "ok",
},
{
name: "pin not configured",
dbSetup: func(db *sql.DB) {
_, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, pin_bcrypt TEXT)")
_, _ = db.Exec("INSERT INTO auth (id, pin_bcrypt) VALUES (1, NULL)")
},
wantStatus: "error",
wantMessage: "Dashboard PIN not configured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, _ := sql.Open("sqlite", ":memory:")
tt.dbSetup(db)
defer db.Close()
c := &Checker{db: db}
result := c.checkPINConfigured()
if result.Name != "pin_configured" {
t.Errorf("Name = %v, want pin_configured", result.Name)
}
if result.Status != tt.wantStatus {
t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus)
}
if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage)
}
})
}
}
func TestCheckNodeTokenConsistency(t *testing.T) {
tests := []struct {
name string
getNodes func() ([]NodeInfo, error)
wantStatus string
wantMessage string
}{
{
name: "no nodes",
getNodes: func() ([]NodeInfo, error) {
return []NodeInfo{}, nil
},
wantStatus: "ok",
},
{
name: "nodes with valid MACs",
getNodes: func() ([]NodeInfo, error) {
return []NodeInfo{
{MAC: "AA:BB:CC:DD:EE:FF"},
{MAC: "11:22:33:44:55:66"},
}, nil
},
wantStatus: "ok",
},
{
name: "node with empty MAC",
getNodes: func() ([]NodeInfo, error) {
return []NodeInfo{{MAC: ""}}, nil
},
wantStatus: "error",
wantMessage: "Node with empty MAC address",
},
{
name: "get nodes error",
getNodes: func() ([]NodeInfo, error) {
return nil, sql.ErrConnDone
},
wantStatus: "warn",
wantMessage: "Cannot check node tokens",
},
{
name: "nil getNodes function",
getNodes: nil,
wantStatus: "ok",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Checker{fleetGetNodes: tt.getNodes}
result := c.checkNodeTokenConsistency()
if result.Name != "node_token_consistency" {
t.Errorf("Name = %v, want node_token_consistency", result.Name)
}
if result.Status != tt.wantStatus {
t.Errorf("Status = %v, want %v", result.Status, tt.wantStatus)
}
if tt.wantMessage != "" && !strings.Contains(result.Message, tt.wantMessage) {
t.Errorf("Message = %v, want to contain %v", result.Message, tt.wantMessage)
}
})
}
}
func TestCheckOverall(t *testing.T) {
tests := []struct {
name string
checks []CheckResult
wantOverall string
}{
{
name: "all ok",
checks: []CheckResult{
{Name: "check1", Status: "ok"},
{Name: "check2", Status: "ok"},
},
wantOverall: "ok",
},
{
name: "one warn",
checks: []CheckResult{
{Name: "check1", Status: "ok"},
{Name: "check2", Status: "warn"},
},
wantOverall: "warn",
},
{
name: "one error",
checks: []CheckResult{
{Name: "check1", Status: "ok"},
{Name: "check2", Status: "error"},
},
wantOverall: "error",
},
{
name: "warn and error",
checks: []CheckResult{
{Name: "check1", Status: "warn"},
{Name: "check2", Status: "error"},
},
wantOverall: "error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
overall := "ok"
for _, r := range tt.checks {
if r.Status == "error" {
overall = "error"
break
}
if r.Status == "warn" && overall == "ok" {
overall = "warn"
}
}
if overall != tt.wantOverall {
t.Errorf("Overall = %v, want %v", overall, tt.wantOverall)
}
})
}
}
func TestHandler(t *testing.T) {
requireAuthCalled := false
requireAuth := func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
requireAuthCalled = true
next(w, r)
}
}
// Create a test database
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// Create auth table for tests
_, _ = db.Exec("CREATE TABLE auth (id INTEGER PRIMARY KEY, install_secret BLOB, pin_bcrypt TEXT)")
_, _ = db.Exec("INSERT INTO auth (id, install_secret, pin_bcrypt) VALUES (1, ?, '$2a$12$hash')", []byte("test-secret-32-bytes-long-enough"))
c := &Checker{
db: db,
dataDir: t.TempDir(),
firmwareDir: t.TempDir(),
installSecret: []byte("test-secret-32-bytes-long-enough"),
}
// Create firmware file
_ = os.WriteFile(filepath.Join(c.firmwareDir, "test.bin"), []byte("firmware"), 0644)
handler := c.Handler(requireAuth)
req := httptest.NewRequest("GET", "/api/doctor", nil)
w := httptest.NewRecorder()
handler(w, req)
if !requireAuthCalled {
t.Error("requireAuth was not called")
}
if w.Code != http.StatusOK {
t.Errorf("Status = %v, want %v", w.Code, http.StatusOK)
}
var response Response
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if response.Overall == "" {
t.Error("Overall status is empty")
}
if response.CheckedAt == "" {
t.Error("CheckedAt is empty")
}
if len(response.Checks) != 9 {
t.Errorf("Got %d checks, want 9", len(response.Checks))
}
}
func TestHandlerMethodNotAllowed(t *testing.T) {
requireAuth := func(next http.HandlerFunc) http.HandlerFunc {
return next
}
c := &Checker{}
handler := c.Handler(requireAuth)
req := httptest.NewRequest("POST", "/api/doctor", nil)
w := httptest.NewRecorder()
handler(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Status = %v, want %v", w.Code, http.StatusMethodNotAllowed)
}
}