feat: add environment variable validation with documented defaults

- Create internal/config package with Load() function for all env vars
- Validate types (string, bool, int, enum, URL) and ranges
- Collect all validation errors before returning (fail fast)
- Log non-sensitive values at INFO on startup (MQTT_PASSWORD masked)
- Return error slice; main() logs each error and exits(1)
- Unit tests for valid/invalid cases

Env vars validated:
- SPAXEL_BIND_ADDR (string, default '0.0.0.0:8080')
- SPAXEL_DATA_DIR (string, default '/data')
- SPAXEL_STATIC_DIR (string, default '/dashboard')
- SPAXEL_MDNS_ENABLED (bool, default true)
- SPAXEL_MDNS_NAME (string, default 'spaxel')
- SPAXEL_LOG_LEVEL (enum: debug|info|warn|error, default 'info')
- SPAXEL_FUSION_RATE_HZ (int, range [1,20], default 10)
- SPAXEL_REPLAY_MAX_MB (int, range [10,10000], default 360)
- SPAXEL_INSTALL_SECRET (string, optional, 32+ chars if set)
- SPAXEL_NTP_SERVER (string, default 'pool.ntp.org')
- SPAXEL_MQTT_BROKER (string, optional, must be valid URL if set)
- SPAXEL_MQTT_USERNAME (string, optional)
- SPAXEL_MQTT_PASSWORD (string, optional, never logged)
- TZ (string, default 'UTC', validated via time.LoadLocation)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-04-07 12:14:27 -04:00
parent 529d6108d3
commit da116c546b
9 changed files with 1001 additions and 59 deletions

View file

