test(core): add Redis session TTL expiration test

test(proxy): fix middleware layer ordering for request ID propagation

- Add test_redis_sessions_expire to verify session keys get EXPIRE set and are deleted after TTL
- Reorder middleware stack: csrf_middleware now outermost, telemetry_middleware reads X-Request-Id set by request_id_middleware
- Add comment documenting layer order and request_id flow
- Change test_task_registry_impl to multi_thread flavor for Redis compatibility

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-04-25 16:11:15 -04:00
parent 5bec6e2bf3
commit bf081e5748
5 changed files with 78 additions and 24 deletions

View file

@ -3067,6 +3067,47 @@ mod tests {
assert_eq!(deleted, 0);
}
#[tokio::test]
async fn test_redis_sessions_expire() {
let (store, _url) = setup_redis_store().await;
store.migrate().expect("Migration should succeed");
// Create a session with a short TTL (1 second)
let session = SessionRow {
session_id: "sess-expire".to_string(),
last_write_mtask_id: Some("task-1".to_string()),
last_write_at: Some(now_ms()),
pinned_group: Some(1),
min_settings_version: 1,
ttl: now_ms() + 1000, // expires in 1 second
};
store.upsert_session(&session).expect("Upsert should succeed");
// Verify session exists immediately
let got = store
.get_session("sess-expire")
.expect("Get should succeed")
.expect("Session should exist immediately after creation");
assert_eq!(got.session_id, "sess-expire");
// Verify EXPIRE is set on the key (TTL should be > 0)
let key = "miroir:session:sess-expire";
let mut conn = store.pool.manager.lock().await;
let ttl: i64 = conn.ttl(key).await.expect("TTL should work");
assert!(ttl > 0, "Session key should have EXPIRE set, got TTL={}", ttl);
assert!(ttl <= 2, "TTL should be approximately 1 second, got {}", ttl);
drop(conn);
// Wait for expiration (2 seconds to be safe, allowing for Redis timing granularity)
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
// Verify session is gone after expiration
let got = store
.get_session("sess-expire")
.expect("Get should succeed");
assert!(got.is_none(), "Session should be expired and gone after TTL");
}
// --- Table 5: idempotency tests ---
#[tokio::test]

View file

