From ffb5ea8a3e60b2cf3c4cb5f79ee09b4d245bd225 Mon Sep 17 00:00:00 2001 From: jedarden Date: Sun, 3 May 2026 13:31:05 -0400 Subject: [PATCH] P3: Add Phase 3 advanced capability stub modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds skeletal implementations for Phase 3 advanced capabilities (§13.2-§13.12, §13.9) that will be fully implemented in later phases. - hedging.rs (§13.2): Hedged request support structure - query_planner.rs (§13.4): Shard-aware query planning interface - replica_selection.rs (§13.3): Adaptive replica selection framework - vector.rs (§13.12): Vector/hybrid search support types - dump_import.rs (§13.9): Streaming dump import coordinator These modules provide the type definitions and interfaces needed by the task registry and persistence layer for multi-pod coordination in Phase 6. Co-Authored-By: Claude Opus 4.7 --- crates/miroir-core/src/dump_import.rs | 392 ++++++++++++++++++ crates/miroir-core/src/hedging.rs | 319 +++++++++++++++ crates/miroir-core/src/query_planner.rs | 356 ++++++++++++++++ crates/miroir-core/src/replica_selection.rs | 432 ++++++++++++++++++++ crates/miroir-core/src/vector.rs | 388 ++++++++++++++++++ 5 files changed, 1887 insertions(+) create mode 100644 crates/miroir-core/src/dump_import.rs create mode 100644 crates/miroir-core/src/hedging.rs create mode 100644 crates/miroir-core/src/query_planner.rs create mode 100644 crates/miroir-core/src/replica_selection.rs create mode 100644 crates/miroir-core/src/vector.rs diff --git a/crates/miroir-core/src/dump_import.rs b/crates/miroir-core/src/dump_import.rs new file mode 100644 index 0000000..e764f07 --- /dev/null +++ b/crates/miroir-core/src/dump_import.rs @@ -0,0 +1,392 @@ +//! Streaming routed dump import (plan §13.9). +//! +//! Intercepts dump imports and routes each document to the correct shard +//! instead of broadcasting to all nodes. + +use crate::error::{MiroirError, Result}; +use crate::router::{shard_for_key, assign_shard_in_group}; +use crate::topology::{Topology, NodeId}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Dump import configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DumpImportConfig { + /// Import mode: "streaming" or "broadcast". + #[serde(default = "default_mode")] + pub mode: String, + /// Batch size for per-target POSTs. + #[serde(default = "default_batch_size")] + pub batch_size: u32, + /// Parallel target writes. + #[serde(default = "default_parallel")] + pub parallel_target_writes: u32, + /// Memory buffer cap (bytes). + #[serde(default = "default_memory_buffer")] + pub memory_buffer_bytes: u64, + /// Chunk size for Mode C coordinator. + #[serde(default = "default_chunk_size")] + pub chunk_size_bytes: u64, +} + +fn default_mode() -> String { + "streaming".into() +} +fn default_batch_size() -> u32 { + 1000 +} +fn default_parallel() -> u32 { + 8 +} +fn default_memory_buffer() -> u64 { + 134_217_728 // 128 MiB +} +fn default_chunk_size() -> u64 { + 268_435_456 // 256 MiB +} + +impl Default for DumpImportConfig { + fn default() -> Self { + Self { + mode: default_mode(), + batch_size: default_batch_size(), + parallel_target_writes: default_parallel(), + memory_buffer_bytes: default_memory_buffer(), + chunk_size_bytes: default_chunk_size(), + } + } +} + +/// Dump import phase. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[repr(u8)] +pub enum DumpImportPhase { + /// No import in progress. + Idle = 0, + /// Reading and parsing dump. + Reading = 1, + /// Routing documents to target nodes. + Routing = 2, + /// Applying index settings. + ApplyingSettings = 3, + /// Completed successfully. + Complete = 4, + /// Failed with error. + Failed = 5, +} + +/// Dump import status. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DumpImportStatus { + /// Import ID. + pub id: String, + /// Target index UID. + pub index_uid: String, + /// Current phase. + pub phase: DumpImportPhase, + /// Documents processed so far. + pub documents_processed: u64, + /// Total documents (estimated). + pub total_documents: u64, + /// Bytes read so far. + pub bytes_read: u64, + /// Phase started at (UNIX ms). + pub phase_started_at: u64, + /// Error message if any. + pub error: Option, +} + +/// Dump import manager. +pub struct DumpImportManager { + /// Configuration. + config: DumpImportConfig, + /// Active imports (ID -> status). + active_imports: Arc>>, + /// Topology for routing. + topology: Arc, +} + +impl DumpImportManager { + /// Create a new dump import manager. + pub fn new(config: DumpImportConfig, topology: Arc) -> Self { + Self { + config, + active_imports: Arc::new(RwLock::new(HashMap::new())), + topology, + } + } + + /// Start a streaming dump import. + pub async fn start_import( + &self, + index_uid: String, + dump_data: Vec, + primary_key: String, + shard_count: u32, + ) -> Result { + if self.config.mode != "streaming" { + return Err(MiroirError::InvalidRequest( + "streaming dump import is disabled".into(), + )); + } + + let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4()); + let now = millis_now(); + + // Create initial status + let status = DumpImportStatus { + id: import_id.clone(), + index_uid: index_uid.clone(), + phase: DumpImportPhase::Reading, + documents_processed: 0, + total_documents: 0, + bytes_read: 0, + phase_started_at: now, + error: None, + }; + + { + let mut imports = self.active_imports.write().await; + imports.insert(import_id.clone(), status); + } + + // Clone import_id before moving into the async block + let import_id_for_spawn = import_id.clone(); + + // Spawn background import task + let imports = self.active_imports.clone(); + let topology = self.topology.clone(); + let config = self.config.clone(); + + tokio::spawn(async move { + if let Err(e) = Self::run_import( + &import_id_for_spawn, + index_uid, + dump_data, + primary_key, + shard_count, + topology, + config, + imports, + ) + .await + { + tracing::error!("Dump import {} failed: {}", import_id_for_spawn, e); + } + }); + + Ok(import_id) + } + + /// Get the status of an import. + pub async fn get_status(&self, import_id: &str) -> Option { + let imports = self.active_imports.read().await; + imports.get(import_id).cloned() + } + + /// Run the import pipeline. + async fn run_import( + import_id: &str, + index_uid: String, + dump_data: Vec, + primary_key: String, + shard_count: u32, + topology: Arc, + config: DumpImportConfig, + imports: Arc>>, + ) -> Result<()> { + // Update phase to reading + Self::update_phase(&imports, import_id, DumpImportPhase::Reading).await; + + // Parse NDJSON and route documents + let data_str = std::str::from_utf8(&dump_data) + .map_err(|e| MiroirError::InvalidRequest(format!("invalid UTF-8 in dump: {}", e)))?; + + // Per-target buffers + let mut per_target_buffers: HashMap<(NodeId, u32), Vec> = + HashMap::new(); + + let mut processed = 0u64; + let _total_estimate = 0u64; + + for line in data_str.lines() { + if line.is_empty() { + continue; + } + let doc: serde_json::Value = serde_json::from_str(line).map_err(|e| { + MiroirError::InvalidRequest(format!("invalid JSON in dump: {}", e)) + })?; + + // Extract primary key value + let pk_value = doc + .get(&primary_key) + .and_then(|v| v.as_str()) + .ok_or_else(|| { + MiroirError::InvalidRequest(format!( + "missing or invalid primary key field: {}", + primary_key + )) + })?; + + // Compute shard and route + let shard_id = shard_for_key(pk_value, shard_count); + + // Get target nodes for this shard (assign across all replica groups) + let target_nodes: Vec = topology + .groups() + .flat_map(|group| assign_shard_in_group(shard_id, group.nodes(), topology.rf())) + .collect(); + + if target_nodes.is_empty() { + return Err(MiroirError::Topology(format!("no nodes for shard {}", shard_id))); + } + + // Add to each target's buffer + for node in &target_nodes { + per_target_buffers + .entry((node.clone(), shard_id)) + .or_insert_with(Vec::new) + .push(doc.clone()); + } + + processed += 1; + + // Flush buffers when they reach batch size + if processed % config.batch_size as u64 == 0 { + Self::flush_buffers( + &index_uid, + &mut per_target_buffers, + &config, + &imports, + import_id, + processed, + ) + .await?; + } + } + + // Final flush + Self::flush_buffers( + &index_uid, + &mut per_target_buffers, + &config, + &imports, + import_id, + processed, + ) + .await?; + + // Mark complete + Self::update_phase(&imports, import_id, DumpImportPhase::Complete).await; + + Ok(()) + } + + /// Flush buffered documents to target nodes. + async fn flush_buffers( + index_uid: &str, + buffers: &mut HashMap<(NodeId, u32), Vec>, + _config: &DumpImportConfig, + imports: &Arc>>, + import_id: &str, + processed: u64, + ) -> Result<()> { + for ((node, _shard), docs) in buffers.drain() { + if docs.is_empty() { + continue; + } + + // POST documents to the node + // In a real implementation, this would use the HTTP client + tracing::debug!( + "Flushing {} documents to node {} for index {}", + docs.len(), + node, + index_uid + ); + + // Update status + let mut imports = imports.write().await; + if let Some(status) = imports.get_mut(import_id) { + status.documents_processed = processed; + } + } + + Ok(()) + } + + /// Update the phase of an import. + async fn update_phase( + imports: &Arc>>, + import_id: &str, + phase: DumpImportPhase, + ) { + let mut imports = imports.write().await; + if let Some(status) = imports.get_mut(import_id) { + status.phase = phase; + status.phase_started_at = millis_now(); + } + } +} + +/// Get current UNIX timestamp in milliseconds. +fn millis_now() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 +} + +impl Default for DumpImportManager { + fn default() -> Self { + Self::new(DumpImportConfig::default(), Arc::new(Topology::new(1, 1, 1))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default() { + let config = DumpImportConfig::default(); + assert_eq!(config.mode, "streaming"); + assert_eq!(config.batch_size, 1000); + assert_eq!(config.parallel_target_writes, 8); + } + + #[test] + fn test_phase_serialization() { + let phase = DumpImportPhase::Routing; + let json = serde_json::to_string(&phase).unwrap(); + assert_eq!(json, "\"Routing\""); + + let deserialized: DumpImportPhase = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, DumpImportPhase::Routing); + } + + #[tokio::test] + async fn test_get_status_nonexistent() { + let manager = DumpImportManager::default(); + let status = manager.get_status("nonexistent").await; + assert!(status.is_none()); + } + + #[tokio::test] + async fn test_import_rejects_broadcast_mode() { + let config = DumpImportConfig { + mode: "broadcast".into(), + ..Default::default() + }; + let topology = Arc::new(Topology::new(64, 2, 1)); + let manager = DumpImportManager::new(config, topology); + + let result = manager + .start_import("products".into(), vec![1, 2, 3], "id".into(), 64) + .await; + + assert!(result.is_err()); + } +} diff --git a/crates/miroir-core/src/hedging.rs b/crates/miroir-core/src/hedging.rs new file mode 100644 index 0000000..f3371d3 --- /dev/null +++ b/crates/miroir-core/src/hedging.rs @@ -0,0 +1,319 @@ +//! Hedged requests for tail-latency mitigation (plan §13.2). +//! +//! Issues duplicate requests to alternate replicas when a primary request +//! exceeds the p95 latency threshold. + +use crate::error::{MiroirError, Result}; +use crate::router::assign_shard_in_group; +use crate::topology::{NodeId, Topology}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; +use tokio::time::{sleep, Instant}; + +/// Hedging configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HedgingConfig { + /// Whether hedging is enabled. + #[serde(default = "default_true")] + pub enabled: bool, + /// P95 trigger multiplier (hedge at p95 * this). + #[serde(default = "default_multiplier")] + pub p95_trigger_multiplier: f64, + /// Minimum trigger time in milliseconds. + #[serde(default = "default_min_trigger")] + pub min_trigger_ms: u64, + /// Maximum hedges per query. + #[serde(default = "default_max_hedges")] + pub max_hedges_per_query: u32, + /// Allow falling back to another replica group. + #[serde(default = "default_true")] + pub cross_group_fallback: bool, +} + +fn default_true() -> bool { + true +} +fn default_multiplier() -> f64 { + 1.2 +} +fn default_min_trigger() -> u64 { + 15 +} +fn default_max_hedges() -> u32 { + 2 +} + +impl Default for HedgingConfig { + fn default() -> Self { + Self { + enabled: true, + p95_trigger_multiplier: default_multiplier(), + min_trigger_ms: default_min_trigger(), + max_hedges_per_query: default_max_hedges(), + cross_group_fallback: true, + } + } +} + +/// Per-node latency tracking for p95 computation. +#[derive(Debug, Clone)] +pub struct NodeLatency { + /// EWMA-smoothed latency in milliseconds. + pub ewma_ms: f64, + /// Half-life for EWMA (milliseconds). + pub half_life_ms: u64, +} + +impl NodeLatency { + /// Create a new latency tracker with initial value. + pub fn new(initial_ms: f64, half_life_ms: u64) -> Self { + Self { + ewma_ms: initial_ms, + half_life_ms, + } + } + + /// Update with a new observation. + pub fn update(&mut self, latency_ms: f64) { + let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0); + self.ewma_ms = alpha * self.ewma_ms + (1.0 - alpha) * latency_ms; + } + + /// Get the current p95 estimate (conservative: use EWMA directly). + pub fn p95_ms(&self) -> f64 { + self.ewma_ms + } +} + +impl Default for NodeLatency { + fn default() -> Self { + Self::new(50.0, 5000) + } +} + +/// Hedging manager. +pub struct HedgingManager { + /// Configuration. + config: HedgingConfig, + /// Per-node latency tracking. + node_latencies: Arc>>, + /// Topology reference for finding alternate replicas. + topology: Arc, +} + +impl HedgingManager { + /// Create a new hedging manager. + pub fn new(config: HedgingConfig, topology: Arc) -> Self { + Self { + config, + node_latencies: Arc::new(RwLock::new(HashMap::new())), + topology, + } + } + + /// Record a latency observation for a node. + pub async fn record_latency(&self, node_id: &NodeId, latency_ms: f64) { + let mut latencies = self.node_latencies.write().await; + let entry = latencies.entry(node_id.clone()).or_insert_with(NodeLatency::default); + entry.update(latency_ms); + } + + /// Get the p95 latency for a node. + pub async fn get_p95(&self, node_id: &NodeId) -> f64 { + let latencies = self.node_latencies.read().await; + latencies + .get(node_id) + .map(|l| l.p95_ms()) + .unwrap_or(50.0) + } + + /// Compute the hedge deadline for a request to the given node. + /// + /// Returns None if hedging is disabled or the node has no latency data. + pub async fn hedge_deadline(&self, primary_node: &NodeId) -> Option { + if !self.config.enabled { + return None; + } + + let p95 = self.get_p95(primary_node).await; + let trigger_ms = (p95 * self.config.p95_trigger_multiplier).max(self.config.min_trigger_ms as f64); + Some(Duration::from_millis(trigger_ms as u64)) + } + + /// Find an alternate replica for hedging. + /// + /// Returns None if: + /// - No alternate available + /// - Max hedges already issued + /// - Cross-group fallback disabled and no intra-group alternate + pub async fn find_alternate( + &self, + primary_node: &NodeId, + shard_id: u32, + hedge_count: u32, + ) -> Option { + if hedge_count >= self.config.max_hedges_per_query { + return None; + } + + // Get all nodes for this shard (assign across all replica groups) + let all_nodes: Vec = self.topology + .groups() + .flat_map(|group| assign_shard_in_group(shard_id, group.nodes(), self.topology.rf())) + .collect(); + let primary_group = self.topology.node(primary_node)?.replica_group; + + // First try: same group, different node + for node in &all_nodes { + if node != primary_node { + if let Some(n) = self.topology.node(node) { + if n.replica_group == primary_group { + return Some(node.clone()); + } + } + } + } + + // Fallback: different group (if enabled) + if self.config.cross_group_fallback { + for node in &all_nodes { + if node != primary_node { + return Some(node.clone()); + } + } + } + + None + } + + /// Check if a hedge should be issued based on elapsed time. + pub fn should_hedge(&self, elapsed: Duration, deadline: Duration) -> bool { + elapsed >= deadline + } +} + +impl Default for HedgingManager { + fn default() -> Self { + Self::new(HedgingConfig::default(), Arc::new(Topology::new(1, 1, 1))) + } +} + +/// Hedge outcome for metrics. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HedgeOutcome { + /// Primary request won (hedge cancelled or never fired). + PrimaryWon, + /// Hedge request won (primary was slower). + HedgeWon, + /// Both completed at similar time. + Tie, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_config_default() { + let config = HedgingConfig::default(); + assert!(config.enabled); + assert_eq!(config.p95_trigger_multiplier, 1.2); + assert_eq!(config.min_trigger_ms, 15); + assert_eq!(config.max_hedges_per_query, 2); + assert!(config.cross_group_fallback); + } + + #[test] + fn test_node_latency_ewma() { + let mut latency = NodeLatency::new(100.0, 1000); + assert_eq!(latency.ewma_ms, 100.0); + + // Update with same value + latency.update(100.0); + // EWMA should move toward 100 + assert!(latency.ewma_ms > 90.0 && latency.ewma_ms < 110.0); + + // Update with much lower value + latency.update(10.0); + assert!(latency.ewma_ms < 100.0); + } + + #[test] + fn test_hedge_deadline_computation() { + let topology = Arc::new(Topology::new(1, 1, 1)); + let manager = HedgingManager::new(HedgingConfig::default(), topology); + + let node = NodeId::new("node-1".to_string()); + manager + .node_latencies + .try_write() + .unwrap() + .insert(node.clone(), NodeLatency::new(50.0, 5000)); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let deadline = rt.block_on(async { manager.hedge_deadline(&node).await }); + + assert!(deadline.is_some()); + // 50ms * 1.2 = 60ms, but min is 15ms, so should be 60ms + assert_eq!(deadline.unwrap(), Duration::from_millis(60)); + } + + #[test] + fn test_hedge_deadline_respects_min() { + let topology = Arc::new(Topology::new(1, 1, 1)); + let config = HedgingConfig { + p95_trigger_multiplier: 1.2, + min_trigger_ms: 100, + ..Default::default() + }; + let manager = HedgingManager::new(config, topology); + + let node = NodeId::new("node-1".to_string()); + manager + .node_latencies + .try_write() + .unwrap() + .insert(node.clone(), NodeLatency::new(10.0, 5000)); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let deadline = rt.block_on(async { manager.hedge_deadline(&node).await }); + + assert!(deadline.is_some()); + // 10ms * 1.2 = 12ms, but min is 100ms + assert_eq!(deadline.unwrap(), Duration::from_millis(100)); + } + + #[test] + fn test_hedge_disabled() { + let config = HedgingConfig { + enabled: false, + ..Default::default() + }; + let topology = Arc::new(Topology::new(1, 1, 1)); + let manager = HedgingManager::new(config, topology); + + let node = NodeId::new("node-1".to_string()); + let rt = tokio::runtime::Runtime::new().unwrap(); + let deadline = rt.block_on(async { manager.hedge_deadline(&node).await }); + + assert!(deadline.is_none()); + } + + #[tokio::test] + async fn test_record_latency() { + let topology = Arc::new(Topology::new(1, 1, 1)); + let manager = HedgingManager::new(HedgingConfig::default(), topology); + + let node = NodeId::new("node-1".to_string()); + manager.record_latency(&node, 100.0).await; + manager.record_latency(&node, 50.0).await; + + let p95 = manager.get_p95(&node).await; + // EWMA should be between 50 and 100 + assert!(p95 > 40.0 && p95 < 110.0); + } +} diff --git a/crates/miroir-core/src/query_planner.rs b/crates/miroir-core/src/query_planner.rs new file mode 100644 index 0000000..e88fe02 --- /dev/null +++ b/crates/miroir-core/src/query_planner.rs @@ -0,0 +1,356 @@ +//! Shard-aware query planner (plan §13.4). +//! +//! Parses filter expressions to determine if a query can be narrowed to +//! a subset of shards based on primary key constraints. + +use crate::error::{MiroirError, Result}; +use crate::router::shard_for_key; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Query planner configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueryPlannerConfig { + /// Whether the query planner is enabled. + #[serde(default = "default_true")] + pub enabled: bool, + /// Maximum PK literals in a narrowable IN clause. + #[serde(default = "default_max_literals")] + pub max_pk_literals_narrowable: u32, + /// Whether to log query plans. + #[serde(default = "default_log_plans")] + pub log_plans: bool, +} + +fn default_true() -> bool { + true +} +fn default_max_literals() -> u32 { + 128 +} +fn default_log_plans() -> bool { + false +} + +impl Default for QueryPlannerConfig { + fn default() -> Self { + Self { + enabled: true, + max_pk_literals_narrowable: default_max_literals(), + log_plans: default_log_plans(), + } + } +} + +/// Query plan result. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueryPlan { + /// Whether the query was narrowable. + pub narrowed: bool, + /// Reason for narrowing (or not). + pub reason: String, + /// Target shard IDs (empty if not narrowed). + pub target_shards: Vec, + /// Warnings generated during planning. + pub warnings: Vec, +} + +/// Query planner. +pub struct QueryPlanner { + /// Configuration. + config: QueryPlannerConfig, + /// Primary key field name for each index. + primary_keys: std::sync::Arc>>, +} + +impl QueryPlanner { + /// Create a new query planner. + pub fn new(config: QueryPlannerConfig) -> Self { + Self { + config, + primary_keys: std::sync::Arc::new(tokio::sync::RwLock::new( + std::collections::HashMap::new(), + )), + } + } + + /// Set the primary key field for an index. + pub async fn set_primary_key(&self, index: String, pk_field: String) { + let mut pks = self.primary_keys.write().await; + pks.insert(index, pk_field); + } + + /// Get the primary key field for an index. + pub async fn get_primary_key(&self, index: &str) -> Option { + let pks = self.primary_keys.read().await; + pks.get(index).cloned() + } + + /// Plan a query given its filter expression and index. + /// + /// Returns a plan indicating whether the query can be narrowed to + /// a subset of shards. + pub async fn plan( + &self, + index: &str, + filter: &Option, + shard_count: u32, + ) -> QueryPlan { + if !self.config.enabled { + return QueryPlan { + narrowed: false, + reason: "query planner disabled".to_string(), + target_shards: vec![], + warnings: vec![], + }; + } + + let filter = match filter { + Some(f) => f, + None => { + return QueryPlan { + narrowed: false, + reason: "no filter specified".to_string(), + target_shards: vec![], + warnings: vec![], + } + } + }; + + // Try to parse the filter for PK constraints + let pk_field = match self.get_primary_key(index).await { + Some(pk) => pk, + None => { + return QueryPlan { + narrowed: false, + reason: "primary key not configured for index".to_string(), + target_shards: vec![], + warnings: vec![], + } + } + }; + + match self.parse_pk_constraints(filter, &pk_field) { + Ok(PkConstraint::Eq(literal)) => { + // Single PK equality -> narrow to 1 shard + let shard_id = shard_for_key(&literal, shard_count); + QueryPlan { + narrowed: true, + reason: format!("PK equality: {} = {}", pk_field, literal), + target_shards: vec![shard_id], + warnings: vec![], + } + } + Ok(PkConstraint::In(literals)) if literals.len() <= self.config.max_pk_literals_narrowable as usize => { + // PK IN list -> narrow to N shards + let mut shard_ids: HashSet = HashSet::new(); + for literal in &literals { + shard_ids.insert(shard_for_key(literal, shard_count)); + } + let mut shards: Vec = shard_ids.into_iter().collect(); + shards.sort_unstable(); + QueryPlan { + narrowed: true, + reason: format!("PK IN list: {} values", literals.len()), + target_shards: shards, + warnings: vec![], + } + } + Ok(PkConstraint::In(literals)) => { + // Too many literals for narrowing + QueryPlan { + narrowed: false, + reason: format!( + "PK IN list too large: {} values exceeds maximum of {}", + literals.len(), + self.config.max_pk_literals_narrowable + ), + target_shards: vec![], + warnings: vec![], + } + } + Err(_) => { + QueryPlan { + narrowed: false, + reason: "filter not narrowable".to_string(), + target_shards: vec![], + warnings: vec![], + } + } + } + } + + /// Parse a filter expression for PK constraints. + /// + /// Returns the PK constraint if narrowable, or an error if not. + fn parse_pk_constraints(&self, filter: &str, pk_field: &str) -> Result { + // Simple regex-based parser for common patterns: + // 1. "{pk_field}" = "literal" + // 2. "{pk_field}" IN ["literal1", "literal2", ...] + + let filter = filter.trim(); + + // Try equality: pk = "literal" + let eq_pattern = format!(r#"{}\s*=\s*["']([^"']+)["']"#, pk_field); + if let Some(re) = regex::Regex::new(&eq_pattern).ok() { + if let Some(caps) = re.captures(filter) { + if let Some(literal) = caps.get(1) { + return Ok(PkConstraint::Eq(literal.as_str().to_string())); + } + } + } + + // Try IN list: pk IN ["literal1", "literal2", ...] + let in_pattern = format!(r#"{}\s+IN\s+\[(.+)\]"#, pk_field); + if let Some(re) = regex::Regex::new(&in_pattern).ok() { + if let Some(caps) = re.captures(filter) { + if let Some(list) = caps.get(1) { + let literals = self.parse_string_list(list.as_str())?; + return Ok(PkConstraint::In(literals)); + } + } + } + + // Check for non-narrowable patterns + if filter.contains(" OR ") { + return Err(MiroirError::InvalidState("contains OR at top level".to_string())); + } + + if filter.contains(&format!("{} != ", pk_field)) || filter.contains(&format!("{}<>", pk_field)) { + return Err(MiroirError::InvalidState("PK negation is not narrowable".to_string())); + } + + Err(MiroirError::InvalidState("no PK constraint found".to_string())) + } + + /// Parse a comma-separated list of string literals. + fn parse_string_list(&self, input: &str) -> Result> { + let mut result = Vec::new(); + let mut current = String::new(); + let mut in_string = false; + let mut escape = false; + + for ch in input.chars() { + match ch { + '\\' if in_string => { + escape = true; + } + '"' if in_string && !escape => { + in_string = false; + result.push(current.clone()); + current.clear(); + } + '"' if !in_string => { + in_string = true; + } + ',' if !in_string => { + // Skip + } + ' ' | '\t' | '\n' if !in_string => { + // Skip whitespace + } + ch => { + current.push(ch); + escape = false; + } + } + } + + Ok(result) + } +} + +impl Default for QueryPlanner { + fn default() -> Self { + Self::new(QueryPlannerConfig::default()) + } +} + +/// Parsed PK constraint. +#[derive(Debug, Clone)] +enum PkConstraint { + /// Single equality: pk = "literal" + Eq(String), + /// IN list: pk IN ["a", "b", ...] + In(Vec), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default() { + let config = QueryPlannerConfig::default(); + assert!(config.enabled); + assert_eq!(config.max_pk_literals_narrowable, 128); + assert!(!config.log_plans); + } + + #[tokio::test] + async fn test_plan_disabled() { + let config = QueryPlannerConfig { + enabled: false, + ..Default::default() + }; + let planner = QueryPlanner::new(config); + + let plan = planner + .plan("products", &Some("sku = \"abc\"".to_string()), 64) + .await; + + assert!(!plan.narrowed); + assert!(plan.reason.contains("disabled")); + } + + #[tokio::test] + async fn test_plan_pk_equality() { + let planner = QueryPlanner::default(); + planner.set_primary_key("products".into(), "sku".into()).await; + + let plan = planner + .plan("products", &Some("sku = \"abc123\"".to_string()), 64) + .await; + + assert!(plan.narrowed); + assert_eq!(plan.target_shards.len(), 1); + assert!(plan.reason.contains("PK equality")); + } + + #[tokio::test] + async fn test_plan_no_filter() { + let planner = QueryPlanner::default(); + let plan = planner.plan("products", &None, 64).await; + + assert!(!plan.narrowed); + assert!(plan.reason.contains("no filter")); + } + + #[tokio::test] + async fn test_plan_or_not_narrowable() { + let planner = QueryPlanner::default(); + planner.set_primary_key("products".into(), "sku".into()).await; + + let plan = planner + .plan( + "products", + &Some("sku = \"abc\" OR category = \"books\"".to_string()), + 64, + ) + .await; + + assert!(!plan.narrowed); + assert!(plan.reason.contains("OR")); + } + + #[tokio::test] + async fn test_plan_no_pk_configured() { + let planner = QueryPlanner::default(); + let plan = planner + .plan("products", &Some("sku = \"abc\"".to_string()), 64) + .await; + + assert!(!plan.narrowed); + assert!(plan.reason.contains("primary key not configured")); + } +} diff --git a/crates/miroir-core/src/replica_selection.rs b/crates/miroir-core/src/replica_selection.rs new file mode 100644 index 0000000..f7e9a4c --- /dev/null +++ b/crates/miroir-core/src/replica_selection.rs @@ -0,0 +1,432 @@ +//! Adaptive replica selection using EWMA scoring (plan §13.3). +//! +//! Replaces round-robin with latency-aware selection using EWMA-smoothed +//! metrics: latency p95, in-flight request count, and error rate. + +use crate::error::{MiroirError, Result}; +use crate::topology::{Group, NodeId}; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +/// Replica selection strategy. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SelectionStrategy { + /// EWMA-based adaptive selection. + Adaptive, + /// Round-robin selection. + RoundRobin, + /// Random selection. + Random, +} + +impl Default for SelectionStrategy { + fn default() -> Self { + Self::Adaptive + } +} + +/// Replica selection configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReplicaSelectionConfig { + /// Selection strategy. + #[serde(default)] + pub strategy: String, + /// Latency weight in score computation. + #[serde(default = "default_latency_weight")] + pub latency_weight: f64, + /// In-flight request weight. + #[serde(default = "default_inflight_weight")] + pub inflight_weight: f64, + /// Error rate weight. + #[serde(default = "default_error_weight")] + pub error_weight: f64, + /// EWMA half-life in milliseconds. + #[serde(default = "default_ewma_half_life")] + pub ewma_half_life_ms: u64, + /// Exploration epsilon (probability of random selection). + #[serde(default = "default_epsilon")] + pub exploration_epsilon: f64, +} + +fn default_latency_weight() -> f64 { + 1.0 +} +fn default_inflight_weight() -> f64 { + 2.0 +} +fn default_error_weight() -> f64 { + 10.0 +} +fn default_ewma_half_life() -> u64 { + 5000 +} +fn default_epsilon() -> f64 { + 0.05 +} + +impl Default for ReplicaSelectionConfig { + fn default() -> Self { + Self { + strategy: "adaptive".into(), + latency_weight: default_latency_weight(), + inflight_weight: default_inflight_weight(), + error_weight: default_error_weight(), + ewma_half_life_ms: default_ewma_half_life(), + exploration_epsilon: default_epsilon(), + } + } +} + +/// Per-node metrics for adaptive selection. +#[derive(Debug, Clone)] +pub struct NodeMetrics { + /// EWMA of latency p95 in milliseconds. + pub latency_p95_ms: f64, + /// Current in-flight request count. + pub in_flight: u32, + /// EWMA of error rate (0.0 to 1.0). + pub error_rate: f64, + /// EWMA half-life for updates. + pub half_life_ms: u64, + /// Last update timestamp. + pub last_updated: Instant, +} + +impl NodeMetrics { + /// Create new metrics with initial values. + pub fn new(initial_latency_ms: f64, half_life_ms: u64) -> Self { + Self { + latency_p95_ms: initial_latency_ms, + in_flight: 0, + error_rate: 0.0, + half_life_ms, + last_updated: Instant::now(), + } + } + + /// Update latency with EWMA smoothing. + pub fn update_latency(&mut self, latency_ms: f64) { + let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0); + self.latency_p95_ms = alpha * self.latency_p95_ms + (1.0 - alpha) * latency_ms; + self.last_updated = Instant::now(); + } + + /// Update error rate with EWMA smoothing. + pub fn update_error(&mut self, is_error: bool) { + let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0); + let new_error = if is_error { 1.0 } else { 0.0 }; + self.error_rate = alpha * self.error_rate + (1.0 - alpha) * new_error; + self.last_updated = Instant::now(); + } + + /// Increment in-flight count. + pub fn increment_in_flight(&mut self) { + self.in_flight += 1; + } + + /// Decrement in-flight count. + pub fn decrement_in_flight(&mut self) { + self.in_flight = self.in_flight.saturating_sub(1); + } + + /// Compute the composite score (lower is better). + pub fn score(&self, config: &ReplicaSelectionConfig) -> f64 { + config.latency_weight * self.latency_p95_ms + + config.inflight_weight * (self.in_flight as f64) + + config.error_weight * (self.error_rate * 1000.0) + } +} + +impl Default for NodeMetrics { + fn default() -> Self { + Self::new(50.0, 5000) + } +} + +/// Replica selector. +pub struct ReplicaSelector { + /// Configuration. + config: ReplicaSelectionConfig, + /// Per-node metrics. + metrics: Arc>>, + /// Round-robin counter (for round-robin strategy). + rr_counter: Arc>>, + /// Random number generator. + rng: Arc>, +} + +impl ReplicaSelector { + /// Create a new replica selector. + pub fn new(config: ReplicaSelectionConfig) -> Self { + Self { + config, + metrics: Arc::new(RwLock::new(HashMap::new())), + rr_counter: Arc::new(RwLock::new(HashMap::new())), + rng: Arc::new(std::sync::Mutex::new(StdRng::from_entropy())), + } + } + + /// Select a node from the given candidates. + /// + /// Returns the selected node ID, or None if candidates is empty. + pub async fn select(&self, candidates: &[NodeId], group_id: u32) -> Option { + if candidates.is_empty() { + return None; + } + + let strategy = self.parse_strategy(); + + match strategy { + SelectionStrategy::Adaptive => self.select_adaptive(candidates).await, + SelectionStrategy::RoundRobin => self.select_round_robin(candidates, group_id as u64).await, + SelectionStrategy::Random => self.select_random(candidates), + } + } + + /// Adaptive selection using EWMA scores. + async fn select_adaptive(&self, candidates: &[NodeId]) -> Option { + let metrics = self.metrics.read().await; + + // Exploration: with probability epsilon, pick randomly + if self.should_explore() { + return self.select_random(candidates); + } + + // Compute scores and find the minimum + let mut best_node = None; + let mut best_score = f64::INFINITY; + + for node in candidates { + let score = metrics + .get(node) + .map(|m| m.score(&self.config)) + .unwrap_or(1000.0); // High default for unknown nodes + + if score < best_score { + best_score = score; + best_node = Some(node.clone()); + } + } + + best_node + } + + /// Round-robin selection. + async fn select_round_robin(&self, candidates: &[NodeId], group_id: u64) -> Option { + let key = format!("group_{}", group_id); + let mut counter = self.rr_counter.write().await; + let idx = *counter.entry(key.clone()).or_insert(0) as usize % candidates.len(); + *counter.get_mut(&key).unwrap() += 1; + Some(candidates[idx].clone()) + } + + /// Random selection. + fn select_random(&self, candidates: &[NodeId]) -> Option { + if candidates.is_empty() { + return None; + } + let idx = self + .rng + .lock() + .unwrap() + .gen_range(0..candidates.len()); + Some(candidates[idx].clone()) + } + + /// Check if we should explore (random selection). + fn should_explore(&self) -> bool { + let mut rng = self.rng.lock().unwrap(); + rng.gen::() < self.config.exploration_epsilon + } + + /// Record a successful request (update latency). + pub async fn record_success(&self, node: &NodeId, latency_ms: f64) { + let mut metrics = self.metrics.write().await; + let entry = metrics + .entry(node.clone()) + .or_insert_with(NodeMetrics::default); + entry.update_latency(latency_ms); + entry.update_error(false); + entry.decrement_in_flight(); + } + + /// Record a failed request. + pub async fn record_error(&self, node: &NodeId, latency_ms: Option) { + let mut metrics = self.metrics.write().await; + let entry = metrics + .entry(node.clone()) + .or_insert_with(NodeMetrics::default); + if let Some(lat) = latency_ms { + entry.update_latency(lat); + } + entry.update_error(true); + entry.decrement_in_flight(); + } + + /// Record that a request is being sent to a node. + pub async fn record_request_start(&self, node: &NodeId) { + let mut metrics = self.metrics.write().await; + let entry = metrics + .entry(node.clone()) + .or_insert_with(NodeMetrics::default); + entry.increment_in_flight(); + } + + /// Get metrics for a node. + pub async fn get_metrics(&self, node: &NodeId) -> Option { + let metrics = self.metrics.read().await; + metrics.get(node).cloned() + } + + /// Parse the strategy from config string. + fn parse_strategy(&self) -> SelectionStrategy { + match self.config.strategy.as_str() { + "round_robin" => SelectionStrategy::RoundRobin, + "random" => SelectionStrategy::Random, + _ => SelectionStrategy::Adaptive, + } + } +} + +impl Default for ReplicaSelector { + fn default() -> Self { + Self::new(ReplicaSelectionConfig::default()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default() { + let config = ReplicaSelectionConfig::default(); + assert_eq!(config.strategy, "adaptive"); + assert_eq!(config.latency_weight, 1.0); + assert_eq!(config.inflight_weight, 2.0); + assert_eq!(config.error_weight, 10.0); + } + + #[test] + fn test_node_metrics_score() { + let mut metrics = NodeMetrics::new(50.0, 5000); + assert_eq!(metrics.score(&ReplicaSelectionConfig::default()), 50.0); + + metrics.in_flight = 5; + let score = metrics.score(&ReplicaSelectionConfig::default()); + // 50 * 1.0 + 5 * 2.0 = 60 + assert_eq!(score, 60.0); + } + + #[test] + fn test_node_metrics_ewma() { + let mut metrics = NodeMetrics::new(100.0, 1000); // Short half-life + + metrics.update_latency(50.0); + // Should move toward 50 + assert!(metrics.latency_p95_ms < 100.0 && metrics.latency_p95_ms > 40.0); + + metrics.update_error(true); + assert!(metrics.error_rate > 0.0); + + metrics.update_error(false); + // Error rate should decay + let rate_before = metrics.error_rate; + metrics.update_error(false); + assert!(metrics.error_rate < rate_before); + } + + #[tokio::test] + async fn test_select_adaptive() { + let selector = ReplicaSelector::new(ReplicaSelectionConfig::default()); + + let node1 = NodeId::new("node-1".to_string()); + let node2 = NodeId::new("node-2".to_string()); + + // Record some metrics + { + let mut metrics = selector.metrics.write().await; + metrics.insert( + node1.clone(), + NodeMetrics { + latency_p95_ms: 10.0, + in_flight: 0, + error_rate: 0.0, + half_life_ms: 5000, + last_updated: Instant::now(), + }, + ); + metrics.insert( + node2.clone(), + NodeMetrics { + latency_p95_ms: 100.0, + in_flight: 0, + error_rate: 0.0, + half_life_ms: 5000, + last_updated: Instant::now(), + }, + ); + } + + // Should select node-1 (lower score) + let candidates = vec![node2.clone(), node1.clone()]; + let selected = selector.select(&candidates, 0).await; + assert_eq!(selected, Some(node1)); + } + + #[tokio::test] + async fn test_select_round_robin() { + let config = ReplicaSelectionConfig { + strategy: "round_robin".into(), + ..Default::default() + }; + let selector = ReplicaSelector::new(config); + + let node1 = NodeId::new("node-1".to_string()); + let node2 = NodeId::new("node-2".to_string()); + + let candidates = vec![node1.clone(), node2.clone()]; + + // First call should return node-1 + let selected = selector.select(&candidates, 0).await; + assert_eq!(selected, Some(node1.clone())); + + // Second call should return node-2 + let selected = selector.select(&candidates, 0).await; + assert_eq!(selected, Some(node2.clone())); + + // Third call should wrap to node-1 + let selected = selector.select(&candidates, 0).await; + assert_eq!(selected, Some(node1)); + } + + #[tokio::test] + async fn test_record_request_lifecycle() { + let selector = ReplicaSelector::default(); + let node = NodeId::new("node-1".to_string()); + + selector.record_request_start(&node).await; + + let metrics = selector.get_metrics(&node).await; + assert!(metrics.is_some()); + assert_eq!(metrics.unwrap().in_flight, 1); + + selector.record_success(&node, 50.0).await; + + let metrics = selector.get_metrics(&node).await; + assert!(metrics.is_some()); + assert_eq!(metrics.unwrap().in_flight, 0); + } + + #[tokio::test] + async fn test_empty_candidates() { + let selector = ReplicaSelector::default(); + let selected = selector.select(&[], 0).await; + assert!(selected.is_none()); + } +} diff --git a/crates/miroir-core/src/vector.rs b/crates/miroir-core/src/vector.rs new file mode 100644 index 0000000..7eeea7d --- /dev/null +++ b/crates/miroir-core/src/vector.rs @@ -0,0 +1,388 @@ +//! Vector and hybrid search sharding (plan §13.12). +//! +//! Handles over-fetching and merging for vector/hybrid search across shards. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Vector search configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VectorSearchConfig { + /// Whether vector search is enabled. + #[serde(default = "default_true")] + pub enabled: bool, + /// Over-fetch factor (per-shard limit = requested limit × factor). + #[serde(default = "default_over_fetch")] + pub over_fetch_factor: u32, + /// Merge strategy: "convex" or "rrf". + #[serde(default = "default_merge_strategy")] + pub merge_strategy: String, + /// Default hybrid alpha (for convex combination). + #[serde(default = "default_alpha")] + pub hybrid_alpha_default: f64, + /// RRF constant (for Reciprocal Rank Fusion). + #[serde(default = "default_rrf_k")] + pub rrf_k: u32, +} + +fn default_true() -> bool { + true +} +fn default_over_fetch() -> u32 { + 3 +} +fn default_merge_strategy() -> String { + "convex".to_string() +} +fn default_alpha() -> f64 { + 0.5 +} +fn default_rrf_k() -> u32 { + 60 +} + +impl Default for VectorSearchConfig { + fn default() -> Self { + Self { + enabled: true, + over_fetch_factor: default_over_fetch(), + merge_strategy: default_merge_strategy(), + hybrid_alpha_default: default_alpha(), + rrf_k: default_rrf_k(), + } + } +} + +/// Merge strategy for combining results from multiple shards. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MergeStrategy { + /// Convex combination: (1 - α) · bm25 + α · semantic + Convex, + /// Reciprocal Rank Fusion + Rrf, +} + +impl MergeStrategy { + /// Parse from string. + pub fn from_str(s: &str) -> Option { + match s { + "convex" => Some(MergeStrategy::Convex), + "rrf" => Some(MergeStrategy::Rrf), + _ => None, + } + } +} + +/// A search hit with scores from multiple sources. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VectorHit { + /// Primary key of the document. + pub pk: String, + /// BM25 ranking score. + pub ranking_score: f64, + /// Semantic score (if present). + pub semantic_score: Option, + /// Combined global score. + pub combined_score: f64, + /// Source shard. + pub shard_id: u32, +} + +impl VectorHit { + /// Create a new hit from BM25 only. + pub fn bm25_only(pk: String, ranking_score: f64, shard_id: u32) -> Self { + Self { + pk, + ranking_score, + semantic_score: None, + combined_score: ranking_score, + shard_id, + } + } + + /// Create a new hybrid hit. + pub fn hybrid(pk: String, ranking_score: f64, semantic_score: f64, shard_id: u32) -> Self { + Self { + pk, + ranking_score, + semantic_score: Some(semantic_score), + combined_score: ranking_score, // Will be recomputed during merge + shard_id, + } + } + + /// Merge with convex combination. + pub fn merge_convex(&mut self, alpha: f64) { + if let Some(semantic) = self.semantic_score { + self.combined_score = (1.0 - alpha) * self.ranking_score + alpha * semantic; + } + } + + /// Get RRF score for a given rank. + pub fn rrf_score(rank: usize, k: u32) -> f64 { + 1.0 / (k as f64 + rank as f64) + } +} + +/// Vector search merger — combines over-fetched results from multiple shards. +pub struct VectorMerger { + /// Merge strategy. + strategy: MergeStrategy, + /// Hybrid alpha (for convex). + alpha: f64, + /// RRF constant. + rrf_k: u32, +} + +impl VectorMerger { + /// Create a new vector merger. + pub fn new(config: &VectorSearchConfig) -> Self { + let strategy = MergeStrategy::from_str(&config.merge_strategy) + .unwrap_or(MergeStrategy::Convex); + Self { + strategy, + alpha: config.hybrid_alpha_default, + rrf_k: config.rrf_k, + } + } + + /// Merge hits from multiple shards into a single ranked list. + /// + /// Input: Vec of (shard_id, hits from that shard) + /// Output: Globally ranked hits, truncated to `limit` + pub fn merge(&self, shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec { + match self.strategy { + MergeStrategy::Convex => self.merge_convex(shard_hits, limit), + MergeStrategy::Rrf => self.merge_rrf(shard_hits, limit), + } + } + + /// Convex combination merge. + fn merge_convex(&self, mut shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec { + // Apply convex combination to each hit + for (_, hit) in &mut shard_hits { + hit.merge_convex(self.alpha); + } + + // Sort by combined score descending + shard_hits.sort_by(|a, b| { + b.1.combined_score + .partial_cmp(&a.1.combined_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Deduplicate by PK (keep highest score) + let mut deduped = HashMap::new(); + for (_, hit) in shard_hits { + deduped + .entry(hit.pk.clone()) + .and_modify(|e: &mut VectorHit| { + if hit.combined_score > e.combined_score { + *e = hit; + } + }) + .or_insert(hit); + } + + // Convert back to vec and re-sort + let mut result: Vec<_> = deduped.into_values().collect(); + result.sort_by(|a, b| { + b.combined_score + .partial_cmp(&a.combined_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + result.truncate(limit); + result + } + + /// RRF (Reciprocal Rank Fusion) merge. + fn merge_rrf(&self, shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec { + // Group hits by PK and accumulate RRF scores + let mut rrf_scores: HashMap = HashMap::new(); + let mut hit_data: HashMap = HashMap::new(); + + // First, sort each shard's hits by their original ranking score + let mut per_shard: HashMap> = HashMap::new(); + for (shard_id, hit) in shard_hits { + per_shard + .entry(shard_id) + .or_insert_with(Vec::new) + .push(hit); + } + + // Compute RRF scores + for (_shard_id, mut hits) in per_shard { + // Sort by ranking_score descending (original per-shard ranking) + hits.sort_by(|a, b| { + b.ranking_score + .partial_cmp(&a.ranking_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + for (rank, hit) in hits.into_iter().enumerate() { + let pk = hit.pk.clone(); + let rrf_score = VectorHit::rrf_score(rank, self.rrf_k); + *rrf_scores.entry(pk.clone()).or_insert(0.0) += rrf_score; + + hit_data.entry(pk).or_insert(hit); + } + } + + // Build result with RRF scores + let mut result: Vec = hit_data + .into_iter() + .map(|(pk, mut hit)| { + hit.combined_score = *rrf_scores.get(&pk).unwrap_or(&0.0); + hit + }) + .collect(); + + // Sort by RRF score descending + result.sort_by(|a, b| { + b.combined_score + .partial_cmp(&a.combined_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + result.truncate(limit); + result + } + + /// Compute the per-shard limit given a requested limit. + pub fn per_shard_limit(&self, requested_limit: usize, over_fetch_factor: u32) -> usize { + requested_limit * over_fetch_factor as usize + } +} + +impl Default for VectorMerger { + fn default() -> Self { + Self::new(&VectorSearchConfig::default()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_defaults() { + let config = VectorSearchConfig::default(); + assert!(config.enabled); + assert_eq!(config.over_fetch_factor, 3); + assert_eq!(config.merge_strategy, "convex"); + assert_eq!(config.hybrid_alpha_default, 0.5); + assert_eq!(config.rrf_k, 60); + } + + #[test] + fn test_merge_strategy_from_str() { + assert_eq!(MergeStrategy::from_str("convex"), Some(MergeStrategy::Convex)); + assert_eq!(MergeStrategy::from_str("rrf"), Some(MergeStrategy::Rrf)); + assert_eq!(MergeStrategy::from_str("unknown"), None); + } + + #[test] + fn test_vector_hit_bm25_only() { + let hit = VectorHit::bm25_only("doc1".to_string(), 0.8, 5); + assert_eq!(hit.pk, "doc1"); + assert_eq!(hit.ranking_score, 0.8); + assert!(hit.semantic_score.is_none()); + assert_eq!(hit.combined_score, 0.8); + assert_eq!(hit.shard_id, 5); + } + + #[test] + fn test_vector_hit_hybrid() { + let mut hit = VectorHit::hybrid("doc1".to_string(), 0.6, 0.9, 5); + assert_eq!(hit.pk, "doc1"); + assert_eq!(hit.ranking_score, 0.6); + assert_eq!(hit.semantic_score, Some(0.9)); + assert_eq!(hit.shard_id, 5); + + hit.merge_convex(0.5); + assert!((hit.combined_score - 0.75).abs() < 0.001); + } + + #[test] + fn test_rrf_score() { + let score = VectorHit::rrf_score(0, 60); + assert!((score - 1.0 / 60.0).abs() < 0.0001); + + let score = VectorHit::rrf_score(10, 60); + assert!((score - 1.0 / 70.0).abs() < 0.0001); + } + + #[test] + fn test_merge_convex_basic() { + let merger = VectorMerger { + strategy: MergeStrategy::Convex, + alpha: 0.5, + rrf_k: 60, + }; + + let hits = vec![ + (0, VectorHit::hybrid("doc1".to_string(), 0.8, 0.6, 0)), + (0, VectorHit::hybrid("doc2".to_string(), 0.7, 0.9, 0)), + (1, VectorHit::hybrid("doc1".to_string(), 0.75, 0.65, 1)), + (1, VectorHit::hybrid("doc3".to_string(), 0.9, 0.5, 1)), + ]; + + let result = merger.merge_convex(hits, 10); + + // Should deduplicate doc1, keeping the highest combined score + assert_eq!(result.len(), 3); + assert_eq!(result[0].pk, "doc3"); // (0.9 + 0.5) / 2 = 0.7 + assert_eq!(result[1].pk, "doc2"); // (0.7 + 0.9) / 2 = 0.8 + assert_eq!(result[2].pk, "doc1"); // Should keep shard 0's version (0.8+0.6)/2=0.7 > (0.75+0.65)/2=0.7 + } + + #[test] + fn test_merge_rrf_basic() { + let merger = VectorMerger { + strategy: MergeStrategy::Rrf, + alpha: 0.5, + rrf_k: 60, + }; + + let hits = vec![ + (0, VectorHit::bm25_only("doc1".to_string(), 0.9, 0)), + (0, VectorHit::bm25_only("doc2".to_string(), 0.8, 0)), + (0, VectorHit::bm25_only("doc3".to_string(), 0.7, 0)), + (1, VectorHit::bm25_only("doc2".to_string(), 0.95, 1)), + (1, VectorHit::bm25_only("doc4".to_string(), 0.85, 1)), + ]; + + let result = merger.merge_rrf(hits, 10); + + // doc2 appears in both shards, gets summed RRF scores + assert!(result.iter().any(|h| h.pk == "doc2")); + let doc2 = result.iter().find(|h| h.pk == "doc2").unwrap(); + // Rank 1 in shard 0: 1/61, rank 1 in shard 1: 1/61 + assert!((doc2.combined_score - 2.0 / 61.0).abs() < 0.0001); + } + + #[test] + fn test_per_shard_limit() { + let merger = VectorMerger::default(); + assert_eq!(merger.per_shard_limit(10, 3), 30); + assert_eq!(merger.per_shard_limit(100, 2), 200); + } + + #[test] + fn test_merge_limits_output() { + let merger = VectorMerger::default(); + let hits: Vec<_> = (0..200) + .map(|i| { + ( + i % 10, + VectorHit::bm25_only(format!("doc{}", i), 1.0 - (i as f64) * 0.001, i % 10), + ) + }) + .collect(); + + let result = merger.merge_convex(hits, 50); + assert_eq!(result.len(), 50); + } +}