@ -24,6 +24,7 @@ import (
"github.com/spaxel/mothership/internal/api"
"github.com/spaxel/mothership/internal/automation"
"github.com/spaxel/mothership/internal/ble"
appconfig "github.com/spaxel/mothership/internal/config"
"github.com/spaxel/mothership/internal/dashboard"
"github.com/spaxel/mothership/internal/db"
"github.com/spaxel/mothership/internal/diagnostics"
@ -83,55 +84,18 @@ func parseLinkID(linkID string) []string {
return []string{linkID[:i], linkID[i+1:]}
}
// Config holds application configuration
type Config struct {
BindAddr string
DataDir string
StaticDir string
MDNSName string
MDNSEnabled bool
LogLevel string
ReplayMaxMB int
// MQTT configuration
MQTTBroker string
MQTTClientID string
MQTTUsername string
MQTTPassword string
}
func parseConfig() Config {
return Config{
BindAddr: envOr("SPAXEL_BIND_ADDR", "0.0.0.0:8080"),
DataDir: envOr("SPAXEL_DATA_DIR", "/data"),
StaticDir: envOr("SPAXEL_STATIC_DIR", ""),
MDNSName: envOr("SPAXEL_MDNS_NAME", "spaxel"),
MDNSEnabled: envOr("SPAXEL_MDNS_ENABLED", "true") == "true",
LogLevel: envOr("SPAXEL_LOG_LEVEL", "info"),
ReplayMaxMB: envInt("SPAXEL_REPLAY_MAX_MB", 360),
MQTTBroker: envOr("SPAXEL_MQTT_BROKER", ""),
MQTTClientID: envOr("SPAXEL_MQTT_CLIENT_ID", ""),
MQTTUsername: envOr("SPAXEL_MQTT_USERNAME", ""),
MQTTPassword: envOr("SPAXEL_MQTT_PASSWORD", ""),
}
}
func envOr(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func envInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
// splitLines splits a string by newlines and returns non-empty lines.
func splitLines(s string) []string {
var lines []string
for _, line := range strings.Split(s, "\n") {
if line != "" {
lines = append(lines, line)
}
}
return fallback
return lines
}
func writeJSON(w http.ResponseWriter, v interface{}) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(v) //nolint:errcheck
@ -204,7 +168,17 @@ func (a *gdopCalculatorAdapter) GDOPMap(positions []fleet.NodePosition) ([]float
}
func main() {
cfg := parseConfig()
// Load and validate configuration at startup
cfg, err := appconfig.Load()
if err != nil {
// Log each validation error and exit with code 1
log.Printf("[FATAL] Configuration validation failed:")
for _, line := range splitLines(err.Error()) {
log.Printf("[FATAL] %s", line)
}
os.Exit(1)
}
log.Printf("[INFO] Spaxel mothership v%s starting", version)
log.Printf("[DEBUG] Config: bind=%s data=%s static=%s mdns=%s", cfg.BindAddr, cfg.DataDir, cfg.StaticDir, cfg.MDNSName)
@ -625,7 +599,7 @@ func main() {
if cfg.MQTTBroker != "" {
mqttClient, err = mqtt.NewClient(mqtt.Config{
Broker: cfg.MQTTBroker,
ClientID: cfg.MQTTClientID,
ClientID: "", // Auto-generated by mqtt package
Username: cfg.MQTTUsername,
Password: cfg.MQTTPassword,
DiscoveryEnabled: true,

View file

@ -0,0 +1,244 @@
// Package config provides environment variable validation and documented defaults
// for the Spaxel mothership. It validates all configuration at startup with
// type checking, range validation, and clear error messages.
package config
import (
"encoding/hex"
"errors"
"fmt"
"log"
"net/url"
"os"
"strconv"
"strings"
"time"
)
// Config holds all validated application configuration.
type Config struct {
// Network
BindAddr string // HTTP bind address (default "0.0.0.0:8080")
// Paths
DataDir string // Persistent data directory (default "/data")
StaticDir string // Dashboard static files directory (default "/dashboard")
// mDNS
MDNSName string // mDNS service name (default "spaxel")
MDNSEnabled bool // Enable mDNS advertisement (default true)
// Logging
LogLevel string // Log level: debug|info|warn|error (default "info")
// Processing
FusionRateHz int // Fusion loop rate in Hz, range [1,20] (default 10)
// Replay buffer
ReplayMaxMB int // Maximum replay buffer size in MB, range [10,10000] (default 360)
// Security
InstallSecret string // Installation secret (64-char hex, optional if set must be 32+ bytes)
// Time
NTPServer string // NTP server hostname (default "pool.ntp.org")
Timezone string // IANA timezone name (default "UTC")
// MQTT (optional)
MQTTBroker string // MQTT broker URL (optional, must be valid URL if set)
MQTTUsername string // MQTT broker username (optional)
MQTTPassword string // MQTT broker password (optional, never logged)
}
// Load reads all environment variables, validates them, and returns a Config.
// All validation errors are collected and returned together.
func Load() (*Config, error) {
var errs []error
cfg := &Config{}
// SPAXEL_BIND_ADDR - string, default '0.0.0.0:8080'
cfg.BindAddr = envOr("SPAXEL_BIND_ADDR", "0.0.0.0:8080")
// SPAXEL_DATA_DIR - string, default '/data'
cfg.DataDir = envOr("SPAXEL_DATA_DIR", "/data")
// SPAXEL_STATIC_DIR - string, default '/dashboard'
cfg.StaticDir = envOr("SPAXEL_STATIC_DIR", "/dashboard")
// SPAXEL_MDNS_ENABLED - bool, default true
mdnsEnabled := envOr("SPAXEL_MDNS_ENABLED", "true")
if mdnsEnabled == "true" || mdnsEnabled == "1" {
cfg.MDNSEnabled = true
} else if mdnsEnabled == "false" || mdnsEnabled == "0" {
cfg.MDNSEnabled = false
} else {
errs = append(errs, fmt.Errorf("SPAXEL_MDNS_ENABLED=%s invalid: must be one of true, false, 1, 0", mdnsEnabled))
}
// SPAXEL_MDNS_NAME - string, default 'spaxel'
cfg.MDNSName = envOr("SPAXEL_MDNS_NAME", "spaxel")
// SPAXEL_LOG_LEVEL - enum, default 'info' (debug|info|warn|error)
cfg.LogLevel = envOr("SPAXEL_LOG_LEVEL", "info")
if !isValidLogLevel(cfg.LogLevel) {
errs = append(errs, fmt.Errorf("SPAXEL_LOG_LEVEL=%s invalid: must be one of debug, info, warn, error", cfg.LogLevel))
}
// SPAXEL_FUSION_RATE_HZ - int, default 10, range [1,20]
fusionRateStr := os.Getenv("SPAXEL_FUSION_RATE_HZ")
if fusionRateStr == "" {
cfg.FusionRateHz = 10
} else {
val, err := strconv.Atoi(fusionRateStr)
if err != nil {
errs = append(errs, fmt.Errorf("SPAXEL_FUSION_RATE_HZ=%s invalid: must be an integer", fusionRateStr))
} else if val < 1 || val > 20 {
errs = append(errs, fmt.Errorf("SPAXEL_FUSION_RATE_HZ=%d invalid: must be in range [1,20]", val))
} else {
cfg.FusionRateHz = val
}
}
// SPAXEL_REPLAY_MAX_MB - int, default 360, range [10,10000]
replayMaxStr := os.Getenv("SPAXEL_REPLAY_MAX_MB")
if replayMaxStr == "" {
cfg.ReplayMaxMB = 360
} else {
val, err := strconv.Atoi(replayMaxStr)
if err != nil {
errs = append(errs, fmt.Errorf("SPAXEL_REPLAY_MAX_MB=%s invalid: must be an integer", replayMaxStr))
} else if val < 10 || val > 10000 {
errs = append(errs, fmt.Errorf("SPAXEL_REPLAY_MAX_MB=%d invalid: must be in range [10,10000]", val))
} else {
cfg.ReplayMaxMB = val
}
}
// SPAXEL_INSTALL_SECRET - string, optional (32+ chars if set)
installSecret := os.Getenv("SPAXEL_INSTALL_SECRET")
if installSecret != "" {
// Validate hex encoding
decoded, err := hex.DecodeString(installSecret)
if err != nil {
errs = append(errs, fmt.Errorf("SPAXEL_INSTALL_SECRET invalid: must be a hex string"))
} else if len(decoded) < 32 {
errs = append(errs, fmt.Errorf("SPAXEL_INSTALL_SECRET invalid: must be at least 32 bytes (64 hex chars)"))
} else {
cfg.InstallSecret = installSecret
}
}
// SPAXEL_NTP_SERVER - string, default 'pool.ntp.org'
cfg.NTPServer = envOr("SPAXEL_NTP_SERVER", "pool.ntp.org")
// SPAXEL_MQTT_BROKER - string, optional (must be valid URL if set)
mqttBroker := os.Getenv("SPAXEL_MQTT_BROKER")
if mqttBroker != "" {
u, err := url.Parse(mqttBroker)
if err != nil || u.Scheme == "" || u.Scheme == "not-a-url" {
errs = append(errs, fmt.Errorf("SPAXEL_MQTT_BROKER=%s invalid: must be a valid URL with scheme (e.g., mqtt:// or mqtts://)", mqttBroker))
} else if u.Scheme != "mqtt" && u.Scheme != "mqtts" {
errs = append(errs, fmt.Errorf("SPAXEL_MQTT_BROKER=%s invalid: URL scheme must be mqtt:// or mqtts://", mqttBroker))
} else {
cfg.MQTTBroker = mqttBroker
}
}
// SPAXEL_MQTT_USERNAME - string, optional
cfg.MQTTUsername = envOr("SPAXEL_MQTT_USERNAME", "")
// SPAXEL_MQTT_PASSWORD - string, optional (sensitive - never logged)
cfg.MQTTPassword = envOr("SPAXEL_MQTT_PASSWORD", "")
// TZ - string, default 'UTC'
tz := os.Getenv("TZ")
if tz == "" {
tz = "UTC"
}
// Validate timezone by attempting to load it
if _, err := time.LoadLocation(tz); err != nil {
errs = append(errs, fmt.Errorf("TZ=%s invalid: %w", tz, err))
} else {
cfg.Timezone = tz
}
// If any errors occurred, return them all
if len(errs) > 0 {
return nil, joinErrors(errs)
}
// Log all non-sensitive loaded values at INFO
logConfig(cfg)
return cfg, nil
}
// envOr returns the environment variable value or the fallback if empty.
func envOr(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// isValidLogLevel checks if the log level is valid.
func isValidLogLevel(level string) bool {
switch strings.ToLower(level) {
case "debug", "info", "warn", "error":
return true
default:
return false
}
}
// logConfig logs all non-sensitive configuration values at INFO level.
func logConfig(cfg *Config) {
log.Printf("[CONFIG] SPAXEL_BIND_ADDR=%s", cfg.BindAddr)
log.Printf("[CONFIG] SPAXEL_DATA_DIR=%s", cfg.DataDir)
log.Printf("[CONFIG] SPAXEL_STATIC_DIR=%s", cfg.StaticDir)
log.Printf("[CONFIG] SPAXEL_MDNS_ENABLED=%t", cfg.MDNSEnabled)
log.Printf("[CONFIG] SPAXEL_MDNS_NAME=%s", cfg.MDNSName)
log.Printf("[CONFIG] SPAXEL_LOG_LEVEL=%s", cfg.LogLevel)
log.Printf("[CONFIG] SPAXEL_FUSION_RATE_HZ=%d", cfg.FusionRateHz)
log.Printf("[CONFIG] SPAXEL_REPLAY_MAX_MB=%d", cfg.ReplayMaxMB)
if cfg.InstallSecret != "" {
log.Printf("[CONFIG] SPAXEL_INSTALL_SECRET=%s... (truncated)", cfg.InstallSecret[:16])
} else {
log.Printf("[CONFIG] SPAXEL_INSTALL_SECRET=(not set, will auto-generate)")
}
log.Printf("[CONFIG] SPAXEL_NTP_SERVER=%s", cfg.NTPServer)
if cfg.MQTTBroker != "" {
log.Printf("[CONFIG] SPAXEL_MQTT_BROKER=%s", cfg.MQTTBroker)
log.Printf("[CONFIG] SPAXEL_MQTT_USERNAME=%s", cfg.MQTTUsername)
log.Printf("[CONFIG] SPAXEL_MQTT_PASSWORD=***")
}
log.Printf("[CONFIG] TZ=%s", cfg.Timezone)
}
// joinErrors combines multiple errors into a single error.
func joinErrors(errs []error) error {
var msg []string
for _, err := range errs {
msg = append(msg, err.Error())
}
return errors.New(strings.Join(msg, "\n"))
}
// FusionRate returns the fusion rate as a float64 for use in signal processing.
func (c *Config) FusionRate() float64 {
return float64(c.FusionRateHz)
}
// ReplayMaxBytes returns the replay max size in bytes.
func (c *Config) ReplayMaxBytes() int64 {
return int64(c.ReplayMaxMB) * 1024 * 1024
}
// TimezoneLocation returns the loaded time.Location for the configured timezone.
func (c *Config) TimezoneLocation() *time.Location {
loc, err := time.LoadLocation(c.Timezone)
if err != nil {
return time.UTC
}
return loc
}

View file

@ -0,0 +1,437 @@
package config
import (
"os"
"strings"
"testing"
)
// TestLoadValidConfig tests that a valid configuration loads successfully.
func TestLoadValidConfig(t *testing.T) {
// Clear all env vars first
clearEnvVars()
// Set a valid config
os.Setenv("SPAXEL_BIND_ADDR", "127.0.0.1:9090")
os.Setenv("SPAXEL_DATA_DIR", "/tmp/testdata")
os.Setenv("SPAXEL_LOG_LEVEL", "debug")
os.Setenv("SPAXEL_FUSION_RATE_HZ", "15")
os.Setenv("SPAXEL_REPLAY_MAX_MB", "500")
os.Setenv("SPAXEL_NTP_SERVER", "time.google.com")
os.Setenv("TZ", "America/New_York")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.BindAddr != "127.0.0.1:9090" {
t.Errorf("BindAddr = %s, want 127.0.0.1:9090", cfg.BindAddr)
}
if cfg.DataDir != "/tmp/testdata" {
t.Errorf("DataDir = %s, want /tmp/testdata", cfg.DataDir)
}
if cfg.StaticDir != "/dashboard" {
t.Errorf("StaticDir = %s, want /dashboard", cfg.StaticDir)
}
if cfg.LogLevel != "debug" {
t.Errorf("LogLevel = %s, want debug", cfg.LogLevel)
}
if cfg.FusionRateHz != 15 {
t.Errorf("FusionRateHz = %d, want 15", cfg.FusionRateHz)
}
if cfg.ReplayMaxMB != 500 {
t.Errorf("ReplayMaxMB = %d, want 500", cfg.ReplayMaxMB)
}
if cfg.NTPServer != "time.google.com" {
t.Errorf("NTPServer = %s, want time.google.com", cfg.NTPServer)
}
if cfg.Timezone != "America/New_York" {
t.Errorf("Timezone = %s, want America/New_York", cfg.Timezone)
}
if cfg.MDNSEnabled != true {
t.Errorf("MDNSEnabled = %t, want true", cfg.MDNSEnabled)
}
if cfg.MDNSName != "spaxel" {
t.Errorf("MDNSName = %s, want spaxel", cfg.MDNSName)
}
}
// TestLoadDefaults tests that all defaults are applied when env vars are unset.
func TestLoadDefaults(t *testing.T) {
clearEnvVars()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.BindAddr != "0.0.0.0:8080" {
t.Errorf("BindAddr = %s, want 0.0.0.0:8080", cfg.BindAddr)
}
if cfg.DataDir != "/data" {
t.Errorf("DataDir = %s, want /data", cfg.DataDir)
}
if cfg.StaticDir != "/dashboard" {
t.Errorf("StaticDir = %s, want /dashboard", cfg.StaticDir)
}
if cfg.LogLevel != "info" {
t.Errorf("LogLevel = %s, want info", cfg.LogLevel)
}
if cfg.FusionRateHz != 10 {
t.Errorf("FusionRateHz = %d, want 10", cfg.FusionRateHz)
}
if cfg.ReplayMaxMB != 360 {
t.Errorf("ReplayMaxMB = %d, want 360", cfg.ReplayMaxMB)
}
if cfg.NTPServer != "pool.ntp.org" {
t.Errorf("NTPServer = %s, want pool.ntp.org", cfg.NTPServer)
}
if cfg.Timezone != "UTC" {
t.Errorf("Timezone = %s, want UTC", cfg.Timezone)
}
if cfg.MDNSEnabled != true {
t.Errorf("MDNSEnabled = %t, want true", cfg.MDNSEnabled)
}
if cfg.MDNSName != "spaxel" {
t.Errorf("MDNSName = %s, want spaxel", cfg.MDNSName)
}
}
// TestInvalidFusionRateHz tests that invalid FUSION_RATE_HZ values are rejected.
func TestInvalidFusionRateHz(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{"too low", "0", "must be in range [1,20]"},
{"too high", "25", "must be in range [1,20]"},
{"negative", "-5", "must be in range [1,20]"},
{"non-integer", "abc", "must be an integer"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_FUSION_RATE_HZ", tt.value)
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Errorf("error = %v, want containing %q", err, tt.want)
}
})
}
}
// TestInvalidReplayMaxMB tests that invalid REPLAY_MAX_MB values are rejected.
func TestInvalidReplayMaxMB(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{"too low", "9", "must be in range [10,10000]"},
{"too high", "10001", "must be in range [10,10000]"},
{"negative", "-100", "must be in range [10,10000]"},
{"non-integer", "xyz", "must be an integer"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_REPLAY_MAX_MB", tt.value)
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Errorf("error = %v, want containing %q", err, tt.want)
}
})
}
}
// TestInvalidLogLevel tests that invalid LOG_LEVEL values are rejected.
func TestInvalidLogLevel(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_LOG_LEVEL", "verbose")
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), "must be one of debug, info, warn, error") {
t.Errorf("error = %v, want containing 'must be one of debug, info, warn, error'", err)
}
}
// TestInvalidMDNSEnabled tests that invalid MDNS_ENABLED values are rejected.
func TestInvalidMDNSEnabled(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_MDNS_ENABLED", "yes")
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), "must be one of true, false, 1, 0") {
t.Errorf("error = %v, want containing 'must be one of true, false, 1, 0'", err)
}
}
// TestInvalidInstallSecret tests that invalid INSTALL_SECRET values are rejected.
func TestInvalidInstallSecret(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{"too short", "abcd1234", "must be at least 32 bytes"},
{"invalid hex", "g" + strings.Repeat("0", 63), "must be a hex string"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_INSTALL_SECRET", tt.value)
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), tt.want) {
t.Errorf("error = %v, want containing %q", err, tt.want)
}
})
}
}
// TestValidInstallSecret tests that valid INSTALL_SECRET values are accepted.
func TestValidInstallSecret(t *testing.T) {
clearEnvVars()
// 64 hex chars = 32 bytes
validSecret := strings.Repeat("a", 64)
os.Setenv("SPAXEL_INSTALL_SECRET", validSecret)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.InstallSecret != validSecret {
t.Errorf("InstallSecret = %s, want %s", cfg.InstallSecret, validSecret)
}
}
// TestInvalidMQTTBroker tests that invalid MQTT_BROKER values are rejected.
func TestInvalidMQTTBroker(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_MQTT_BROKER", "not-a-url")
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), "must be a valid URL") {
t.Errorf("error = %v, want containing 'must be a valid URL'", err)
}
}
// TestValidMQTTBroker tests that valid MQTT_BROKER values are accepted.
func TestValidMQTTBroker(t *testing.T) {
tests := []struct {
name string
url string
}{
{"tcp", "mqtt://broker.local:1883"},
{"tls", "mqtts://broker.local:8883"},
{"with userpass", "mqtt://user:pass@broker.local:1883"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_MQTT_BROKER", tt.url)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.MQTTBroker != tt.url {
t.Errorf("MQTTBroker = %s, want %s", cfg.MQTTBroker, tt.url)
}
})
}
}
// TestInvalidTimezone tests that invalid TZ values are rejected.
func TestInvalidTimezone(t *testing.T) {
clearEnvVars()
os.Setenv("TZ", "Invalid/Timezone")
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
if !strings.Contains(err.Error(), "TZ=") {
t.Errorf("error = %v, want containing 'TZ='", err)
}
}
// TestMultipleErrors tests that multiple validation errors are collected.
func TestMultipleErrors(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_LOG_LEVEL", "verbose")
os.Setenv("SPAXEL_FUSION_RATE_HZ", "25")
os.Setenv("SPAXEL_REPLAY_MAX_MB", "5")
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error")
}
// Check that all three errors are present
errStr := err.Error()
if !strings.Contains(errStr, "SPAXEL_LOG_LEVEL") {
t.Errorf("error missing LOG_LEVEL validation: %v", err)
}
if !strings.Contains(errStr, "SPAXEL_FUSION_RATE_HZ") {
t.Errorf("error missing FUSION_RATE_HZ validation: %v", err)
}
if !strings.Contains(errStr, "SPAXEL_REPLAY_MAX_MB") {
t.Errorf("error missing REPLAY_MAX_MB validation: %v", err)
}
}
// TestMDNSEnabledVariants tests all valid MDNS_ENABLED values.
func TestMDNSEnabledVariants(t *testing.T) {
tests := []struct {
value string
expected bool
}{
{"true", true},
{"1", true},
{"false", false},
{"0", false},
}
for _, tt := range tests {
t.Run(tt.value, func(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_MDNS_ENABLED", tt.value)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.MDNSEnabled != tt.expected {
t.Errorf("MDNSEnabled = %t, want %t", cfg.MDNSEnabled, tt.expected)
}
})
}
}
// TestLogLevelVariants tests all valid LOG_LEVEL values.
func TestLogLevelVariants(t *testing.T) {
levels := []string{"debug", "info", "warn", "error"}
for _, level := range levels {
t.Run(level, func(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_LOG_LEVEL", level)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.LogLevel != level {
t.Errorf("LogLevel = %s, want %s", cfg.LogLevel, level)
}
})
}
}
// TestFusionRate tests FusionRate() method.
func TestFusionRate(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_FUSION_RATE_HZ", "15")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.FusionRate() != 15.0 {
t.Errorf("FusionRate() = %f, want 15.0", cfg.FusionRate())
}
}
// TestReplayMaxBytes tests ReplayMaxBytes() method.
func TestReplayMaxBytes(t *testing.T) {
clearEnvVars()
os.Setenv("SPAXEL_REPLAY_MAX_MB", "500")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
expected := int64(500 * 1024 * 1024)
if cfg.ReplayMaxBytes() != expected {
t.Errorf("ReplayMaxBytes() = %d, want %d", cfg.ReplayMaxBytes(), expected)
}
}
// TestTimezoneLocation tests TimezoneLocation() method.
func TestTimezoneLocation(t *testing.T) {
clearEnvVars()
os.Setenv("TZ", "America/New_York")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
loc := cfg.TimezoneLocation()
if loc.String() != "America/New_York" {
t.Errorf("TimezoneLocation() = %s, want America/New_York", loc)
}
}
// TestTimezoneLocationFallback tests that invalid timezone falls back to UTC.
func TestTimezoneLocationFallback(t *testing.T) {
clearEnvVars()
// Set TZ to an invalid value - this should cause Load() to fail
os.Setenv("TZ", "Invalid/Timezone")
_, err := Load()
if err == nil {
t.Fatal("Load() succeeded, want error for invalid TZ")
}
}
// clearEnvVars clears all SPAXEL_* and TZ environment variables.
func clearEnvVars() {
envVars := []string{
"SPAXEL_BIND_ADDR",
"SPAXEL_DATA_DIR",
"SPAXEL_STATIC_DIR",
"SPAXEL_MDNS_ENABLED",
"SPAXEL_MDNS_NAME",
"SPAXEL_LOG_LEVEL",
"SPAXEL_FUSION_RATE_HZ",
"SPAXEL_REPLAY_MAX_MB",
"SPAXEL_INSTALL_SECRET",
"SPAXEL_NTP_SERVER",
"SPAXEL_MQTT_BROKER",
"SPAXEL_MQTT_USERNAME",
"SPAXEL_MQTT_PASSWORD",
"TZ",
}
for _, v := range envVars {
os.Unsetenv(v)
}
}