@ -271,18 +271,14 @@ async fn main() -> anyhow::Result<()> {
.nest("/search", search::router::<UnifiedState>())
.nest("/settings", settings::router::<UnifiedState>())
.nest("/tasks", tasks::router::<UnifiedState>())
.layer(axum::middleware::from_fn(
middleware::request_id_middleware,
))
.layer(axum::extract::DefaultBodyLimit::max(
config.server.max_body_bytes as usize,
))
.layer(axum::Extension(state.admin.config.clone()))
.layer(axum::Extension(std::sync::Arc::new(state.admin.clone())))
.layer(axum::middleware::from_fn_with_state(
state.auth.clone(),
auth::auth_middleware,
))
// IMPORTANT: Layer order matters! Last layer() call = outermost = runs first.
// The middleware stack (from outermost to innermost):
// 1. csrf_middleware - runs first
// 2. auth_middleware
// 3. Extension layers
// 4. request_id_middleware - sets X-Request-Id header
// 5. telemetry_middleware - reads X-Request-Id, creates tracing span with request_id field
// The span's request_id field propagates to all child log events via with_current_span(true)
.layer(axum::middleware::from_fn_with_state(
auth::CsrfState {
auth: state.auth.clone(),
@ -290,6 +286,18 @@ async fn main() -> anyhow::Result<()> {
},
auth::csrf_middleware,
))
.layer(axum::middleware::from_fn_with_state(
state.auth.clone(),
auth::auth_middleware,
))
.layer(axum::Extension(std::sync::Arc::new(state.admin.clone())))
.layer(axum::Extension(state.admin.config.clone()))
.layer(axum::extract::DefaultBodyLimit::max(
config.server.max_body_bytes as usize,
))
.layer(axum::middleware::from_fn(
middleware::request_id_middleware,
))
.layer(axum::middleware::from_fn_with_state(
TelemetryState {
metrics: state.metrics.clone(),

View file

@ -932,6 +932,10 @@ fn extract_path_template(request: &Request) -> String {
}
/// Main middleware that combines request ID injection, structured logging, and Prometheus metrics.
///
/// IMPORTANT: This middleware must be applied AFTER request_id_middleware in the layer stack
/// (i.e., its layer() call must come BEFORE request_id_middleware's layer() call).
/// This ensures the request_id header is already set when this middleware runs.
pub async fn telemetry_middleware(
State(telemetry): State<TelemetryState>,
mut req: Request,
@ -943,11 +947,12 @@ pub async fn telemetry_middleware(
let metrics = telemetry.metrics.clone();
let pod_id = telemetry.pod_id.clone();
// Generate or extract request ID
// Extract request ID from header (set by request_id_middleware)
// The header must already exist because request_id_middleware runs first.
let request_id = req
.headers()
.get_request_id()
.unwrap_or_else(generate_request_id);
.expect("request_id header must be set by request_id_middleware");
req.headers_mut().set_request_id(&request_id);
// Create span for structured logging with pod_id included.

View file

@ -570,7 +570,7 @@ mod tests {
assert!(ts.contains("Z"));
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_task_registry_impl() {
let registry = TaskRegistryImpl::in_memory();
let mut node_tasks = HashMap::new();

View file

@ -19,7 +19,7 @@ fn parse_log_line(line: &str) -> Option<serde_json::Value> {
serde_json::from_str(line).ok()
}
/// Helper: check if a string is a valid 8-char hex request ID
/// Helper: check if a string is a valid 8-char hex request ID (from RequestId::new)
fn is_valid_request_id(s: &str) -> bool {
s.len() == 8 && s.chars().all(|c| c.is_ascii_hexdigit())
}
@ -96,7 +96,7 @@ fn test_request_id_format_in_logs() {
"level": "info",
"target": "miroir.request",
"message": "search completed",
"request_id": "a1b2c3d4e5f67890",
"request_id": "a1b2c3d4",
"pod_id": "test-pod",
"duration_ms": 42
}"#;
@ -104,8 +104,8 @@ fn test_request_id_format_in_logs() {
let parsed = parse_log_line(sample_log).unwrap();
let request_id = parsed["request_id"].as_str().unwrap();
// Request IDs should be 16 hex chars (from generate_request_id)
assert_eq!(request_id.len(), 16);
// Request IDs should be 8 hex chars (from RequestId::new())
assert_eq!(request_id.len(), 8);
assert!(request_id.chars().all(|c| c.is_ascii_hexdigit()));
}
@ -116,13 +116,13 @@ fn test_request_id_format_in_logs() {
#[test]
fn test_request_id_extraction_from_logs() {
let logs = vec![
r#"{"timestamp":"2026-05-01T12:00:00.000Z","level":"info","target":"miroir.request","request_id":"abc123def4567890","pod_id":"pod-1","message":"GET /search 200"}"#,
r#"{"timestamp":"2026-05-01T12:00:00.001Z","level":"debug","target":"miroir.node","request_id":"abc123def4567890","pod_id":"pod-1","node_id":"node-1","message":"node call started"}"#,
r#"{"timestamp":"2026-05-01T12:00:00.010Z","level":"info","target":"miroir.search","request_id":"abc123def4567890","pod_id":"pod-1","index":"products","message":"search completed"}"#,
r#"{"timestamp":"2026-05-01T12:00:00.000Z","level":"info","target":"miroir.request","request_id":"abc12345","pod_id":"pod-1","message":"GET /search 200"}"#,
r#"{"timestamp":"2026-05-01T12:00:00.001Z","level":"debug","target":"miroir.node","request_id":"abc12345","pod_id":"pod-1","node_id":"node-1","message":"node call started"}"#,
r#"{"timestamp":"2026-05-01T12:00:00.010Z","level":"info","target":"miroir.search","request_id":"abc12345","pod_id":"pod-1","index":"products","message":"search completed"}"#,
];
// Extract all logs with request_id = "abc123def4567890"
let target_id = "abc123def4567890";
// Extract all logs with request_id = "abc12345"
let target_id = "abc12345";
let matching_logs: Vec<_> = logs
.iter()
.filter(|line| line.contains(&format!("\"request_id\":\"{}\"", target_id)))