diff --git a/mothership/cmd/mothership/main.go b/mothership/cmd/mothership/main.go index 3eb4f81..2848169 100644 --- a/mothership/cmd/mothership/main.go +++ b/mothership/cmd/mothership/main.go @@ -475,7 +475,8 @@ func main() { } defer authHandler.Close() authHandler.RegisterRoutes(&handleFuncAdapter{router: r}) - log.Printf("[INFO] Auth handler registered at /api/auth/*") + r.Use(authHandler.Middleware) + log.Printf("[INFO] Auth handler registered at /api/auth/* (server-side enforcement enabled)") // Create load shedder — single source of truth for load shedding state shedder := loadshed.New() diff --git a/mothership/internal/auth/handler.go b/mothership/internal/auth/handler.go index 75814a2..4ec9a14 100644 --- a/mothership/internal/auth/handler.go +++ b/mothership/internal/auth/handler.go @@ -671,14 +671,16 @@ func (h *Handler) GetInstallSecretForNodes() ([]byte, error) { return h.GetInstallSecret() } -// Helper function to check if a path should be excluded from auth -func isPublicPath(path string) bool { +// IsPublicPath checks if a path should be excluded from auth. +func IsPublicPath(path string) bool { publicPaths := []string{ "/healthz", "/api/auth/status", "/api/auth/setup", "/api/auth/login", + "/api/auth/logout", "/api/provision", + "/ws/node", } for _, pp := range publicPaths { @@ -688,9 +690,51 @@ func isPublicPath(path string) bool { } // Firmware is served without auth (URL contains SHA256 for integrity) - if len(path) > 10 && path[:10] == "/firmware/" { + if strings.HasPrefix(path, "/firmware/") { return true } return false } + +// IsPINConfigured returns true if a PIN has been set. +func (h *Handler) IsPINConfigured() bool { + var pinBcrypt sql.NullString + err := h.db.QueryRow("SELECT pin_bcrypt FROM auth WHERE id = 1").Scan(&pinBcrypt) + return err == nil && pinBcrypt.Valid +} + +// Middleware returns chi-compatible middleware that enforces auth on API and +// WebSocket routes. Static files pass through so the login page can render. +// During onboarding (no PIN configured), all requests pass through. +func (h *Handler) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + if IsPublicPath(path) { + next.ServeHTTP(w, r) + return + } + + // Static files and HTML pages pass through so the login UI renders + if !strings.HasPrefix(path, "/api/") && !strings.HasPrefix(path, "/ws/") { + next.ServeHTTP(w, r) + return + } + + // During onboarding (no PIN set), allow everything + if !h.IsPINConfigured() { + next.ServeHTTP(w, r) + return + } + + if !h.IsAuthenticated(r) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"}) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/mothership/internal/auth/handler_test.go b/mothership/internal/auth/handler_test.go index 07fd14b..8f2b6be 100644 --- a/mothership/internal/auth/handler_test.go +++ b/mothership/internal/auth/handler_test.go @@ -407,17 +407,18 @@ func TestPublicPaths(t *testing.T) { {"/api/auth/status", true}, {"/api/auth/setup", true}, {"/api/auth/login", true}, + {"/api/auth/logout", true}, {"/api/provision", true}, + {"/ws/node", true}, {"/firmware/spaxel-1.0.0.bin", true}, {"/api/settings", false}, {"/api/nodes", false}, {"/ws/dashboard", false}, - {"/ws/node", false}, } for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { - result := isPublicPath(tt.path) + result := IsPublicPath(tt.path) if result != tt.expected { t.Errorf("isPublicPath(%q) = %v, want %v", tt.path, result, tt.expected) } @@ -888,7 +889,7 @@ func TestHandler_ChangePIN_Unauthenticated(t *testing.T) { // Try to change PIN without authentication (no cookie) changeReqBody := `{"old_pin": "1234", "new_pin": "5678"}` - req := httptest.NewRequest("POST", "/api/auth/change-pin", strings.NewReader(changeReqBody)) + req = httptest.NewRequest("POST", "/api/auth/change-pin", strings.NewReader(changeReqBody)) req.Header.Set("Content-Type", "application/json") w = httptest.NewRecorder()