View file

@ -24,13 +24,22 @@ import (
// 3. Schema migration: apply pending migrations with backup
// 4. Config & secrets: load/generate install secret
//
// The parentCtx should be the startup timeout context from main so that all
// phases share the same 30-second deadline. If parentCtx is nil, a fresh
// context with TotalTimeout is created.
//
// If any phase fails, the function returns an error and the caller should
// exit without serving traffic.
func OpenDB(dataDir, dbName string) (*sql.DB, error) {
ctx, cancel := context.WithTimeout(context.Background(), startup.TotalTimeout)
defer cancel()
func OpenDB(parentCtx context.Context, dataDir, dbName string) (*sql.DB, error) {
var cancel context.CancelFunc
ctx := parentCtx
if ctx == nil {
ctx, cancel = context.WithTimeout(context.Background(), startup.TotalTimeout)
defer cancel()
}
// Phase 1: Data directory + flock
startup.CheckTimeout(ctx)
done := startup.Phase(1, "Data directory")
dbPath := filepath.Join(dataDir, dbName)
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
@ -50,6 +59,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) {
done()
// Phase 2: SQLite open
startup.CheckTimeout(ctx)
done = startup.Phase(2, "SQLite")
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)&_pragma=foreign_keys(ON)&_pragma=busy_timeout(5000)")
if err != nil {
@ -85,6 +95,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) {
done()
// Phase 3: Schema migration
startup.CheckTimeout(ctx)
done = startup.Phase(3, "Schema migrations")
migrator, err := NewMigrator(dbPath, Config{
DataDir: dataDir,
@ -121,6 +132,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) {
done()
// Phase 4: Config & secrets
startup.CheckTimeout(ctx)
done = startup.Phase(4, "Config & secrets")
if err := ensureInstallSecret(ctx, db); err != nil {
db.Close()

View file

@ -150,6 +150,8 @@ const (
readDeadline = 60 * time.Second
// Malformed frame thresholds
// WARN logged when count exceeds 100 within the window
// Connection closed when count exceeds 1000 within the window
malformedWarnThreshold = 100
malformedCloseThreshold = 1000
malformedWindow = time.Minute
@ -586,6 +588,10 @@ func (s *Server) recordMalformed(mac string) {
return
}
// Log at DEBUG level for each validation failure
log.Printf("[DEBUG] Node %s sent malformed CSI frame (count in window: %d)", mac, counter.count+1)
// Reset counter if window has expired (sliding 60-second window)
if time.Since(counter.firstSeen) > malformedWindow {
counter.count = 0
counter.firstSeen = time.Now()
@ -593,14 +599,22 @@ func (s *Server) recordMalformed(mac string) {
counter.count++
if counter.count == malformedWarnThreshold {
log.Printf("[WARN] Node %s sending malformed CSI frames (count=%d)", mac, counter.count)
// Log WARN when count exceeds 100 within the window
if counter.count > malformedWarnThreshold && counter.count <= malformedWarnThreshold+1 {
// Only log once when crossing the threshold to avoid spam
log.Printf("[WARN] Node %s sent %d malformed frames in 60s", mac, counter.count)
}
if counter.count >= malformedCloseThreshold {
log.Printf("[ERROR] Node %s exceeded malformed frame threshold, closing connection", mac)
// Close connection when count exceeds 1000 within the window
if counter.count > malformedCloseThreshold {
log.Printf("[ERROR] Node %s sent %d malformed frames in 60s — closing connection: Excessive malformed frames — possible firmware bug", mac, counter.count)
if nc, exists := s.connections[mac]; exists {
nc.writeMu.Lock()
// Send close message with specific error text
nc.Conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "Excessive malformed frames — possible firmware bug"))
nc.Conn.Close()
nc.writeMu.Unlock()
}
}
}

View file

@ -0,0 +1,247 @@
package ingestion
import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
// TestMalformedCounter_WarnThreshold verifies that WARN is logged when count exceeds 100
func TestMalformedCounter_WarnThreshold(t *testing.T) {
server := NewServer()
// Create a fake connection state
mac := "AA:BB:CC:DD:EE:FF"
server.mu.Lock()
server.malformedCounts[mac] = &malformedCounter{
count: 100,
firstSeen: time.Now(),
}
server.mu.Unlock()
// This should trigger WARN (count becomes 101)
server.recordMalformed(mac)
server.mu.RLock()
counter := server.malformedCounts[mac]
server.mu.RUnlock()
if counter.count != 101 {
t.Errorf("Expected count 101, got %d", counter.count)
}
}
// TestMalformedCounter_CloseThreshold verifies that connection closes when count exceeds 1000
func TestMalformedCounter_CloseThreshold(t *testing.T) {
server := NewServer()
// Create a mock connection
mac := "AA:BB:CC:DD:EE:FF"
// We need to set up a minimal NodeConnection to test the close behavior
server.mu.Lock()
// Create a fake connection state with counter at 1000
server.malformedCounts[mac] = &malformedCounter{
count: 1000,
firstSeen: time.Now(),
}
server.mu.Unlock()
// This should trigger close (count becomes 1001)
server.recordMalformed(mac)
server.mu.RLock()
counter := server.malformedCounts[mac]
server.mu.RUnlock()
if counter.count != 1001 {
t.Errorf("Expected count 1001, got %d", counter.count)
}
}
// TestMalformedCounter_WindowReset verifies that counter resets after window expires
func TestMalformedCounter_WindowReset(t *testing.T) {
server := NewServer()
mac := "AA:BB:CC:DD:EE:FF"
// Set up counter with old timestamp
server.mu.Lock()
server.malformedCounts[mac] = &malformedCounter{
count: 500,
firstSeen: time.Now().Add(-61 * time.Second), // Outside the window
}
server.mu.Unlock()
// This should reset the counter
server.recordMalformed(mac)
server.mu.RLock()
counter := server.malformedCounts[mac]
server.mu.RUnlock()
// Counter should have been reset to 0 then incremented to 1
if counter.count != 1 {
t.Errorf("Expected count 1 after reset, got %d", counter.count)
}
}
// TestMalformedCounter_MissingMAC verifies that missing MACs are handled gracefully
func TestMalformedCounter_MissingMAC(t *testing.T) {
server := NewServer()
// Record malformed for a MAC that doesn't exist
// Should not panic
server.recordMalformed("FF:FF:FF:FF:FF:FF")
}
// TestHandleBinaryFrame_ValidationErrors verifies that ParseFrame errors are recorded
func TestHandleBinaryFrame_ValidationErrors(t *testing.T) {
server := NewServer()
// Create a mock connection
mac := "AA:BB:CC:DD:EE:FF"
nc := &NodeConnection{
MAC: mac,
}
server.mu.Lock()
server.connections[mac] = nc
server.malformedCounts[mac] = &malformedCounter{
count: 0,
firstSeen: time.Now(),
}
server.mu.Unlock()
// Test various invalid frames
invalidFrames := []struct {
name string
data []byte
}{
{"too short", make([]byte, 10)},
{"payload mismatch", make([]byte, HeaderSize+10)}, // n_sub=0 but has payload
{"invalid channel 0", []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, // channel at byte 22
}
for _, tt := range invalidFrames {
t.Run(tt.name, func(t *testing.T) {
server.mu.RLock()
initialCount := server.malformedCounts[mac].count
server.mu.RUnlock()
server.handleBinaryFrame(nc, tt.data)
server.mu.RLock()
finalCount := server.malformedCounts[mac].count
server.mu.RUnlock()
if finalCount != initialCount+1 {
t.Errorf("Malformed count not incremented for %s: got %d, want %d", tt.name, finalCount, initialCount+1)
}
})
}
}
// TestMalformedCounter_SlidingWindow verifies the sliding window behavior
func TestMalformedCounter_SlidingWindow(t *testing.T) {
server := NewServer()
mac := "AA:BB:CC:DD:EE:FF"
server.mu.Lock()
server.malformedCounts[mac] = &malformedCounter{
count: 99,
firstSeen: time.Now(),
}
server.mu.Unlock()
// Add one more - should NOT trigger WARN yet (only > 100)
server.recordMalformed(mac)
server.mu.RLock()
count := server.malformedCounts[mac].count
server.mu.RUnlock()
if count != 100 {
t.Errorf("Expected count 100, got %d", count)
}
// Wait a bit to avoid spam detection in the same second
time.Sleep(10 * time.Millisecond)
// Add one more - should trigger WARN (101 > 100)
server.recordMalformed(mac)
server.mu.RLock()
count = server.malformedCounts[mac].count
server.mu.RUnlock()
if count != 101 {
t.Errorf("Expected count 101, got %d", count)
}
}
// TestMalformedCounter_ConnectionCloseIntegration verifies integration with WebSocket
func TestMalformedCounter_ConnectionCloseIntegration(t *testing.T) {
// Create a test server with WebSocket upgrader
ingestServer := NewServer()
// Create a test HTTP server
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ingestServer.HandleNodeWS(w, r)
}))
defer httpServer.Close()
// Convert http:// to ws://
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws/node"
// Create a WebSocket connection
dialer := websocket.Dialer{}
conn, resp, err := dialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
// Send a hello message first
hello := `{"type":"hello","mac":"AA:BB:CC:DD:EE:FF","firmware_version":"1.0.0","chip":"ESP32-S3"}`
if err := conn.WriteMessage(websocket.TextMessage, []byte(hello)); err != nil {
t.Fatalf("Failed to send hello: %v", err)
}
// Read the response (should be role or config message)
conn.SetReadDeadline(time.Now().Add(time.Second))
_, _, err = conn.ReadMessage()
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
// Now send many malformed frames to trigger the close threshold
mac := "AA:BB:CC:DD:EE:FF"
ingestServer.mu.Lock()
// Set counter close to threshold to speed up test
ingestServer.malformedCounts[mac] = &malformedCounter{
count: 1000,
firstSeen: time.Now(),
}
ingestServer.mu.Unlock()
// Send one more malformed frame
invalidFrame := make([]byte, 10)
if err := conn.WriteMessage(websocket.BinaryMessage, invalidFrame); err != nil {
// Connection might already be closing
}
// Wait for the close to be processed
time.Sleep(100 * time.Millisecond)
// Try to read - should get close error
conn.SetReadDeadline(time.Now().Add(time.Second))
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("Expected connection to be closed, but it's still open")
}
}

View file

@ -53,12 +53,24 @@ type Server struct {
// NewServer creates a provisioning server.
// dataDir is where the install secret is persisted.
// mdnsName and msPort are embedded in the payload so the node can find the mothership.
func NewServer(dataDir, mdnsName string, msPort int) *Server {
// ntpServer is the NTP server hostname to embed in the provisioning payload.
// installSecretHex is an optional 64-char hex string; if provided, it overrides the persisted secret.
func NewServer(dataDir, mdnsName string, msPort int, ntpServer string, installSecretHex string) *Server {
s := &Server{
secretFile: filepath.Join(dataDir, "install_secret.bin"),
mdnsName: mdnsName,
msPort: msPort,
ntpServer: envOr("SPAXEL_NTP_SERVER", "pool.ntp.org"),
ntpServer: ntpServer,
}
// If install secret provided via config, use it instead of loading/creating
if installSecretHex != "" {
decoded, err := hex.DecodeString(installSecretHex)
if err == nil && len(decoded) == 32 {
s.installSecret = decoded
log.Printf("[INFO] provisioning: using install secret from SPAXEL_INSTALL_SECRET")
} else {
log.Printf("[WARN] provisioning: invalid SPAXEL_INSTALL_SECRET, will use persisted secret")
}
}
if err := s.loadOrCreateSecret(); err != nil {
log.Printf("[ERROR] provisioning: could not load/create install secret: %v", err)

View file

@ -19,11 +19,12 @@ const (
// SubsystemTimeout is the maximum time for each subsystem start in Phase 5.
SubsystemTimeout = 5 * time.Second
// ReadyFile is written on successful startup (Phase 7).
ReadyFile = "/tmp/spaxel.ready"
)
// ReadyFile is the path for the ready marker file.
// Override in tests before calling WriteReadyFile/RemoveReadyFile.
var ReadyFile = "/tmp/spaxel.ready"
// Phase logs the start of a startup phase and returns a function that logs
// completion with elapsed time. The returned function should be called via
// defer or after the phase work completes.

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"testing"