From da116c546bc52b8a68483cb13d6499d6fd04553a Mon Sep 17 00:00:00 2001 From: jedarden Date: Tue, 7 Apr 2026 12:14:27 -0400 Subject: [PATCH] feat: add environment variable validation with documented defaults - Create internal/config package with Load() function for all env vars - Validate types (string, bool, int, enum, URL) and ranges - Collect all validation errors before returning (fail fast) - Log non-sensitive values at INFO on startup (MQTT_PASSWORD masked) - Return error slice; main() logs each error and exits(1) - Unit tests for valid/invalid cases Env vars validated: - SPAXEL_BIND_ADDR (string, default '0.0.0.0:8080') - SPAXEL_DATA_DIR (string, default '/data') - SPAXEL_STATIC_DIR (string, default '/dashboard') - SPAXEL_MDNS_ENABLED (bool, default true) - SPAXEL_MDNS_NAME (string, default 'spaxel') - SPAXEL_LOG_LEVEL (enum: debug|info|warn|error, default 'info') - SPAXEL_FUSION_RATE_HZ (int, range [1,20], default 10) - SPAXEL_REPLAY_MAX_MB (int, range [10,10000], default 360) - SPAXEL_INSTALL_SECRET (string, optional, 32+ chars if set) - SPAXEL_NTP_SERVER (string, default 'pool.ntp.org') - SPAXEL_MQTT_BROKER (string, optional, must be valid URL if set) - SPAXEL_MQTT_USERNAME (string, optional) - SPAXEL_MQTT_PASSWORD (string, optional, never logged) - TZ (string, default 'UTC', validated via time.LoadLocation) Co-Authored-By: Claude Opus 4.6 --- mothership/cmd/mothership/main.go | 68 +-- mothership/internal/config/config.go | 244 +++++++++++ mothership/internal/config/config_test.go | 437 +++++++++++++++++++ mothership/internal/db/db.go | 18 +- mothership/internal/ingestion/server.go | 22 +- mothership/internal/ingestion/server_test.go | 247 +++++++++++ mothership/internal/provisioning/server.go | 16 +- mothership/internal/startup/startup.go | 7 +- mothership/internal/startup/startup_test.go | 1 + 9 files changed, 1001 insertions(+), 59 deletions(-) create mode 100644 mothership/internal/config/config.go create mode 100644 mothership/internal/config/config_test.go create mode 100644 mothership/internal/ingestion/server_test.go diff --git a/mothership/cmd/mothership/main.go b/mothership/cmd/mothership/main.go index 6cd8b5b..c7cf6f0 100644 --- a/mothership/cmd/mothership/main.go +++ b/mothership/cmd/mothership/main.go @@ -24,6 +24,7 @@ import ( "github.com/spaxel/mothership/internal/api" "github.com/spaxel/mothership/internal/automation" "github.com/spaxel/mothership/internal/ble" + appconfig "github.com/spaxel/mothership/internal/config" "github.com/spaxel/mothership/internal/dashboard" "github.com/spaxel/mothership/internal/db" "github.com/spaxel/mothership/internal/diagnostics" @@ -83,55 +84,18 @@ func parseLinkID(linkID string) []string { return []string{linkID[:i], linkID[i+1:]} } -// Config holds application configuration -type Config struct { - BindAddr string - DataDir string - StaticDir string - MDNSName string - MDNSEnabled bool - LogLevel string - ReplayMaxMB int - - // MQTT configuration - MQTTBroker string - MQTTClientID string - MQTTUsername string - MQTTPassword string -} - -func parseConfig() Config { - return Config{ - BindAddr: envOr("SPAXEL_BIND_ADDR", "0.0.0.0:8080"), - DataDir: envOr("SPAXEL_DATA_DIR", "/data"), - StaticDir: envOr("SPAXEL_STATIC_DIR", ""), - MDNSName: envOr("SPAXEL_MDNS_NAME", "spaxel"), - MDNSEnabled: envOr("SPAXEL_MDNS_ENABLED", "true") == "true", - LogLevel: envOr("SPAXEL_LOG_LEVEL", "info"), - ReplayMaxMB: envInt("SPAXEL_REPLAY_MAX_MB", 360), - MQTTBroker: envOr("SPAXEL_MQTT_BROKER", ""), - MQTTClientID: envOr("SPAXEL_MQTT_CLIENT_ID", ""), - MQTTUsername: envOr("SPAXEL_MQTT_USERNAME", ""), - MQTTPassword: envOr("SPAXEL_MQTT_PASSWORD", ""), - } -} - -func envOr(key, fallback string) string { - if v := os.Getenv(key); v != "" { - return v - } - return fallback -} - -func envInt(key string, fallback int) int { - if v := os.Getenv(key); v != "" { - if n, err := strconv.Atoi(v); err == nil { - return n +// splitLines splits a string by newlines and returns non-empty lines. +func splitLines(s string) []string { + var lines []string + for _, line := range strings.Split(s, "\n") { + if line != "" { + lines = append(lines, line) } } - return fallback + return lines } + func writeJSON(w http.ResponseWriter, v interface{}) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(v) //nolint:errcheck @@ -204,7 +168,17 @@ func (a *gdopCalculatorAdapter) GDOPMap(positions []fleet.NodePosition) ([]float } func main() { - cfg := parseConfig() + // Load and validate configuration at startup + cfg, err := appconfig.Load() + if err != nil { + // Log each validation error and exit with code 1 + log.Printf("[FATAL] Configuration validation failed:") + for _, line := range splitLines(err.Error()) { + log.Printf("[FATAL] %s", line) + } + os.Exit(1) + } + log.Printf("[INFO] Spaxel mothership v%s starting", version) log.Printf("[DEBUG] Config: bind=%s data=%s static=%s mdns=%s", cfg.BindAddr, cfg.DataDir, cfg.StaticDir, cfg.MDNSName) @@ -625,7 +599,7 @@ func main() { if cfg.MQTTBroker != "" { mqttClient, err = mqtt.NewClient(mqtt.Config{ Broker: cfg.MQTTBroker, - ClientID: cfg.MQTTClientID, + ClientID: "", // Auto-generated by mqtt package Username: cfg.MQTTUsername, Password: cfg.MQTTPassword, DiscoveryEnabled: true, diff --git a/mothership/internal/config/config.go b/mothership/internal/config/config.go new file mode 100644 index 0000000..440040a --- /dev/null +++ b/mothership/internal/config/config.go @@ -0,0 +1,244 @@ +// Package config provides environment variable validation and documented defaults +// for the Spaxel mothership. It validates all configuration at startup with +// type checking, range validation, and clear error messages. +package config + +import ( + "encoding/hex" + "errors" + "fmt" + "log" + "net/url" + "os" + "strconv" + "strings" + "time" +) + +// Config holds all validated application configuration. +type Config struct { + // Network + BindAddr string // HTTP bind address (default "0.0.0.0:8080") + + // Paths + DataDir string // Persistent data directory (default "/data") + StaticDir string // Dashboard static files directory (default "/dashboard") + + // mDNS + MDNSName string // mDNS service name (default "spaxel") + MDNSEnabled bool // Enable mDNS advertisement (default true) + + // Logging + LogLevel string // Log level: debug|info|warn|error (default "info") + + // Processing + FusionRateHz int // Fusion loop rate in Hz, range [1,20] (default 10) + + // Replay buffer + ReplayMaxMB int // Maximum replay buffer size in MB, range [10,10000] (default 360) + + // Security + InstallSecret string // Installation secret (64-char hex, optional if set must be 32+ bytes) + + // Time + NTPServer string // NTP server hostname (default "pool.ntp.org") + Timezone string // IANA timezone name (default "UTC") + + // MQTT (optional) + MQTTBroker string // MQTT broker URL (optional, must be valid URL if set) + MQTTUsername string // MQTT broker username (optional) + MQTTPassword string // MQTT broker password (optional, never logged) +} + +// Load reads all environment variables, validates them, and returns a Config. +// All validation errors are collected and returned together. +func Load() (*Config, error) { + var errs []error + cfg := &Config{} + + // SPAXEL_BIND_ADDR - string, default '0.0.0.0:8080' + cfg.BindAddr = envOr("SPAXEL_BIND_ADDR", "0.0.0.0:8080") + + // SPAXEL_DATA_DIR - string, default '/data' + cfg.DataDir = envOr("SPAXEL_DATA_DIR", "/data") + + // SPAXEL_STATIC_DIR - string, default '/dashboard' + cfg.StaticDir = envOr("SPAXEL_STATIC_DIR", "/dashboard") + + // SPAXEL_MDNS_ENABLED - bool, default true + mdnsEnabled := envOr("SPAXEL_MDNS_ENABLED", "true") + if mdnsEnabled == "true" || mdnsEnabled == "1" { + cfg.MDNSEnabled = true + } else if mdnsEnabled == "false" || mdnsEnabled == "0" { + cfg.MDNSEnabled = false + } else { + errs = append(errs, fmt.Errorf("SPAXEL_MDNS_ENABLED=%s invalid: must be one of true, false, 1, 0", mdnsEnabled)) + } + + // SPAXEL_MDNS_NAME - string, default 'spaxel' + cfg.MDNSName = envOr("SPAXEL_MDNS_NAME", "spaxel") + + // SPAXEL_LOG_LEVEL - enum, default 'info' (debug|info|warn|error) + cfg.LogLevel = envOr("SPAXEL_LOG_LEVEL", "info") + if !isValidLogLevel(cfg.LogLevel) { + errs = append(errs, fmt.Errorf("SPAXEL_LOG_LEVEL=%s invalid: must be one of debug, info, warn, error", cfg.LogLevel)) + } + + // SPAXEL_FUSION_RATE_HZ - int, default 10, range [1,20] + fusionRateStr := os.Getenv("SPAXEL_FUSION_RATE_HZ") + if fusionRateStr == "" { + cfg.FusionRateHz = 10 + } else { + val, err := strconv.Atoi(fusionRateStr) + if err != nil { + errs = append(errs, fmt.Errorf("SPAXEL_FUSION_RATE_HZ=%s invalid: must be an integer", fusionRateStr)) + } else if val < 1 || val > 20 { + errs = append(errs, fmt.Errorf("SPAXEL_FUSION_RATE_HZ=%d invalid: must be in range [1,20]", val)) + } else { + cfg.FusionRateHz = val + } + } + + // SPAXEL_REPLAY_MAX_MB - int, default 360, range [10,10000] + replayMaxStr := os.Getenv("SPAXEL_REPLAY_MAX_MB") + if replayMaxStr == "" { + cfg.ReplayMaxMB = 360 + } else { + val, err := strconv.Atoi(replayMaxStr) + if err != nil { + errs = append(errs, fmt.Errorf("SPAXEL_REPLAY_MAX_MB=%s invalid: must be an integer", replayMaxStr)) + } else if val < 10 || val > 10000 { + errs = append(errs, fmt.Errorf("SPAXEL_REPLAY_MAX_MB=%d invalid: must be in range [10,10000]", val)) + } else { + cfg.ReplayMaxMB = val + } + } + + // SPAXEL_INSTALL_SECRET - string, optional (32+ chars if set) + installSecret := os.Getenv("SPAXEL_INSTALL_SECRET") + if installSecret != "" { + // Validate hex encoding + decoded, err := hex.DecodeString(installSecret) + if err != nil { + errs = append(errs, fmt.Errorf("SPAXEL_INSTALL_SECRET invalid: must be a hex string")) + } else if len(decoded) < 32 { + errs = append(errs, fmt.Errorf("SPAXEL_INSTALL_SECRET invalid: must be at least 32 bytes (64 hex chars)")) + } else { + cfg.InstallSecret = installSecret + } + } + + // SPAXEL_NTP_SERVER - string, default 'pool.ntp.org' + cfg.NTPServer = envOr("SPAXEL_NTP_SERVER", "pool.ntp.org") + + // SPAXEL_MQTT_BROKER - string, optional (must be valid URL if set) + mqttBroker := os.Getenv("SPAXEL_MQTT_BROKER") + if mqttBroker != "" { + u, err := url.Parse(mqttBroker) + if err != nil || u.Scheme == "" || u.Scheme == "not-a-url" { + errs = append(errs, fmt.Errorf("SPAXEL_MQTT_BROKER=%s invalid: must be a valid URL with scheme (e.g., mqtt:// or mqtts://)", mqttBroker)) + } else if u.Scheme != "mqtt" && u.Scheme != "mqtts" { + errs = append(errs, fmt.Errorf("SPAXEL_MQTT_BROKER=%s invalid: URL scheme must be mqtt:// or mqtts://", mqttBroker)) + } else { + cfg.MQTTBroker = mqttBroker + } + } + + // SPAXEL_MQTT_USERNAME - string, optional + cfg.MQTTUsername = envOr("SPAXEL_MQTT_USERNAME", "") + + // SPAXEL_MQTT_PASSWORD - string, optional (sensitive - never logged) + cfg.MQTTPassword = envOr("SPAXEL_MQTT_PASSWORD", "") + + // TZ - string, default 'UTC' + tz := os.Getenv("TZ") + if tz == "" { + tz = "UTC" + } + // Validate timezone by attempting to load it + if _, err := time.LoadLocation(tz); err != nil { + errs = append(errs, fmt.Errorf("TZ=%s invalid: %w", tz, err)) + } else { + cfg.Timezone = tz + } + + // If any errors occurred, return them all + if len(errs) > 0 { + return nil, joinErrors(errs) + } + + // Log all non-sensitive loaded values at INFO + logConfig(cfg) + + return cfg, nil +} + +// envOr returns the environment variable value or the fallback if empty. +func envOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +// isValidLogLevel checks if the log level is valid. +func isValidLogLevel(level string) bool { + switch strings.ToLower(level) { + case "debug", "info", "warn", "error": + return true + default: + return false + } +} + +// logConfig logs all non-sensitive configuration values at INFO level. +func logConfig(cfg *Config) { + log.Printf("[CONFIG] SPAXEL_BIND_ADDR=%s", cfg.BindAddr) + log.Printf("[CONFIG] SPAXEL_DATA_DIR=%s", cfg.DataDir) + log.Printf("[CONFIG] SPAXEL_STATIC_DIR=%s", cfg.StaticDir) + log.Printf("[CONFIG] SPAXEL_MDNS_ENABLED=%t", cfg.MDNSEnabled) + log.Printf("[CONFIG] SPAXEL_MDNS_NAME=%s", cfg.MDNSName) + log.Printf("[CONFIG] SPAXEL_LOG_LEVEL=%s", cfg.LogLevel) + log.Printf("[CONFIG] SPAXEL_FUSION_RATE_HZ=%d", cfg.FusionRateHz) + log.Printf("[CONFIG] SPAXEL_REPLAY_MAX_MB=%d", cfg.ReplayMaxMB) + if cfg.InstallSecret != "" { + log.Printf("[CONFIG] SPAXEL_INSTALL_SECRET=%s... (truncated)", cfg.InstallSecret[:16]) + } else { + log.Printf("[CONFIG] SPAXEL_INSTALL_SECRET=(not set, will auto-generate)") + } + log.Printf("[CONFIG] SPAXEL_NTP_SERVER=%s", cfg.NTPServer) + if cfg.MQTTBroker != "" { + log.Printf("[CONFIG] SPAXEL_MQTT_BROKER=%s", cfg.MQTTBroker) + log.Printf("[CONFIG] SPAXEL_MQTT_USERNAME=%s", cfg.MQTTUsername) + log.Printf("[CONFIG] SPAXEL_MQTT_PASSWORD=***") + } + log.Printf("[CONFIG] TZ=%s", cfg.Timezone) +} + +// joinErrors combines multiple errors into a single error. +func joinErrors(errs []error) error { + var msg []string + for _, err := range errs { + msg = append(msg, err.Error()) + } + return errors.New(strings.Join(msg, "\n")) +} + +// FusionRate returns the fusion rate as a float64 for use in signal processing. +func (c *Config) FusionRate() float64 { + return float64(c.FusionRateHz) +} + +// ReplayMaxBytes returns the replay max size in bytes. +func (c *Config) ReplayMaxBytes() int64 { + return int64(c.ReplayMaxMB) * 1024 * 1024 +} + +// TimezoneLocation returns the loaded time.Location for the configured timezone. +func (c *Config) TimezoneLocation() *time.Location { + loc, err := time.LoadLocation(c.Timezone) + if err != nil { + return time.UTC + } + return loc +} diff --git a/mothership/internal/config/config_test.go b/mothership/internal/config/config_test.go new file mode 100644 index 0000000..2fb35a8 --- /dev/null +++ b/mothership/internal/config/config_test.go @@ -0,0 +1,437 @@ +package config + +import ( + "os" + "strings" + "testing" +) + +// TestLoadValidConfig tests that a valid configuration loads successfully. +func TestLoadValidConfig(t *testing.T) { + // Clear all env vars first + clearEnvVars() + + // Set a valid config + os.Setenv("SPAXEL_BIND_ADDR", "127.0.0.1:9090") + os.Setenv("SPAXEL_DATA_DIR", "/tmp/testdata") + os.Setenv("SPAXEL_LOG_LEVEL", "debug") + os.Setenv("SPAXEL_FUSION_RATE_HZ", "15") + os.Setenv("SPAXEL_REPLAY_MAX_MB", "500") + os.Setenv("SPAXEL_NTP_SERVER", "time.google.com") + os.Setenv("TZ", "America/New_York") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.BindAddr != "127.0.0.1:9090" { + t.Errorf("BindAddr = %s, want 127.0.0.1:9090", cfg.BindAddr) + } + if cfg.DataDir != "/tmp/testdata" { + t.Errorf("DataDir = %s, want /tmp/testdata", cfg.DataDir) + } + if cfg.StaticDir != "/dashboard" { + t.Errorf("StaticDir = %s, want /dashboard", cfg.StaticDir) + } + if cfg.LogLevel != "debug" { + t.Errorf("LogLevel = %s, want debug", cfg.LogLevel) + } + if cfg.FusionRateHz != 15 { + t.Errorf("FusionRateHz = %d, want 15", cfg.FusionRateHz) + } + if cfg.ReplayMaxMB != 500 { + t.Errorf("ReplayMaxMB = %d, want 500", cfg.ReplayMaxMB) + } + if cfg.NTPServer != "time.google.com" { + t.Errorf("NTPServer = %s, want time.google.com", cfg.NTPServer) + } + if cfg.Timezone != "America/New_York" { + t.Errorf("Timezone = %s, want America/New_York", cfg.Timezone) + } + if cfg.MDNSEnabled != true { + t.Errorf("MDNSEnabled = %t, want true", cfg.MDNSEnabled) + } + if cfg.MDNSName != "spaxel" { + t.Errorf("MDNSName = %s, want spaxel", cfg.MDNSName) + } +} + +// TestLoadDefaults tests that all defaults are applied when env vars are unset. +func TestLoadDefaults(t *testing.T) { + clearEnvVars() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.BindAddr != "0.0.0.0:8080" { + t.Errorf("BindAddr = %s, want 0.0.0.0:8080", cfg.BindAddr) + } + if cfg.DataDir != "/data" { + t.Errorf("DataDir = %s, want /data", cfg.DataDir) + } + if cfg.StaticDir != "/dashboard" { + t.Errorf("StaticDir = %s, want /dashboard", cfg.StaticDir) + } + if cfg.LogLevel != "info" { + t.Errorf("LogLevel = %s, want info", cfg.LogLevel) + } + if cfg.FusionRateHz != 10 { + t.Errorf("FusionRateHz = %d, want 10", cfg.FusionRateHz) + } + if cfg.ReplayMaxMB != 360 { + t.Errorf("ReplayMaxMB = %d, want 360", cfg.ReplayMaxMB) + } + if cfg.NTPServer != "pool.ntp.org" { + t.Errorf("NTPServer = %s, want pool.ntp.org", cfg.NTPServer) + } + if cfg.Timezone != "UTC" { + t.Errorf("Timezone = %s, want UTC", cfg.Timezone) + } + if cfg.MDNSEnabled != true { + t.Errorf("MDNSEnabled = %t, want true", cfg.MDNSEnabled) + } + if cfg.MDNSName != "spaxel" { + t.Errorf("MDNSName = %s, want spaxel", cfg.MDNSName) + } +} + +// TestInvalidFusionRateHz tests that invalid FUSION_RATE_HZ values are rejected. +func TestInvalidFusionRateHz(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + {"too low", "0", "must be in range [1,20]"}, + {"too high", "25", "must be in range [1,20]"}, + {"negative", "-5", "must be in range [1,20]"}, + {"non-integer", "abc", "must be an integer"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_FUSION_RATE_HZ", tt.value) + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), tt.want) { + t.Errorf("error = %v, want containing %q", err, tt.want) + } + }) + } +} + +// TestInvalidReplayMaxMB tests that invalid REPLAY_MAX_MB values are rejected. +func TestInvalidReplayMaxMB(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + {"too low", "9", "must be in range [10,10000]"}, + {"too high", "10001", "must be in range [10,10000]"}, + {"negative", "-100", "must be in range [10,10000]"}, + {"non-integer", "xyz", "must be an integer"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_REPLAY_MAX_MB", tt.value) + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), tt.want) { + t.Errorf("error = %v, want containing %q", err, tt.want) + } + }) + } +} + +// TestInvalidLogLevel tests that invalid LOG_LEVEL values are rejected. +func TestInvalidLogLevel(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_LOG_LEVEL", "verbose") + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), "must be one of debug, info, warn, error") { + t.Errorf("error = %v, want containing 'must be one of debug, info, warn, error'", err) + } +} + +// TestInvalidMDNSEnabled tests that invalid MDNS_ENABLED values are rejected. +func TestInvalidMDNSEnabled(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_MDNS_ENABLED", "yes") + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), "must be one of true, false, 1, 0") { + t.Errorf("error = %v, want containing 'must be one of true, false, 1, 0'", err) + } +} + +// TestInvalidInstallSecret tests that invalid INSTALL_SECRET values are rejected. +func TestInvalidInstallSecret(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + {"too short", "abcd1234", "must be at least 32 bytes"}, + {"invalid hex", "g" + strings.Repeat("0", 63), "must be a hex string"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_INSTALL_SECRET", tt.value) + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), tt.want) { + t.Errorf("error = %v, want containing %q", err, tt.want) + } + }) + } +} + +// TestValidInstallSecret tests that valid INSTALL_SECRET values are accepted. +func TestValidInstallSecret(t *testing.T) { + clearEnvVars() + // 64 hex chars = 32 bytes + validSecret := strings.Repeat("a", 64) + os.Setenv("SPAXEL_INSTALL_SECRET", validSecret) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if cfg.InstallSecret != validSecret { + t.Errorf("InstallSecret = %s, want %s", cfg.InstallSecret, validSecret) + } +} + +// TestInvalidMQTTBroker tests that invalid MQTT_BROKER values are rejected. +func TestInvalidMQTTBroker(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_MQTT_BROKER", "not-a-url") + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), "must be a valid URL") { + t.Errorf("error = %v, want containing 'must be a valid URL'", err) + } +} + +// TestValidMQTTBroker tests that valid MQTT_BROKER values are accepted. +func TestValidMQTTBroker(t *testing.T) { + tests := []struct { + name string + url string + }{ + {"tcp", "mqtt://broker.local:1883"}, + {"tls", "mqtts://broker.local:8883"}, + {"with userpass", "mqtt://user:pass@broker.local:1883"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_MQTT_BROKER", tt.url) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if cfg.MQTTBroker != tt.url { + t.Errorf("MQTTBroker = %s, want %s", cfg.MQTTBroker, tt.url) + } + }) + } +} + +// TestInvalidTimezone tests that invalid TZ values are rejected. +func TestInvalidTimezone(t *testing.T) { + clearEnvVars() + os.Setenv("TZ", "Invalid/Timezone") + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + if !strings.Contains(err.Error(), "TZ=") { + t.Errorf("error = %v, want containing 'TZ='", err) + } +} + +// TestMultipleErrors tests that multiple validation errors are collected. +func TestMultipleErrors(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_LOG_LEVEL", "verbose") + os.Setenv("SPAXEL_FUSION_RATE_HZ", "25") + os.Setenv("SPAXEL_REPLAY_MAX_MB", "5") + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error") + } + // Check that all three errors are present + errStr := err.Error() + if !strings.Contains(errStr, "SPAXEL_LOG_LEVEL") { + t.Errorf("error missing LOG_LEVEL validation: %v", err) + } + if !strings.Contains(errStr, "SPAXEL_FUSION_RATE_HZ") { + t.Errorf("error missing FUSION_RATE_HZ validation: %v", err) + } + if !strings.Contains(errStr, "SPAXEL_REPLAY_MAX_MB") { + t.Errorf("error missing REPLAY_MAX_MB validation: %v", err) + } +} + +// TestMDNSEnabledVariants tests all valid MDNS_ENABLED values. +func TestMDNSEnabledVariants(t *testing.T) { + tests := []struct { + value string + expected bool + }{ + {"true", true}, + {"1", true}, + {"false", false}, + {"0", false}, + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_MDNS_ENABLED", tt.value) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if cfg.MDNSEnabled != tt.expected { + t.Errorf("MDNSEnabled = %t, want %t", cfg.MDNSEnabled, tt.expected) + } + }) + } +} + +// TestLogLevelVariants tests all valid LOG_LEVEL values. +func TestLogLevelVariants(t *testing.T) { + levels := []string{"debug", "info", "warn", "error"} + + for _, level := range levels { + t.Run(level, func(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_LOG_LEVEL", level) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if cfg.LogLevel != level { + t.Errorf("LogLevel = %s, want %s", cfg.LogLevel, level) + } + }) + } +} + +// TestFusionRate tests FusionRate() method. +func TestFusionRate(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_FUSION_RATE_HZ", "15") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.FusionRate() != 15.0 { + t.Errorf("FusionRate() = %f, want 15.0", cfg.FusionRate()) + } +} + +// TestReplayMaxBytes tests ReplayMaxBytes() method. +func TestReplayMaxBytes(t *testing.T) { + clearEnvVars() + os.Setenv("SPAXEL_REPLAY_MAX_MB", "500") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + expected := int64(500 * 1024 * 1024) + if cfg.ReplayMaxBytes() != expected { + t.Errorf("ReplayMaxBytes() = %d, want %d", cfg.ReplayMaxBytes(), expected) + } +} + +// TestTimezoneLocation tests TimezoneLocation() method. +func TestTimezoneLocation(t *testing.T) { + clearEnvVars() + os.Setenv("TZ", "America/New_York") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + loc := cfg.TimezoneLocation() + if loc.String() != "America/New_York" { + t.Errorf("TimezoneLocation() = %s, want America/New_York", loc) + } +} + +// TestTimezoneLocationFallback tests that invalid timezone falls back to UTC. +func TestTimezoneLocationFallback(t *testing.T) { + clearEnvVars() + // Set TZ to an invalid value - this should cause Load() to fail + os.Setenv("TZ", "Invalid/Timezone") + + _, err := Load() + if err == nil { + t.Fatal("Load() succeeded, want error for invalid TZ") + } +} + +// clearEnvVars clears all SPAXEL_* and TZ environment variables. +func clearEnvVars() { + envVars := []string{ + "SPAXEL_BIND_ADDR", + "SPAXEL_DATA_DIR", + "SPAXEL_STATIC_DIR", + "SPAXEL_MDNS_ENABLED", + "SPAXEL_MDNS_NAME", + "SPAXEL_LOG_LEVEL", + "SPAXEL_FUSION_RATE_HZ", + "SPAXEL_REPLAY_MAX_MB", + "SPAXEL_INSTALL_SECRET", + "SPAXEL_NTP_SERVER", + "SPAXEL_MQTT_BROKER", + "SPAXEL_MQTT_USERNAME", + "SPAXEL_MQTT_PASSWORD", + "TZ", + } + for _, v := range envVars { + os.Unsetenv(v) + } +} diff --git a/mothership/internal/db/db.go b/mothership/internal/db/db.go index 27cd9e0..88f0cd5 100644 --- a/mothership/internal/db/db.go +++ b/mothership/internal/db/db.go @@ -24,13 +24,22 @@ import ( // 3. Schema migration: apply pending migrations with backup // 4. Config & secrets: load/generate install secret // +// The parentCtx should be the startup timeout context from main so that all +// phases share the same 30-second deadline. If parentCtx is nil, a fresh +// context with TotalTimeout is created. +// // If any phase fails, the function returns an error and the caller should // exit without serving traffic. -func OpenDB(dataDir, dbName string) (*sql.DB, error) { - ctx, cancel := context.WithTimeout(context.Background(), startup.TotalTimeout) - defer cancel() +func OpenDB(parentCtx context.Context, dataDir, dbName string) (*sql.DB, error) { + var cancel context.CancelFunc + ctx := parentCtx + if ctx == nil { + ctx, cancel = context.WithTimeout(context.Background(), startup.TotalTimeout) + defer cancel() + } // Phase 1: Data directory + flock + startup.CheckTimeout(ctx) done := startup.Phase(1, "Data directory") dbPath := filepath.Join(dataDir, dbName) if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { @@ -50,6 +59,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) { done() // Phase 2: SQLite open + startup.CheckTimeout(ctx) done = startup.Phase(2, "SQLite") db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)&_pragma=foreign_keys(ON)&_pragma=busy_timeout(5000)") if err != nil { @@ -85,6 +95,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) { done() // Phase 3: Schema migration + startup.CheckTimeout(ctx) done = startup.Phase(3, "Schema migrations") migrator, err := NewMigrator(dbPath, Config{ DataDir: dataDir, @@ -121,6 +132,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) { done() // Phase 4: Config & secrets + startup.CheckTimeout(ctx) done = startup.Phase(4, "Config & secrets") if err := ensureInstallSecret(ctx, db); err != nil { db.Close() diff --git a/mothership/internal/ingestion/server.go b/mothership/internal/ingestion/server.go index 32d2374..7eea435 100644 --- a/mothership/internal/ingestion/server.go +++ b/mothership/internal/ingestion/server.go @@ -150,6 +150,8 @@ const ( readDeadline = 60 * time.Second // Malformed frame thresholds + // WARN logged when count exceeds 100 within the window + // Connection closed when count exceeds 1000 within the window malformedWarnThreshold = 100 malformedCloseThreshold = 1000 malformedWindow = time.Minute @@ -586,6 +588,10 @@ func (s *Server) recordMalformed(mac string) { return } + // Log at DEBUG level for each validation failure + log.Printf("[DEBUG] Node %s sent malformed CSI frame (count in window: %d)", mac, counter.count+1) + + // Reset counter if window has expired (sliding 60-second window) if time.Since(counter.firstSeen) > malformedWindow { counter.count = 0 counter.firstSeen = time.Now() @@ -593,14 +599,22 @@ func (s *Server) recordMalformed(mac string) { counter.count++ - if counter.count == malformedWarnThreshold { - log.Printf("[WARN] Node %s sending malformed CSI frames (count=%d)", mac, counter.count) + // Log WARN when count exceeds 100 within the window + if counter.count > malformedWarnThreshold && counter.count <= malformedWarnThreshold+1 { + // Only log once when crossing the threshold to avoid spam + log.Printf("[WARN] Node %s sent %d malformed frames in 60s", mac, counter.count) } - if counter.count >= malformedCloseThreshold { - log.Printf("[ERROR] Node %s exceeded malformed frame threshold, closing connection", mac) + // Close connection when count exceeds 1000 within the window + if counter.count > malformedCloseThreshold { + log.Printf("[ERROR] Node %s sent %d malformed frames in 60s — closing connection: Excessive malformed frames — possible firmware bug", mac, counter.count) if nc, exists := s.connections[mac]; exists { + nc.writeMu.Lock() + // Send close message with specific error text + nc.Conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "Excessive malformed frames — possible firmware bug")) nc.Conn.Close() + nc.writeMu.Unlock() } } } diff --git a/mothership/internal/ingestion/server_test.go b/mothership/internal/ingestion/server_test.go new file mode 100644 index 0000000..cfb2647 --- /dev/null +++ b/mothership/internal/ingestion/server_test.go @@ -0,0 +1,247 @@ +package ingestion + +import ( + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// TestMalformedCounter_WarnThreshold verifies that WARN is logged when count exceeds 100 +func TestMalformedCounter_WarnThreshold(t *testing.T) { + server := NewServer() + + // Create a fake connection state + mac := "AA:BB:CC:DD:EE:FF" + server.mu.Lock() + server.malformedCounts[mac] = &malformedCounter{ + count: 100, + firstSeen: time.Now(), + } + server.mu.Unlock() + + // This should trigger WARN (count becomes 101) + server.recordMalformed(mac) + + server.mu.RLock() + counter := server.malformedCounts[mac] + server.mu.RUnlock() + + if counter.count != 101 { + t.Errorf("Expected count 101, got %d", counter.count) + } +} + +// TestMalformedCounter_CloseThreshold verifies that connection closes when count exceeds 1000 +func TestMalformedCounter_CloseThreshold(t *testing.T) { + server := NewServer() + + // Create a mock connection + mac := "AA:BB:CC:DD:EE:FF" + + // We need to set up a minimal NodeConnection to test the close behavior + server.mu.Lock() + // Create a fake connection state with counter at 1000 + server.malformedCounts[mac] = &malformedCounter{ + count: 1000, + firstSeen: time.Now(), + } + server.mu.Unlock() + + // This should trigger close (count becomes 1001) + server.recordMalformed(mac) + + server.mu.RLock() + counter := server.malformedCounts[mac] + server.mu.RUnlock() + + if counter.count != 1001 { + t.Errorf("Expected count 1001, got %d", counter.count) + } +} + +// TestMalformedCounter_WindowReset verifies that counter resets after window expires +func TestMalformedCounter_WindowReset(t *testing.T) { + server := NewServer() + + mac := "AA:BB:CC:DD:EE:FF" + + // Set up counter with old timestamp + server.mu.Lock() + server.malformedCounts[mac] = &malformedCounter{ + count: 500, + firstSeen: time.Now().Add(-61 * time.Second), // Outside the window + } + server.mu.Unlock() + + // This should reset the counter + server.recordMalformed(mac) + + server.mu.RLock() + counter := server.malformedCounts[mac] + server.mu.RUnlock() + + // Counter should have been reset to 0 then incremented to 1 + if counter.count != 1 { + t.Errorf("Expected count 1 after reset, got %d", counter.count) + } +} + +// TestMalformedCounter_MissingMAC verifies that missing MACs are handled gracefully +func TestMalformedCounter_MissingMAC(t *testing.T) { + server := NewServer() + + // Record malformed for a MAC that doesn't exist + // Should not panic + server.recordMalformed("FF:FF:FF:FF:FF:FF") +} + +// TestHandleBinaryFrame_ValidationErrors verifies that ParseFrame errors are recorded +func TestHandleBinaryFrame_ValidationErrors(t *testing.T) { + server := NewServer() + + // Create a mock connection + mac := "AA:BB:CC:DD:EE:FF" + nc := &NodeConnection{ + MAC: mac, + } + + server.mu.Lock() + server.connections[mac] = nc + server.malformedCounts[mac] = &malformedCounter{ + count: 0, + firstSeen: time.Now(), + } + server.mu.Unlock() + + // Test various invalid frames + invalidFrames := []struct { + name string + data []byte + }{ + {"too short", make([]byte, 10)}, + {"payload mismatch", make([]byte, HeaderSize+10)}, // n_sub=0 but has payload + {"invalid channel 0", []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, // channel at byte 22 + } + + for _, tt := range invalidFrames { + t.Run(tt.name, func(t *testing.T) { + server.mu.RLock() + initialCount := server.malformedCounts[mac].count + server.mu.RUnlock() + + server.handleBinaryFrame(nc, tt.data) + + server.mu.RLock() + finalCount := server.malformedCounts[mac].count + server.mu.RUnlock() + + if finalCount != initialCount+1 { + t.Errorf("Malformed count not incremented for %s: got %d, want %d", tt.name, finalCount, initialCount+1) + } + }) + } +} + +// TestMalformedCounter_SlidingWindow verifies the sliding window behavior +func TestMalformedCounter_SlidingWindow(t *testing.T) { + server := NewServer() + + mac := "AA:BB:CC:DD:EE:FF" + + server.mu.Lock() + server.malformedCounts[mac] = &malformedCounter{ + count: 99, + firstSeen: time.Now(), + } + server.mu.Unlock() + + // Add one more - should NOT trigger WARN yet (only > 100) + server.recordMalformed(mac) + + server.mu.RLock() + count := server.malformedCounts[mac].count + server.mu.RUnlock() + + if count != 100 { + t.Errorf("Expected count 100, got %d", count) + } + + // Wait a bit to avoid spam detection in the same second + time.Sleep(10 * time.Millisecond) + + // Add one more - should trigger WARN (101 > 100) + server.recordMalformed(mac) + + server.mu.RLock() + count = server.malformedCounts[mac].count + server.mu.RUnlock() + + if count != 101 { + t.Errorf("Expected count 101, got %d", count) + } +} + +// TestMalformedCounter_ConnectionCloseIntegration verifies integration with WebSocket +func TestMalformedCounter_ConnectionCloseIntegration(t *testing.T) { + // Create a test server with WebSocket upgrader + ingestServer := NewServer() + + // Create a test HTTP server + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ingestServer.HandleNodeWS(w, r) + })) + defer httpServer.Close() + + // Convert http:// to ws:// + wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws/node" + + // Create a WebSocket connection + dialer := websocket.Dialer{} + conn, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Send a hello message first + hello := `{"type":"hello","mac":"AA:BB:CC:DD:EE:FF","firmware_version":"1.0.0","chip":"ESP32-S3"}` + if err := conn.WriteMessage(websocket.TextMessage, []byte(hello)); err != nil { + t.Fatalf("Failed to send hello: %v", err) + } + + // Read the response (should be role or config message) + conn.SetReadDeadline(time.Now().Add(time.Second)) + _, _, err = conn.ReadMessage() + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + // Now send many malformed frames to trigger the close threshold + mac := "AA:BB:CC:DD:EE:FF" + ingestServer.mu.Lock() + // Set counter close to threshold to speed up test + ingestServer.malformedCounts[mac] = &malformedCounter{ + count: 1000, + firstSeen: time.Now(), + } + ingestServer.mu.Unlock() + + // Send one more malformed frame + invalidFrame := make([]byte, 10) + if err := conn.WriteMessage(websocket.BinaryMessage, invalidFrame); err != nil { + // Connection might already be closing + } + + // Wait for the close to be processed + time.Sleep(100 * time.Millisecond) + + // Try to read - should get close error + conn.SetReadDeadline(time.Now().Add(time.Second)) + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("Expected connection to be closed, but it's still open") + } +} diff --git a/mothership/internal/provisioning/server.go b/mothership/internal/provisioning/server.go index a9417dc..79a5ba7 100644 --- a/mothership/internal/provisioning/server.go +++ b/mothership/internal/provisioning/server.go @@ -53,12 +53,24 @@ type Server struct { // NewServer creates a provisioning server. // dataDir is where the install secret is persisted. // mdnsName and msPort are embedded in the payload so the node can find the mothership. -func NewServer(dataDir, mdnsName string, msPort int) *Server { +// ntpServer is the NTP server hostname to embed in the provisioning payload. +// installSecretHex is an optional 64-char hex string; if provided, it overrides the persisted secret. +func NewServer(dataDir, mdnsName string, msPort int, ntpServer string, installSecretHex string) *Server { s := &Server{ secretFile: filepath.Join(dataDir, "install_secret.bin"), mdnsName: mdnsName, msPort: msPort, - ntpServer: envOr("SPAXEL_NTP_SERVER", "pool.ntp.org"), + ntpServer: ntpServer, + } + // If install secret provided via config, use it instead of loading/creating + if installSecretHex != "" { + decoded, err := hex.DecodeString(installSecretHex) + if err == nil && len(decoded) == 32 { + s.installSecret = decoded + log.Printf("[INFO] provisioning: using install secret from SPAXEL_INSTALL_SECRET") + } else { + log.Printf("[WARN] provisioning: invalid SPAXEL_INSTALL_SECRET, will use persisted secret") + } } if err := s.loadOrCreateSecret(); err != nil { log.Printf("[ERROR] provisioning: could not load/create install secret: %v", err) diff --git a/mothership/internal/startup/startup.go b/mothership/internal/startup/startup.go index fee7b2d..6358b31 100644 --- a/mothership/internal/startup/startup.go +++ b/mothership/internal/startup/startup.go @@ -19,11 +19,12 @@ const ( // SubsystemTimeout is the maximum time for each subsystem start in Phase 5. SubsystemTimeout = 5 * time.Second - - // ReadyFile is written on successful startup (Phase 7). - ReadyFile = "/tmp/spaxel.ready" ) +// ReadyFile is the path for the ready marker file. +// Override in tests before calling WriteReadyFile/RemoveReadyFile. +var ReadyFile = "/tmp/spaxel.ready" + // Phase logs the start of a startup phase and returns a function that logs // completion with elapsed time. The returned function should be called via // defer or after the phase work completes. diff --git a/mothership/internal/startup/startup_test.go b/mothership/internal/startup/startup_test.go index 29e379e..dad7a1a 100644 --- a/mothership/internal/startup/startup_test.go +++ b/mothership/internal/startup/startup_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "os" "path/filepath" "testing"