spaxel/mothership/internal/api/prediction_test.go
jedarden f99dc15a2d feat: complete crowd flow visualization implementation
- Fix Viz3D exports to include flow visualization functions
- Export setFlowLayerVisible, setDwellLayerVisible, setCorridorLayerVisible
- Export setFlowTimeFilter, setFlowData, setDwellData, setCorridorData
- Remove duplicate setDwellLayerVisible function definition

This completes the crowd flow visualization feature that was
already implemented in the backend (flow.go) and frontend
(crowdflow.js, viz3d.js) but had missing exports in the Viz3D module.
2026-04-11 07:27:21 -04:00

325 lines
8.6 KiB
Go

// Package api provides REST API handlers for presence prediction.
package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/spaxel/mothership/internal/prediction"
)
// mockZoneProvider implements ZoneProvider for testing.
type mockZoneProvider struct {
zones map[string]string
}
func (m *mockZoneProvider) GetZone(id string) (string, bool) {
if m.zones == nil {
return "", false
}
name, ok := m.zones[id]
return name, ok
}
// mockPersonProvider implements PersonProvider for testing.
type mockPersonProvider struct {
people []struct {
ID string
Name string
}
}
func (m *mockPersonProvider) GetPeople() ([]struct {
ID string
Name string
}, error) {
return m.people, nil
}
func TestPredictionHandler_getPredictions(t *testing.T) {
// Create temporary database
tmpDir, err := os.MkdirTemp("", "prediction_api_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create prediction components
store, err := prediction.NewModelStore(filepath.Join(tmpDir, "predictions.db"))
if err != nil {
t.Fatalf("Failed to create model store: %v", err)
}
defer store.Close()
accuracy, err := prediction.NewAccuracyTracker(filepath.Join(tmpDir, "accuracy.db"))
if err != nil {
t.Fatalf("Failed to create accuracy tracker: %v", err)
}
defer accuracy.Close()
predictor := prediction.NewPredictor(store)
horizon := prediction.NewHorizonPredictor(store, accuracy)
history := prediction.NewHistoryUpdater(store)
handler := NewPredictionHandler(predictor, history, accuracy, horizon)
// Set mock providers
zp := &mockZoneProvider{
zones: map[string]string{
"zone_1": "Kitchen",
"zone_2": "Living Room",
},
}
pp := &mockPersonProvider{
people: []struct {
ID string
Name string
}{
{ID: "person_1", Name: "Alice"},
},
}
handler.SetZoneProvider(zp)
handler.SetPersonProvider(pp)
// Create test router
r := chi.NewRouter()
handler.RegisterRoutes(r)
// Test GET /api/predictions
req := httptest.NewRequest("GET", "/api/predictions", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var predictions []prediction.PersonPrediction
if err := json.NewDecoder(w.Body).Decode(&predictions); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
// Initially should be empty
if len(predictions) != 0 {
t.Errorf("Expected 0 predictions, got %d", len(predictions))
}
}
func TestPredictionHandler_getStats(t *testing.T) {
// Create temporary database
tmpDir, err := os.MkdirTemp("", "prediction_api_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
store, err := prediction.NewModelStore(filepath.Join(tmpDir, "predictions.db"))
if err != nil {
t.Fatalf("Failed to create model store: %v", err)
}
defer store.Close()
accuracy, err := prediction.NewAccuracyTracker(filepath.Join(tmpDir, "accuracy.db"))
if err != nil {
t.Fatalf("Failed to create accuracy tracker: %v", err)
}
defer accuracy.Close()
predictor := prediction.NewPredictor(store)
history := prediction.NewHistoryUpdater(store)
handler := NewPredictionHandler(predictor, history, accuracy, nil)
r := chi.NewRouter()
handler.RegisterRoutes(r)
// Test GET /api/predictions/stats
req := httptest.NewRequest("GET", "/api/predictions/stats", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var stats map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&stats); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
// Check fields
if _, ok := stats["transition_count"]; !ok {
t.Error("Missing transition_count field")
}
if _, ok := stats["data_age_days"]; !ok {
t.Error("Missing data_age_days field")
}
if _, ok := stats["has_minimum_data"]; !ok {
t.Error("Missing has_minimum_data field")
}
}
func TestPredictionHandler_getAccuracyOverall(t *testing.T) {
// Create temporary database
tmpDir, err := os.MkdirTemp("", "prediction_api_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
store, err := prediction.NewModelStore(filepath.Join(tmpDir, "predictions.db"))
if err != nil {
t.Fatalf("Failed to create model store: %v", err)
}
defer store.Close()
accuracy, err := prediction.NewAccuracyTracker(filepath.Join(tmpDir, "accuracy.db"))
if err != nil {
t.Fatalf("Failed to create accuracy tracker: %v", err)
}
defer accuracy.Close()
predictor := prediction.NewPredictor(store)
history := prediction.NewHistoryUpdater(store)
handler := NewPredictionHandler(predictor, history, accuracy, nil)
r := chi.NewRouter()
handler.RegisterRoutes(r)
// Test GET /api/predictions/accuracy/overall
req := httptest.NewRequest("GET", "/api/predictions/accuracy/overall", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var result map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
// Check required fields
requiredFields := []string{"accuracy_percent", "total_predictions", "pending_predictions", "target_accuracy", "meets_target", "horizon_minutes"}
for _, field := range requiredFields {
if _, ok := result[field]; !ok {
t.Errorf("Missing required field: %s", field)
}
}
// Verify target accuracy is 75%
if target, ok := result["target_accuracy"].(float64); !ok || target != 75.0 {
t.Errorf("Expected target_accuracy 75.0, got %v", result["target_accuracy"])
}
// Verify horizon is 15 minutes
if horizon, ok := result["horizon_minutes"].(float64); !ok || horizon != 15 {
t.Errorf("Expected horizon_minutes 15, got %v", result["horizon_minutes"])
}
}
func TestPredictionHandler_getHorizonPredictions(t *testing.T) {
// Create temporary database
tmpDir, err := os.MkdirTemp("", "prediction_api_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
store, err := prediction.NewModelStore(filepath.Join(tmpDir, "predictions.db"))
if err != nil {
t.Fatalf("Failed to create model store: %v", err)
}
defer store.Close()
accuracy, err := prediction.NewAccuracyTracker(filepath.Join(tmpDir, "accuracy.db"))
if err != nil {
t.Fatalf("Failed to create accuracy tracker: %v", err)
}
defer accuracy.Close()
predictor := prediction.NewPredictor(store)
horizon := prediction.NewHorizonPredictor(store, accuracy)
history := prediction.NewHistoryUpdater(store)
handler := NewPredictionHandler(predictor, history, accuracy, horizon)
// Set mock providers
zp := &mockZoneProvider{
zones: map[string]string{
"zone_1": "Kitchen",
},
}
pp := &mockPersonProvider{
people: []struct {
ID string
Name string
}{
{ID: "person_1", Name: "Alice"},
},
}
handler.SetZoneProvider(zp)
handler.SetPersonProvider(pp)
r := chi.NewRouter()
handler.RegisterRoutes(r)
// Test GET /api/predictions/horizon
req := httptest.NewRequest("GET", "/api/predictions/horizon?horizon=30", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var result map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
// Check horizon parameter was respected
if horizon, ok := result["horizon_minutes"].(float64); !ok || horizon != 30 {
t.Errorf("Expected horizon_minutes 30, got %v", result["horizon_minutes"])
}
// Check predictions array exists
if _, ok := result["predictions"]; !ok {
t.Error("Missing predictions field")
}
}
func TestLogPredictionAccuracy(t *testing.T) {
// Create temporary database
tmpDir, err := os.MkdirTemp("", "prediction_api_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
accuracy, err := prediction.NewAccuracyTracker(filepath.Join(tmpDir, "accuracy.db"))
if err != nil {
t.Fatalf("Failed to create accuracy tracker: %v", err)
}
defer accuracy.Close()
// Record some predictions
_ = accuracy.RecordPrediction("person1", "zone_a", "zone_b", 0.8, 15*time.Minute)
_ = accuracy.RecordPrediction("person1", "zone_a", "zone_b", 0.9, 15*time.Minute)
// Evaluate them as if they were correct
actualPositions := map[string]string{"person1": "zone_b"}
_, _, _ = accuracy.EvaluatePending(actualPositions)
// Log accuracy (should not crash)
LogPredictionAccuracy(accuracy)
}