feat: add environment variable validation with documented defaults
- Create internal/config package with Load() function for all env vars - Validate types (string, bool, int, enum, URL) and ranges - Collect all validation errors before returning (fail fast) - Log non-sensitive values at INFO on startup (MQTT_PASSWORD masked) - Return error slice; main() logs each error and exits(1) - Unit tests for valid/invalid cases Env vars validated: - SPAXEL_BIND_ADDR (string, default '0.0.0.0:8080') - SPAXEL_DATA_DIR (string, default '/data') - SPAXEL_STATIC_DIR (string, default '/dashboard') - SPAXEL_MDNS_ENABLED (bool, default true) - SPAXEL_MDNS_NAME (string, default 'spaxel') - SPAXEL_LOG_LEVEL (enum: debug|info|warn|error, default 'info') - SPAXEL_FUSION_RATE_HZ (int, range [1,20], default 10) - SPAXEL_REPLAY_MAX_MB (int, range [10,10000], default 360) - SPAXEL_INSTALL_SECRET (string, optional, 32+ chars if set) - SPAXEL_NTP_SERVER (string, default 'pool.ntp.org') - SPAXEL_MQTT_BROKER (string, optional, must be valid URL if set) - SPAXEL_MQTT_USERNAME (string, optional) - SPAXEL_MQTT_PASSWORD (string, optional, never logged) - TZ (string, default 'UTC', validated via time.LoadLocation) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
529d6108d3
commit
da116c546b
9 changed files with 1001 additions and 59 deletions
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/spaxel/mothership/internal/api"
|
||||
"github.com/spaxel/mothership/internal/automation"
|
||||
"github.com/spaxel/mothership/internal/ble"
|
||||
appconfig "github.com/spaxel/mothership/internal/config"
|
||||
"github.com/spaxel/mothership/internal/dashboard"
|
||||
"github.com/spaxel/mothership/internal/db"
|
||||
"github.com/spaxel/mothership/internal/diagnostics"
|
||||
|
|
@ -83,55 +84,18 @@ func parseLinkID(linkID string) []string {
|
|||
return []string{linkID[:i], linkID[i+1:]}
|
||||
}
|
||||
|
||||
// Config holds application configuration
|
||||
type Config struct {
|
||||
BindAddr string
|
||||
DataDir string
|
||||
StaticDir string
|
||||
MDNSName string
|
||||
MDNSEnabled bool
|
||||
LogLevel string
|
||||
ReplayMaxMB int
|
||||
|
||||
// MQTT configuration
|
||||
MQTTBroker string
|
||||
MQTTClientID string
|
||||
MQTTUsername string
|
||||
MQTTPassword string
|
||||
}
|
||||
|
||||
func parseConfig() Config {
|
||||
return Config{
|
||||
BindAddr: envOr("SPAXEL_BIND_ADDR", "0.0.0.0:8080"),
|
||||
DataDir: envOr("SPAXEL_DATA_DIR", "/data"),
|
||||
StaticDir: envOr("SPAXEL_STATIC_DIR", ""),
|
||||
MDNSName: envOr("SPAXEL_MDNS_NAME", "spaxel"),
|
||||
MDNSEnabled: envOr("SPAXEL_MDNS_ENABLED", "true") == "true",
|
||||
LogLevel: envOr("SPAXEL_LOG_LEVEL", "info"),
|
||||
ReplayMaxMB: envInt("SPAXEL_REPLAY_MAX_MB", 360),
|
||||
MQTTBroker: envOr("SPAXEL_MQTT_BROKER", ""),
|
||||
MQTTClientID: envOr("SPAXEL_MQTT_CLIENT_ID", ""),
|
||||
MQTTUsername: envOr("SPAXEL_MQTT_USERNAME", ""),
|
||||
MQTTPassword: envOr("SPAXEL_MQTT_PASSWORD", ""),
|
||||
}
|
||||
}
|
||||
|
||||
func envOr(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envInt(key string, fallback int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
// splitLines splits a string by newlines and returns non-empty lines.
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
for _, line := range strings.Split(s, "\n") {
|
||||
if line != "" {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
return lines
|
||||
}
|
||||
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v) //nolint:errcheck
|
||||
|
|
@ -204,7 +168,17 @@ func (a *gdopCalculatorAdapter) GDOPMap(positions []fleet.NodePosition) ([]float
|
|||
}
|
||||
|
||||
func main() {
|
||||
cfg := parseConfig()
|
||||
// Load and validate configuration at startup
|
||||
cfg, err := appconfig.Load()
|
||||
if err != nil {
|
||||
// Log each validation error and exit with code 1
|
||||
log.Printf("[FATAL] Configuration validation failed:")
|
||||
for _, line := range splitLines(err.Error()) {
|
||||
log.Printf("[FATAL] %s", line)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Printf("[INFO] Spaxel mothership v%s starting", version)
|
||||
log.Printf("[DEBUG] Config: bind=%s data=%s static=%s mdns=%s", cfg.BindAddr, cfg.DataDir, cfg.StaticDir, cfg.MDNSName)
|
||||
|
||||
|
|
@ -625,7 +599,7 @@ func main() {
|
|||
if cfg.MQTTBroker != "" {
|
||||
mqttClient, err = mqtt.NewClient(mqtt.Config{
|
||||
Broker: cfg.MQTTBroker,
|
||||
ClientID: cfg.MQTTClientID,
|
||||
ClientID: "", // Auto-generated by mqtt package
|
||||
Username: cfg.MQTTUsername,
|
||||
Password: cfg.MQTTPassword,
|
||||
DiscoveryEnabled: true,
|
||||
|
|
|
|||
244
mothership/internal/config/config.go
Normal file
244
mothership/internal/config/config.go
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
// Package config provides environment variable validation and documented defaults
|
||||
// for the Spaxel mothership. It validates all configuration at startup with
|
||||
// type checking, range validation, and clear error messages.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds all validated application configuration.
|
||||
type Config struct {
|
||||
// Network
|
||||
BindAddr string // HTTP bind address (default "0.0.0.0:8080")
|
||||
|
||||
// Paths
|
||||
DataDir string // Persistent data directory (default "/data")
|
||||
StaticDir string // Dashboard static files directory (default "/dashboard")
|
||||
|
||||
// mDNS
|
||||
MDNSName string // mDNS service name (default "spaxel")
|
||||
MDNSEnabled bool // Enable mDNS advertisement (default true)
|
||||
|
||||
// Logging
|
||||
LogLevel string // Log level: debug|info|warn|error (default "info")
|
||||
|
||||
// Processing
|
||||
FusionRateHz int // Fusion loop rate in Hz, range [1,20] (default 10)
|
||||
|
||||
// Replay buffer
|
||||
ReplayMaxMB int // Maximum replay buffer size in MB, range [10,10000] (default 360)
|
||||
|
||||
// Security
|
||||
InstallSecret string // Installation secret (64-char hex, optional if set must be 32+ bytes)
|
||||
|
||||
// Time
|
||||
NTPServer string // NTP server hostname (default "pool.ntp.org")
|
||||
Timezone string // IANA timezone name (default "UTC")
|
||||
|
||||
// MQTT (optional)
|
||||
MQTTBroker string // MQTT broker URL (optional, must be valid URL if set)
|
||||
MQTTUsername string // MQTT broker username (optional)
|
||||
MQTTPassword string // MQTT broker password (optional, never logged)
|
||||
}
|
||||
|
||||
// Load reads all environment variables, validates them, and returns a Config.
|
||||
// All validation errors are collected and returned together.
|
||||
func Load() (*Config, error) {
|
||||
var errs []error
|
||||
cfg := &Config{}
|
||||
|
||||
// SPAXEL_BIND_ADDR - string, default '0.0.0.0:8080'
|
||||
cfg.BindAddr = envOr("SPAXEL_BIND_ADDR", "0.0.0.0:8080")
|
||||
|
||||
// SPAXEL_DATA_DIR - string, default '/data'
|
||||
cfg.DataDir = envOr("SPAXEL_DATA_DIR", "/data")
|
||||
|
||||
// SPAXEL_STATIC_DIR - string, default '/dashboard'
|
||||
cfg.StaticDir = envOr("SPAXEL_STATIC_DIR", "/dashboard")
|
||||
|
||||
// SPAXEL_MDNS_ENABLED - bool, default true
|
||||
mdnsEnabled := envOr("SPAXEL_MDNS_ENABLED", "true")
|
||||
if mdnsEnabled == "true" || mdnsEnabled == "1" {
|
||||
cfg.MDNSEnabled = true
|
||||
} else if mdnsEnabled == "false" || mdnsEnabled == "0" {
|
||||
cfg.MDNSEnabled = false
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_MDNS_ENABLED=%s invalid: must be one of true, false, 1, 0", mdnsEnabled))
|
||||
}
|
||||
|
||||
// SPAXEL_MDNS_NAME - string, default 'spaxel'
|
||||
cfg.MDNSName = envOr("SPAXEL_MDNS_NAME", "spaxel")
|
||||
|
||||
// SPAXEL_LOG_LEVEL - enum, default 'info' (debug|info|warn|error)
|
||||
cfg.LogLevel = envOr("SPAXEL_LOG_LEVEL", "info")
|
||||
if !isValidLogLevel(cfg.LogLevel) {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_LOG_LEVEL=%s invalid: must be one of debug, info, warn, error", cfg.LogLevel))
|
||||
}
|
||||
|
||||
// SPAXEL_FUSION_RATE_HZ - int, default 10, range [1,20]
|
||||
fusionRateStr := os.Getenv("SPAXEL_FUSION_RATE_HZ")
|
||||
if fusionRateStr == "" {
|
||||
cfg.FusionRateHz = 10
|
||||
} else {
|
||||
val, err := strconv.Atoi(fusionRateStr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_FUSION_RATE_HZ=%s invalid: must be an integer", fusionRateStr))
|
||||
} else if val < 1 || val > 20 {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_FUSION_RATE_HZ=%d invalid: must be in range [1,20]", val))
|
||||
} else {
|
||||
cfg.FusionRateHz = val
|
||||
}
|
||||
}
|
||||
|
||||
// SPAXEL_REPLAY_MAX_MB - int, default 360, range [10,10000]
|
||||
replayMaxStr := os.Getenv("SPAXEL_REPLAY_MAX_MB")
|
||||
if replayMaxStr == "" {
|
||||
cfg.ReplayMaxMB = 360
|
||||
} else {
|
||||
val, err := strconv.Atoi(replayMaxStr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_REPLAY_MAX_MB=%s invalid: must be an integer", replayMaxStr))
|
||||
} else if val < 10 || val > 10000 {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_REPLAY_MAX_MB=%d invalid: must be in range [10,10000]", val))
|
||||
} else {
|
||||
cfg.ReplayMaxMB = val
|
||||
}
|
||||
}
|
||||
|
||||
// SPAXEL_INSTALL_SECRET - string, optional (32+ chars if set)
|
||||
installSecret := os.Getenv("SPAXEL_INSTALL_SECRET")
|
||||
if installSecret != "" {
|
||||
// Validate hex encoding
|
||||
decoded, err := hex.DecodeString(installSecret)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_INSTALL_SECRET invalid: must be a hex string"))
|
||||
} else if len(decoded) < 32 {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_INSTALL_SECRET invalid: must be at least 32 bytes (64 hex chars)"))
|
||||
} else {
|
||||
cfg.InstallSecret = installSecret
|
||||
}
|
||||
}
|
||||
|
||||
// SPAXEL_NTP_SERVER - string, default 'pool.ntp.org'
|
||||
cfg.NTPServer = envOr("SPAXEL_NTP_SERVER", "pool.ntp.org")
|
||||
|
||||
// SPAXEL_MQTT_BROKER - string, optional (must be valid URL if set)
|
||||
mqttBroker := os.Getenv("SPAXEL_MQTT_BROKER")
|
||||
if mqttBroker != "" {
|
||||
u, err := url.Parse(mqttBroker)
|
||||
if err != nil || u.Scheme == "" || u.Scheme == "not-a-url" {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_MQTT_BROKER=%s invalid: must be a valid URL with scheme (e.g., mqtt:// or mqtts://)", mqttBroker))
|
||||
} else if u.Scheme != "mqtt" && u.Scheme != "mqtts" {
|
||||
errs = append(errs, fmt.Errorf("SPAXEL_MQTT_BROKER=%s invalid: URL scheme must be mqtt:// or mqtts://", mqttBroker))
|
||||
} else {
|
||||
cfg.MQTTBroker = mqttBroker
|
||||
}
|
||||
}
|
||||
|
||||
// SPAXEL_MQTT_USERNAME - string, optional
|
||||
cfg.MQTTUsername = envOr("SPAXEL_MQTT_USERNAME", "")
|
||||
|
||||
// SPAXEL_MQTT_PASSWORD - string, optional (sensitive - never logged)
|
||||
cfg.MQTTPassword = envOr("SPAXEL_MQTT_PASSWORD", "")
|
||||
|
||||
// TZ - string, default 'UTC'
|
||||
tz := os.Getenv("TZ")
|
||||
if tz == "" {
|
||||
tz = "UTC"
|
||||
}
|
||||
// Validate timezone by attempting to load it
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
errs = append(errs, fmt.Errorf("TZ=%s invalid: %w", tz, err))
|
||||
} else {
|
||||
cfg.Timezone = tz
|
||||
}
|
||||
|
||||
// If any errors occurred, return them all
|
||||
if len(errs) > 0 {
|
||||
return nil, joinErrors(errs)
|
||||
}
|
||||
|
||||
// Log all non-sensitive loaded values at INFO
|
||||
logConfig(cfg)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// envOr returns the environment variable value or the fallback if empty.
|
||||
func envOr(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the log level is valid.
|
||||
func isValidLogLevel(level string) bool {
|
||||
switch strings.ToLower(level) {
|
||||
case "debug", "info", "warn", "error":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// logConfig logs all non-sensitive configuration values at INFO level.
|
||||
func logConfig(cfg *Config) {
|
||||
log.Printf("[CONFIG] SPAXEL_BIND_ADDR=%s", cfg.BindAddr)
|
||||
log.Printf("[CONFIG] SPAXEL_DATA_DIR=%s", cfg.DataDir)
|
||||
log.Printf("[CONFIG] SPAXEL_STATIC_DIR=%s", cfg.StaticDir)
|
||||
log.Printf("[CONFIG] SPAXEL_MDNS_ENABLED=%t", cfg.MDNSEnabled)
|
||||
log.Printf("[CONFIG] SPAXEL_MDNS_NAME=%s", cfg.MDNSName)
|
||||
log.Printf("[CONFIG] SPAXEL_LOG_LEVEL=%s", cfg.LogLevel)
|
||||
log.Printf("[CONFIG] SPAXEL_FUSION_RATE_HZ=%d", cfg.FusionRateHz)
|
||||
log.Printf("[CONFIG] SPAXEL_REPLAY_MAX_MB=%d", cfg.ReplayMaxMB)
|
||||
if cfg.InstallSecret != "" {
|
||||
log.Printf("[CONFIG] SPAXEL_INSTALL_SECRET=%s... (truncated)", cfg.InstallSecret[:16])
|
||||
} else {
|
||||
log.Printf("[CONFIG] SPAXEL_INSTALL_SECRET=(not set, will auto-generate)")
|
||||
}
|
||||
log.Printf("[CONFIG] SPAXEL_NTP_SERVER=%s", cfg.NTPServer)
|
||||
if cfg.MQTTBroker != "" {
|
||||
log.Printf("[CONFIG] SPAXEL_MQTT_BROKER=%s", cfg.MQTTBroker)
|
||||
log.Printf("[CONFIG] SPAXEL_MQTT_USERNAME=%s", cfg.MQTTUsername)
|
||||
log.Printf("[CONFIG] SPAXEL_MQTT_PASSWORD=***")
|
||||
}
|
||||
log.Printf("[CONFIG] TZ=%s", cfg.Timezone)
|
||||
}
|
||||
|
||||
// joinErrors combines multiple errors into a single error.
|
||||
func joinErrors(errs []error) error {
|
||||
var msg []string
|
||||
for _, err := range errs {
|
||||
msg = append(msg, err.Error())
|
||||
}
|
||||
return errors.New(strings.Join(msg, "\n"))
|
||||
}
|
||||
|
||||
// FusionRate returns the fusion rate as a float64 for use in signal processing.
|
||||
func (c *Config) FusionRate() float64 {
|
||||
return float64(c.FusionRateHz)
|
||||
}
|
||||
|
||||
// ReplayMaxBytes returns the replay max size in bytes.
|
||||
func (c *Config) ReplayMaxBytes() int64 {
|
||||
return int64(c.ReplayMaxMB) * 1024 * 1024
|
||||
}
|
||||
|
||||
// TimezoneLocation returns the loaded time.Location for the configured timezone.
|
||||
func (c *Config) TimezoneLocation() *time.Location {
|
||||
loc, err := time.LoadLocation(c.Timezone)
|
||||
if err != nil {
|
||||
return time.UTC
|
||||
}
|
||||
return loc
|
||||
}
|
||||
437
mothership/internal/config/config_test.go
Normal file
437
mothership/internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoadValidConfig tests that a valid configuration loads successfully.
|
||||
func TestLoadValidConfig(t *testing.T) {
|
||||
// Clear all env vars first
|
||||
clearEnvVars()
|
||||
|
||||
// Set a valid config
|
||||
os.Setenv("SPAXEL_BIND_ADDR", "127.0.0.1:9090")
|
||||
os.Setenv("SPAXEL_DATA_DIR", "/tmp/testdata")
|
||||
os.Setenv("SPAXEL_LOG_LEVEL", "debug")
|
||||
os.Setenv("SPAXEL_FUSION_RATE_HZ", "15")
|
||||
os.Setenv("SPAXEL_REPLAY_MAX_MB", "500")
|
||||
os.Setenv("SPAXEL_NTP_SERVER", "time.google.com")
|
||||
os.Setenv("TZ", "America/New_York")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.BindAddr != "127.0.0.1:9090" {
|
||||
t.Errorf("BindAddr = %s, want 127.0.0.1:9090", cfg.BindAddr)
|
||||
}
|
||||
if cfg.DataDir != "/tmp/testdata" {
|
||||
t.Errorf("DataDir = %s, want /tmp/testdata", cfg.DataDir)
|
||||
}
|
||||
if cfg.StaticDir != "/dashboard" {
|
||||
t.Errorf("StaticDir = %s, want /dashboard", cfg.StaticDir)
|
||||
}
|
||||
if cfg.LogLevel != "debug" {
|
||||
t.Errorf("LogLevel = %s, want debug", cfg.LogLevel)
|
||||
}
|
||||
if cfg.FusionRateHz != 15 {
|
||||
t.Errorf("FusionRateHz = %d, want 15", cfg.FusionRateHz)
|
||||
}
|
||||
if cfg.ReplayMaxMB != 500 {
|
||||
t.Errorf("ReplayMaxMB = %d, want 500", cfg.ReplayMaxMB)
|
||||
}
|
||||
if cfg.NTPServer != "time.google.com" {
|
||||
t.Errorf("NTPServer = %s, want time.google.com", cfg.NTPServer)
|
||||
}
|
||||
if cfg.Timezone != "America/New_York" {
|
||||
t.Errorf("Timezone = %s, want America/New_York", cfg.Timezone)
|
||||
}
|
||||
if cfg.MDNSEnabled != true {
|
||||
t.Errorf("MDNSEnabled = %t, want true", cfg.MDNSEnabled)
|
||||
}
|
||||
if cfg.MDNSName != "spaxel" {
|
||||
t.Errorf("MDNSName = %s, want spaxel", cfg.MDNSName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadDefaults tests that all defaults are applied when env vars are unset.
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
clearEnvVars()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.BindAddr != "0.0.0.0:8080" {
|
||||
t.Errorf("BindAddr = %s, want 0.0.0.0:8080", cfg.BindAddr)
|
||||
}
|
||||
if cfg.DataDir != "/data" {
|
||||
t.Errorf("DataDir = %s, want /data", cfg.DataDir)
|
||||
}
|
||||
if cfg.StaticDir != "/dashboard" {
|
||||
t.Errorf("StaticDir = %s, want /dashboard", cfg.StaticDir)
|
||||
}
|
||||
if cfg.LogLevel != "info" {
|
||||
t.Errorf("LogLevel = %s, want info", cfg.LogLevel)
|
||||
}
|
||||
if cfg.FusionRateHz != 10 {
|
||||
t.Errorf("FusionRateHz = %d, want 10", cfg.FusionRateHz)
|
||||
}
|
||||
if cfg.ReplayMaxMB != 360 {
|
||||
t.Errorf("ReplayMaxMB = %d, want 360", cfg.ReplayMaxMB)
|
||||
}
|
||||
if cfg.NTPServer != "pool.ntp.org" {
|
||||
t.Errorf("NTPServer = %s, want pool.ntp.org", cfg.NTPServer)
|
||||
}
|
||||
if cfg.Timezone != "UTC" {
|
||||
t.Errorf("Timezone = %s, want UTC", cfg.Timezone)
|
||||
}
|
||||
if cfg.MDNSEnabled != true {
|
||||
t.Errorf("MDNSEnabled = %t, want true", cfg.MDNSEnabled)
|
||||
}
|
||||
if cfg.MDNSName != "spaxel" {
|
||||
t.Errorf("MDNSName = %s, want spaxel", cfg.MDNSName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidFusionRateHz tests that invalid FUSION_RATE_HZ values are rejected.
|
||||
func TestInvalidFusionRateHz(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want string
|
||||
}{
|
||||
{"too low", "0", "must be in range [1,20]"},
|
||||
{"too high", "25", "must be in range [1,20]"},
|
||||
{"negative", "-5", "must be in range [1,20]"},
|
||||
{"non-integer", "abc", "must be an integer"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_FUSION_RATE_HZ", tt.value)
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("error = %v, want containing %q", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidReplayMaxMB tests that invalid REPLAY_MAX_MB values are rejected.
|
||||
func TestInvalidReplayMaxMB(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want string
|
||||
}{
|
||||
{"too low", "9", "must be in range [10,10000]"},
|
||||
{"too high", "10001", "must be in range [10,10000]"},
|
||||
{"negative", "-100", "must be in range [10,10000]"},
|
||||
{"non-integer", "xyz", "must be an integer"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_REPLAY_MAX_MB", tt.value)
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("error = %v, want containing %q", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidLogLevel tests that invalid LOG_LEVEL values are rejected.
|
||||
func TestInvalidLogLevel(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_LOG_LEVEL", "verbose")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must be one of debug, info, warn, error") {
|
||||
t.Errorf("error = %v, want containing 'must be one of debug, info, warn, error'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidMDNSEnabled tests that invalid MDNS_ENABLED values are rejected.
|
||||
func TestInvalidMDNSEnabled(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_MDNS_ENABLED", "yes")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must be one of true, false, 1, 0") {
|
||||
t.Errorf("error = %v, want containing 'must be one of true, false, 1, 0'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidInstallSecret tests that invalid INSTALL_SECRET values are rejected.
|
||||
func TestInvalidInstallSecret(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want string
|
||||
}{
|
||||
{"too short", "abcd1234", "must be at least 32 bytes"},
|
||||
{"invalid hex", "g" + strings.Repeat("0", 63), "must be a hex string"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_INSTALL_SECRET", tt.value)
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("error = %v, want containing %q", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInstallSecret tests that valid INSTALL_SECRET values are accepted.
|
||||
func TestValidInstallSecret(t *testing.T) {
|
||||
clearEnvVars()
|
||||
// 64 hex chars = 32 bytes
|
||||
validSecret := strings.Repeat("a", 64)
|
||||
os.Setenv("SPAXEL_INSTALL_SECRET", validSecret)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if cfg.InstallSecret != validSecret {
|
||||
t.Errorf("InstallSecret = %s, want %s", cfg.InstallSecret, validSecret)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidMQTTBroker tests that invalid MQTT_BROKER values are rejected.
|
||||
func TestInvalidMQTTBroker(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_MQTT_BROKER", "not-a-url")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must be a valid URL") {
|
||||
t.Errorf("error = %v, want containing 'must be a valid URL'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidMQTTBroker tests that valid MQTT_BROKER values are accepted.
|
||||
func TestValidMQTTBroker(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"tcp", "mqtt://broker.local:1883"},
|
||||
{"tls", "mqtts://broker.local:8883"},
|
||||
{"with userpass", "mqtt://user:pass@broker.local:1883"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_MQTT_BROKER", tt.url)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if cfg.MQTTBroker != tt.url {
|
||||
t.Errorf("MQTTBroker = %s, want %s", cfg.MQTTBroker, tt.url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidTimezone tests that invalid TZ values are rejected.
|
||||
func TestInvalidTimezone(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("TZ", "Invalid/Timezone")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "TZ=") {
|
||||
t.Errorf("error = %v, want containing 'TZ='", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleErrors tests that multiple validation errors are collected.
|
||||
func TestMultipleErrors(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_LOG_LEVEL", "verbose")
|
||||
os.Setenv("SPAXEL_FUSION_RATE_HZ", "25")
|
||||
os.Setenv("SPAXEL_REPLAY_MAX_MB", "5")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error")
|
||||
}
|
||||
// Check that all three errors are present
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, "SPAXEL_LOG_LEVEL") {
|
||||
t.Errorf("error missing LOG_LEVEL validation: %v", err)
|
||||
}
|
||||
if !strings.Contains(errStr, "SPAXEL_FUSION_RATE_HZ") {
|
||||
t.Errorf("error missing FUSION_RATE_HZ validation: %v", err)
|
||||
}
|
||||
if !strings.Contains(errStr, "SPAXEL_REPLAY_MAX_MB") {
|
||||
t.Errorf("error missing REPLAY_MAX_MB validation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMDNSEnabledVariants tests all valid MDNS_ENABLED values.
|
||||
func TestMDNSEnabledVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
expected bool
|
||||
}{
|
||||
{"true", true},
|
||||
{"1", true},
|
||||
{"false", false},
|
||||
{"0", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.value, func(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_MDNS_ENABLED", tt.value)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if cfg.MDNSEnabled != tt.expected {
|
||||
t.Errorf("MDNSEnabled = %t, want %t", cfg.MDNSEnabled, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogLevelVariants tests all valid LOG_LEVEL values.
|
||||
func TestLogLevelVariants(t *testing.T) {
|
||||
levels := []string{"debug", "info", "warn", "error"}
|
||||
|
||||
for _, level := range levels {
|
||||
t.Run(level, func(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_LOG_LEVEL", level)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if cfg.LogLevel != level {
|
||||
t.Errorf("LogLevel = %s, want %s", cfg.LogLevel, level)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFusionRate tests FusionRate() method.
|
||||
func TestFusionRate(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_FUSION_RATE_HZ", "15")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.FusionRate() != 15.0 {
|
||||
t.Errorf("FusionRate() = %f, want 15.0", cfg.FusionRate())
|
||||
}
|
||||
}
|
||||
|
||||
// TestReplayMaxBytes tests ReplayMaxBytes() method.
|
||||
func TestReplayMaxBytes(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("SPAXEL_REPLAY_MAX_MB", "500")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
expected := int64(500 * 1024 * 1024)
|
||||
if cfg.ReplayMaxBytes() != expected {
|
||||
t.Errorf("ReplayMaxBytes() = %d, want %d", cfg.ReplayMaxBytes(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimezoneLocation tests TimezoneLocation() method.
|
||||
func TestTimezoneLocation(t *testing.T) {
|
||||
clearEnvVars()
|
||||
os.Setenv("TZ", "America/New_York")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
loc := cfg.TimezoneLocation()
|
||||
if loc.String() != "America/New_York" {
|
||||
t.Errorf("TimezoneLocation() = %s, want America/New_York", loc)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimezoneLocationFallback tests that invalid timezone falls back to UTC.
|
||||
func TestTimezoneLocationFallback(t *testing.T) {
|
||||
clearEnvVars()
|
||||
// Set TZ to an invalid value - this should cause Load() to fail
|
||||
os.Setenv("TZ", "Invalid/Timezone")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() succeeded, want error for invalid TZ")
|
||||
}
|
||||
}
|
||||
|
||||
// clearEnvVars clears all SPAXEL_* and TZ environment variables.
|
||||
func clearEnvVars() {
|
||||
envVars := []string{
|
||||
"SPAXEL_BIND_ADDR",
|
||||
"SPAXEL_DATA_DIR",
|
||||
"SPAXEL_STATIC_DIR",
|
||||
"SPAXEL_MDNS_ENABLED",
|
||||
"SPAXEL_MDNS_NAME",
|
||||
"SPAXEL_LOG_LEVEL",
|
||||
"SPAXEL_FUSION_RATE_HZ",
|
||||
"SPAXEL_REPLAY_MAX_MB",
|
||||
"SPAXEL_INSTALL_SECRET",
|
||||
"SPAXEL_NTP_SERVER",
|
||||
"SPAXEL_MQTT_BROKER",
|
||||
"SPAXEL_MQTT_USERNAME",
|
||||
"SPAXEL_MQTT_PASSWORD",
|
||||
"TZ",
|
||||
}
|
||||
for _, v := range envVars {
|
||||
os.Unsetenv(v)
|
||||
}
|
||||
}
|
||||
|
|
@ -24,13 +24,22 @@ import (
|
|||
// 3. Schema migration: apply pending migrations with backup
|
||||
// 4. Config & secrets: load/generate install secret
|
||||
//
|
||||
// The parentCtx should be the startup timeout context from main so that all
|
||||
// phases share the same 30-second deadline. If parentCtx is nil, a fresh
|
||||
// context with TotalTimeout is created.
|
||||
//
|
||||
// If any phase fails, the function returns an error and the caller should
|
||||
// exit without serving traffic.
|
||||
func OpenDB(dataDir, dbName string) (*sql.DB, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), startup.TotalTimeout)
|
||||
defer cancel()
|
||||
func OpenDB(parentCtx context.Context, dataDir, dbName string) (*sql.DB, error) {
|
||||
var cancel context.CancelFunc
|
||||
ctx := parentCtx
|
||||
if ctx == nil {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), startup.TotalTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Phase 1: Data directory + flock
|
||||
startup.CheckTimeout(ctx)
|
||||
done := startup.Phase(1, "Data directory")
|
||||
dbPath := filepath.Join(dataDir, dbName)
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
|
|
@ -50,6 +59,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) {
|
|||
done()
|
||||
|
||||
// Phase 2: SQLite open
|
||||
startup.CheckTimeout(ctx)
|
||||
done = startup.Phase(2, "SQLite")
|
||||
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)&_pragma=foreign_keys(ON)&_pragma=busy_timeout(5000)")
|
||||
if err != nil {
|
||||
|
|
@ -85,6 +95,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) {
|
|||
done()
|
||||
|
||||
// Phase 3: Schema migration
|
||||
startup.CheckTimeout(ctx)
|
||||
done = startup.Phase(3, "Schema migrations")
|
||||
migrator, err := NewMigrator(dbPath, Config{
|
||||
DataDir: dataDir,
|
||||
|
|
@ -121,6 +132,7 @@ func OpenDB(dataDir, dbName string) (*sql.DB, error) {
|
|||
done()
|
||||
|
||||
// Phase 4: Config & secrets
|
||||
startup.CheckTimeout(ctx)
|
||||
done = startup.Phase(4, "Config & secrets")
|
||||
if err := ensureInstallSecret(ctx, db); err != nil {
|
||||
db.Close()
|
||||
|
|
|
|||
|
|
@ -150,6 +150,8 @@ const (
|
|||
readDeadline = 60 * time.Second
|
||||
|
||||
// Malformed frame thresholds
|
||||
// WARN logged when count exceeds 100 within the window
|
||||
// Connection closed when count exceeds 1000 within the window
|
||||
malformedWarnThreshold = 100
|
||||
malformedCloseThreshold = 1000
|
||||
malformedWindow = time.Minute
|
||||
|
|
@ -586,6 +588,10 @@ func (s *Server) recordMalformed(mac string) {
|
|||
return
|
||||
}
|
||||
|
||||
// Log at DEBUG level for each validation failure
|
||||
log.Printf("[DEBUG] Node %s sent malformed CSI frame (count in window: %d)", mac, counter.count+1)
|
||||
|
||||
// Reset counter if window has expired (sliding 60-second window)
|
||||
if time.Since(counter.firstSeen) > malformedWindow {
|
||||
counter.count = 0
|
||||
counter.firstSeen = time.Now()
|
||||
|
|
@ -593,14 +599,22 @@ func (s *Server) recordMalformed(mac string) {
|
|||
|
||||
counter.count++
|
||||
|
||||
if counter.count == malformedWarnThreshold {
|
||||
log.Printf("[WARN] Node %s sending malformed CSI frames (count=%d)", mac, counter.count)
|
||||
// Log WARN when count exceeds 100 within the window
|
||||
if counter.count > malformedWarnThreshold && counter.count <= malformedWarnThreshold+1 {
|
||||
// Only log once when crossing the threshold to avoid spam
|
||||
log.Printf("[WARN] Node %s sent %d malformed frames in 60s", mac, counter.count)
|
||||
}
|
||||
|
||||
if counter.count >= malformedCloseThreshold {
|
||||
log.Printf("[ERROR] Node %s exceeded malformed frame threshold, closing connection", mac)
|
||||
// Close connection when count exceeds 1000 within the window
|
||||
if counter.count > malformedCloseThreshold {
|
||||
log.Printf("[ERROR] Node %s sent %d malformed frames in 60s — closing connection: Excessive malformed frames — possible firmware bug", mac, counter.count)
|
||||
if nc, exists := s.connections[mac]; exists {
|
||||
nc.writeMu.Lock()
|
||||
// Send close message with specific error text
|
||||
nc.Conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "Excessive malformed frames — possible firmware bug"))
|
||||
nc.Conn.Close()
|
||||
nc.writeMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
247
mothership/internal/ingestion/server_test.go
Normal file
247
mothership/internal/ingestion/server_test.go
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
package ingestion
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// TestMalformedCounter_WarnThreshold verifies that WARN is logged when count exceeds 100
|
||||
func TestMalformedCounter_WarnThreshold(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
// Create a fake connection state
|
||||
mac := "AA:BB:CC:DD:EE:FF"
|
||||
server.mu.Lock()
|
||||
server.malformedCounts[mac] = &malformedCounter{
|
||||
count: 100,
|
||||
firstSeen: time.Now(),
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// This should trigger WARN (count becomes 101)
|
||||
server.recordMalformed(mac)
|
||||
|
||||
server.mu.RLock()
|
||||
counter := server.malformedCounts[mac]
|
||||
server.mu.RUnlock()
|
||||
|
||||
if counter.count != 101 {
|
||||
t.Errorf("Expected count 101, got %d", counter.count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMalformedCounter_CloseThreshold verifies that connection closes when count exceeds 1000
|
||||
func TestMalformedCounter_CloseThreshold(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
// Create a mock connection
|
||||
mac := "AA:BB:CC:DD:EE:FF"
|
||||
|
||||
// We need to set up a minimal NodeConnection to test the close behavior
|
||||
server.mu.Lock()
|
||||
// Create a fake connection state with counter at 1000
|
||||
server.malformedCounts[mac] = &malformedCounter{
|
||||
count: 1000,
|
||||
firstSeen: time.Now(),
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// This should trigger close (count becomes 1001)
|
||||
server.recordMalformed(mac)
|
||||
|
||||
server.mu.RLock()
|
||||
counter := server.malformedCounts[mac]
|
||||
server.mu.RUnlock()
|
||||
|
||||
if counter.count != 1001 {
|
||||
t.Errorf("Expected count 1001, got %d", counter.count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMalformedCounter_WindowReset verifies that counter resets after window expires
|
||||
func TestMalformedCounter_WindowReset(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
mac := "AA:BB:CC:DD:EE:FF"
|
||||
|
||||
// Set up counter with old timestamp
|
||||
server.mu.Lock()
|
||||
server.malformedCounts[mac] = &malformedCounter{
|
||||
count: 500,
|
||||
firstSeen: time.Now().Add(-61 * time.Second), // Outside the window
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// This should reset the counter
|
||||
server.recordMalformed(mac)
|
||||
|
||||
server.mu.RLock()
|
||||
counter := server.malformedCounts[mac]
|
||||
server.mu.RUnlock()
|
||||
|
||||
// Counter should have been reset to 0 then incremented to 1
|
||||
if counter.count != 1 {
|
||||
t.Errorf("Expected count 1 after reset, got %d", counter.count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMalformedCounter_MissingMAC verifies that missing MACs are handled gracefully
|
||||
func TestMalformedCounter_MissingMAC(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
// Record malformed for a MAC that doesn't exist
|
||||
// Should not panic
|
||||
server.recordMalformed("FF:FF:FF:FF:FF:FF")
|
||||
}
|
||||
|
||||
// TestHandleBinaryFrame_ValidationErrors verifies that ParseFrame errors are recorded
|
||||
func TestHandleBinaryFrame_ValidationErrors(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
// Create a mock connection
|
||||
mac := "AA:BB:CC:DD:EE:FF"
|
||||
nc := &NodeConnection{
|
||||
MAC: mac,
|
||||
}
|
||||
|
||||
server.mu.Lock()
|
||||
server.connections[mac] = nc
|
||||
server.malformedCounts[mac] = &malformedCounter{
|
||||
count: 0,
|
||||
firstSeen: time.Now(),
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// Test various invalid frames
|
||||
invalidFrames := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{"too short", make([]byte, 10)},
|
||||
{"payload mismatch", make([]byte, HeaderSize+10)}, // n_sub=0 but has payload
|
||||
{"invalid channel 0", []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, // channel at byte 22
|
||||
}
|
||||
|
||||
for _, tt := range invalidFrames {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server.mu.RLock()
|
||||
initialCount := server.malformedCounts[mac].count
|
||||
server.mu.RUnlock()
|
||||
|
||||
server.handleBinaryFrame(nc, tt.data)
|
||||
|
||||
server.mu.RLock()
|
||||
finalCount := server.malformedCounts[mac].count
|
||||
server.mu.RUnlock()
|
||||
|
||||
if finalCount != initialCount+1 {
|
||||
t.Errorf("Malformed count not incremented for %s: got %d, want %d", tt.name, finalCount, initialCount+1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMalformedCounter_SlidingWindow verifies the sliding window behavior
|
||||
func TestMalformedCounter_SlidingWindow(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
mac := "AA:BB:CC:DD:EE:FF"
|
||||
|
||||
server.mu.Lock()
|
||||
server.malformedCounts[mac] = &malformedCounter{
|
||||
count: 99,
|
||||
firstSeen: time.Now(),
|
||||
}
|
||||
server.mu.Unlock()
|
||||
|
||||
// Add one more - should NOT trigger WARN yet (only > 100)
|
||||
server.recordMalformed(mac)
|
||||
|
||||
server.mu.RLock()
|
||||
count := server.malformedCounts[mac].count
|
||||
server.mu.RUnlock()
|
||||
|
||||
if count != 100 {
|
||||
t.Errorf("Expected count 100, got %d", count)
|
||||
}
|
||||
|
||||
// Wait a bit to avoid spam detection in the same second
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Add one more - should trigger WARN (101 > 100)
|
||||
server.recordMalformed(mac)
|
||||
|
||||
server.mu.RLock()
|
||||
count = server.malformedCounts[mac].count
|
||||
server.mu.RUnlock()
|
||||
|
||||
if count != 101 {
|
||||
t.Errorf("Expected count 101, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMalformedCounter_ConnectionCloseIntegration verifies integration with WebSocket
|
||||
func TestMalformedCounter_ConnectionCloseIntegration(t *testing.T) {
|
||||
// Create a test server with WebSocket upgrader
|
||||
ingestServer := NewServer()
|
||||
|
||||
// Create a test HTTP server
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ingestServer.HandleNodeWS(w, r)
|
||||
}))
|
||||
defer httpServer.Close()
|
||||
|
||||
// Convert http:// to ws://
|
||||
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws/node"
|
||||
|
||||
// Create a WebSocket connection
|
||||
dialer := websocket.Dialer{}
|
||||
conn, resp, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send a hello message first
|
||||
hello := `{"type":"hello","mac":"AA:BB:CC:DD:EE:FF","firmware_version":"1.0.0","chip":"ESP32-S3"}`
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte(hello)); err != nil {
|
||||
t.Fatalf("Failed to send hello: %v", err)
|
||||
}
|
||||
|
||||
// Read the response (should be role or config message)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
}
|
||||
|
||||
// Now send many malformed frames to trigger the close threshold
|
||||
mac := "AA:BB:CC:DD:EE:FF"
|
||||
ingestServer.mu.Lock()
|
||||
// Set counter close to threshold to speed up test
|
||||
ingestServer.malformedCounts[mac] = &malformedCounter{
|
||||
count: 1000,
|
||||
firstSeen: time.Now(),
|
||||
}
|
||||
ingestServer.mu.Unlock()
|
||||
|
||||
// Send one more malformed frame
|
||||
invalidFrame := make([]byte, 10)
|
||||
if err := conn.WriteMessage(websocket.BinaryMessage, invalidFrame); err != nil {
|
||||
// Connection might already be closing
|
||||
}
|
||||
|
||||
// Wait for the close to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to read - should get close error
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Error("Expected connection to be closed, but it's still open")
|
||||
}
|
||||
}
|
||||
|
|
@ -53,12 +53,24 @@ type Server struct {
|
|||
// NewServer creates a provisioning server.
|
||||
// dataDir is where the install secret is persisted.
|
||||
// mdnsName and msPort are embedded in the payload so the node can find the mothership.
|
||||
func NewServer(dataDir, mdnsName string, msPort int) *Server {
|
||||
// ntpServer is the NTP server hostname to embed in the provisioning payload.
|
||||
// installSecretHex is an optional 64-char hex string; if provided, it overrides the persisted secret.
|
||||
func NewServer(dataDir, mdnsName string, msPort int, ntpServer string, installSecretHex string) *Server {
|
||||
s := &Server{
|
||||
secretFile: filepath.Join(dataDir, "install_secret.bin"),
|
||||
mdnsName: mdnsName,
|
||||
msPort: msPort,
|
||||
ntpServer: envOr("SPAXEL_NTP_SERVER", "pool.ntp.org"),
|
||||
ntpServer: ntpServer,
|
||||
}
|
||||
// If install secret provided via config, use it instead of loading/creating
|
||||
if installSecretHex != "" {
|
||||
decoded, err := hex.DecodeString(installSecretHex)
|
||||
if err == nil && len(decoded) == 32 {
|
||||
s.installSecret = decoded
|
||||
log.Printf("[INFO] provisioning: using install secret from SPAXEL_INSTALL_SECRET")
|
||||
} else {
|
||||
log.Printf("[WARN] provisioning: invalid SPAXEL_INSTALL_SECRET, will use persisted secret")
|
||||
}
|
||||
}
|
||||
if err := s.loadOrCreateSecret(); err != nil {
|
||||
log.Printf("[ERROR] provisioning: could not load/create install secret: %v", err)
|
||||
|
|
|
|||
|
|
@ -19,11 +19,12 @@ const (
|
|||
|
||||
// SubsystemTimeout is the maximum time for each subsystem start in Phase 5.
|
||||
SubsystemTimeout = 5 * time.Second
|
||||
|
||||
// ReadyFile is written on successful startup (Phase 7).
|
||||
ReadyFile = "/tmp/spaxel.ready"
|
||||
)
|
||||
|
||||
// ReadyFile is the path for the ready marker file.
|
||||
// Override in tests before calling WriteReadyFile/RemoveReadyFile.
|
||||
var ReadyFile = "/tmp/spaxel.ready"
|
||||
|
||||
// Phase logs the start of a startup phase and returns a function that logs
|
||||
// completion with elapsed time. The returned function should be called via
|
||||
// defer or after the phase work completes.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue