P3: Add Phase 3 advanced capability stub modules

Adds skeletal implementations for Phase 3 advanced capabilities
(§13.2-§13.12, §13.9) that will be fully implemented in later phases.

- hedging.rs (§13.2): Hedged request support structure
- query_planner.rs (§13.4): Shard-aware query planning interface
- replica_selection.rs (§13.3): Adaptive replica selection framework
- vector.rs (§13.12): Vector/hybrid search support types
- dump_import.rs (§13.9): Streaming dump import coordinator

These modules provide the type definitions and interfaces needed
by the task registry and persistence layer for multi-pod coordination
in Phase 6.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-05-03 13:31:05 -04:00
parent bd29c32688
commit ffb5ea8a3e
5 changed files with 1887 additions and 0 deletions

View file

@ -0,0 +1,392 @@
//! Streaming routed dump import (plan §13.9).
//!
//! Intercepts dump imports and routes each document to the correct shard
//! instead of broadcasting to all nodes.
use crate::error::{MiroirError, Result};
use crate::router::{shard_for_key, assign_shard_in_group};
use crate::topology::{Topology, NodeId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Dump import configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DumpImportConfig {
/// Import mode: "streaming" or "broadcast".
#[serde(default = "default_mode")]
pub mode: String,
/// Batch size for per-target POSTs.
#[serde(default = "default_batch_size")]
pub batch_size: u32,
/// Parallel target writes.
#[serde(default = "default_parallel")]
pub parallel_target_writes: u32,
/// Memory buffer cap (bytes).
#[serde(default = "default_memory_buffer")]
pub memory_buffer_bytes: u64,
/// Chunk size for Mode C coordinator.
#[serde(default = "default_chunk_size")]
pub chunk_size_bytes: u64,
}
fn default_mode() -> String {
"streaming".into()
}
fn default_batch_size() -> u32 {
1000
}
fn default_parallel() -> u32 {
8
}
fn default_memory_buffer() -> u64 {
134_217_728 // 128 MiB
}
fn default_chunk_size() -> u64 {
268_435_456 // 256 MiB
}
impl Default for DumpImportConfig {
fn default() -> Self {
Self {
mode: default_mode(),
batch_size: default_batch_size(),
parallel_target_writes: default_parallel(),
memory_buffer_bytes: default_memory_buffer(),
chunk_size_bytes: default_chunk_size(),
}
}
}
/// Dump import phase.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum DumpImportPhase {
/// No import in progress.
Idle = 0,
/// Reading and parsing dump.
Reading = 1,
/// Routing documents to target nodes.
Routing = 2,
/// Applying index settings.
ApplyingSettings = 3,
/// Completed successfully.
Complete = 4,
/// Failed with error.
Failed = 5,
}
/// Dump import status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DumpImportStatus {
/// Import ID.
pub id: String,
/// Target index UID.
pub index_uid: String,
/// Current phase.
pub phase: DumpImportPhase,
/// Documents processed so far.
pub documents_processed: u64,
/// Total documents (estimated).
pub total_documents: u64,
/// Bytes read so far.
pub bytes_read: u64,
/// Phase started at (UNIX ms).
pub phase_started_at: u64,
/// Error message if any.
pub error: Option<String>,
}
/// Dump import manager.
pub struct DumpImportManager {
/// Configuration.
config: DumpImportConfig,
/// Active imports (ID -> status).
active_imports: Arc<RwLock<HashMap<String, DumpImportStatus>>>,
/// Topology for routing.
topology: Arc<Topology>,
}
impl DumpImportManager {
/// Create a new dump import manager.
pub fn new(config: DumpImportConfig, topology: Arc<Topology>) -> Self {
Self {
config,
active_imports: Arc::new(RwLock::new(HashMap::new())),
topology,
}
}
/// Start a streaming dump import.
pub async fn start_import(
&self,
index_uid: String,
dump_data: Vec<u8>,
primary_key: String,
shard_count: u32,
) -> Result<String> {
if self.config.mode != "streaming" {
return Err(MiroirError::InvalidRequest(
"streaming dump import is disabled".into(),
));
}
let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4());
let now = millis_now();
// Create initial status
let status = DumpImportStatus {
id: import_id.clone(),
index_uid: index_uid.clone(),
phase: DumpImportPhase::Reading,
documents_processed: 0,
total_documents: 0,
bytes_read: 0,
phase_started_at: now,
error: None,
};
{
let mut imports = self.active_imports.write().await;
imports.insert(import_id.clone(), status);
}
// Clone import_id before moving into the async block
let import_id_for_spawn = import_id.clone();
// Spawn background import task
let imports = self.active_imports.clone();
let topology = self.topology.clone();
let config = self.config.clone();
tokio::spawn(async move {
if let Err(e) = Self::run_import(
&import_id_for_spawn,
index_uid,
dump_data,
primary_key,
shard_count,
topology,
config,
imports,
)
.await
{
tracing::error!("Dump import {} failed: {}", import_id_for_spawn, e);
}
});
Ok(import_id)
}
/// Get the status of an import.
pub async fn get_status(&self, import_id: &str) -> Option<DumpImportStatus> {
let imports = self.active_imports.read().await;
imports.get(import_id).cloned()
}
/// Run the import pipeline.
async fn run_import(
import_id: &str,
index_uid: String,
dump_data: Vec<u8>,
primary_key: String,
shard_count: u32,
topology: Arc<Topology>,
config: DumpImportConfig,
imports: Arc<RwLock<HashMap<String, DumpImportStatus>>>,
) -> Result<()> {
// Update phase to reading
Self::update_phase(&imports, import_id, DumpImportPhase::Reading).await;
// Parse NDJSON and route documents
let data_str = std::str::from_utf8(&dump_data)
.map_err(|e| MiroirError::InvalidRequest(format!("invalid UTF-8 in dump: {}", e)))?;
// Per-target buffers
let mut per_target_buffers: HashMap<(NodeId, u32), Vec<serde_json::Value>> =
HashMap::new();
let mut processed = 0u64;
let _total_estimate = 0u64;
for line in data_str.lines() {
if line.is_empty() {
continue;
}
let doc: serde_json::Value = serde_json::from_str(line).map_err(|e| {
MiroirError::InvalidRequest(format!("invalid JSON in dump: {}", e))
})?;
// Extract primary key value
let pk_value = doc
.get(&primary_key)
.and_then(|v| v.as_str())
.ok_or_else(|| {
MiroirError::InvalidRequest(format!(
"missing or invalid primary key field: {}",
primary_key
))
})?;
// Compute shard and route
let shard_id = shard_for_key(pk_value, shard_count);
// Get target nodes for this shard (assign across all replica groups)
let target_nodes: Vec<NodeId> = topology
.groups()
.flat_map(|group| assign_shard_in_group(shard_id, group.nodes(), topology.rf()))
.collect();
if target_nodes.is_empty() {
return Err(MiroirError::Topology(format!("no nodes for shard {}", shard_id)));
}
// Add to each target's buffer
for node in &target_nodes {
per_target_buffers
.entry((node.clone(), shard_id))
.or_insert_with(Vec::new)
.push(doc.clone());
}
processed += 1;
// Flush buffers when they reach batch size
if processed % config.batch_size as u64 == 0 {
Self::flush_buffers(
&index_uid,
&mut per_target_buffers,
&config,
&imports,
import_id,
processed,
)
.await?;
}
}
// Final flush
Self::flush_buffers(
&index_uid,
&mut per_target_buffers,
&config,
&imports,
import_id,
processed,
)
.await?;
// Mark complete
Self::update_phase(&imports, import_id, DumpImportPhase::Complete).await;
Ok(())
}
/// Flush buffered documents to target nodes.
async fn flush_buffers(
index_uid: &str,
buffers: &mut HashMap<(NodeId, u32), Vec<serde_json::Value>>,
_config: &DumpImportConfig,
imports: &Arc<RwLock<HashMap<String, DumpImportStatus>>>,
import_id: &str,
processed: u64,
) -> Result<()> {
for ((node, _shard), docs) in buffers.drain() {
if docs.is_empty() {
continue;
}
// POST documents to the node
// In a real implementation, this would use the HTTP client
tracing::debug!(
"Flushing {} documents to node {} for index {}",
docs.len(),
node,
index_uid
);
// Update status
let mut imports = imports.write().await;
if let Some(status) = imports.get_mut(import_id) {
status.documents_processed = processed;
}
}
Ok(())
}
/// Update the phase of an import.
async fn update_phase(
imports: &Arc<RwLock<HashMap<String, DumpImportStatus>>>,
import_id: &str,
phase: DumpImportPhase,
) {
let mut imports = imports.write().await;
if let Some(status) = imports.get_mut(import_id) {
status.phase = phase;
status.phase_started_at = millis_now();
}
}
}
/// Get current UNIX timestamp in milliseconds.
fn millis_now() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
impl Default for DumpImportManager {
fn default() -> Self {
Self::new(DumpImportConfig::default(), Arc::new(Topology::new(1, 1, 1)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = DumpImportConfig::default();
assert_eq!(config.mode, "streaming");
assert_eq!(config.batch_size, 1000);
assert_eq!(config.parallel_target_writes, 8);
}
#[test]
fn test_phase_serialization() {
let phase = DumpImportPhase::Routing;
let json = serde_json::to_string(&phase).unwrap();
assert_eq!(json, "\"Routing\"");
let deserialized: DumpImportPhase = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, DumpImportPhase::Routing);
}
#[tokio::test]
async fn test_get_status_nonexistent() {
let manager = DumpImportManager::default();
let status = manager.get_status("nonexistent").await;
assert!(status.is_none());
}
#[tokio::test]
async fn test_import_rejects_broadcast_mode() {
let config = DumpImportConfig {
mode: "broadcast".into(),
..Default::default()
};
let topology = Arc::new(Topology::new(64, 2, 1));
let manager = DumpImportManager::new(config, topology);
let result = manager
.start_import("products".into(), vec![1, 2, 3], "id".into(), 64)
.await;
assert!(result.is_err());
}
}

View file

@ -0,0 +1,319 @@
//! Hedged requests for tail-latency mitigation (plan §13.2).
//!
//! Issues duplicate requests to alternate replicas when a primary request
//! exceeds the p95 latency threshold.
use crate::error::{MiroirError, Result};
use crate::router::assign_shard_in_group;
use crate::topology::{NodeId, Topology};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::{sleep, Instant};
/// Hedging configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HedgingConfig {
/// Whether hedging is enabled.
#[serde(default = "default_true")]
pub enabled: bool,
/// P95 trigger multiplier (hedge at p95 * this).
#[serde(default = "default_multiplier")]
pub p95_trigger_multiplier: f64,
/// Minimum trigger time in milliseconds.
#[serde(default = "default_min_trigger")]
pub min_trigger_ms: u64,
/// Maximum hedges per query.
#[serde(default = "default_max_hedges")]
pub max_hedges_per_query: u32,
/// Allow falling back to another replica group.
#[serde(default = "default_true")]
pub cross_group_fallback: bool,
}
fn default_true() -> bool {
true
}
fn default_multiplier() -> f64 {
1.2
}
fn default_min_trigger() -> u64 {
15
}
fn default_max_hedges() -> u32 {
2
}
impl Default for HedgingConfig {
fn default() -> Self {
Self {
enabled: true,
p95_trigger_multiplier: default_multiplier(),
min_trigger_ms: default_min_trigger(),
max_hedges_per_query: default_max_hedges(),
cross_group_fallback: true,
}
}
}
/// Per-node latency tracking for p95 computation.
#[derive(Debug, Clone)]
pub struct NodeLatency {
/// EWMA-smoothed latency in milliseconds.
pub ewma_ms: f64,
/// Half-life for EWMA (milliseconds).
pub half_life_ms: u64,
}
impl NodeLatency {
/// Create a new latency tracker with initial value.
pub fn new(initial_ms: f64, half_life_ms: u64) -> Self {
Self {
ewma_ms: initial_ms,
half_life_ms,
}
}
/// Update with a new observation.
pub fn update(&mut self, latency_ms: f64) {
let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0);
self.ewma_ms = alpha * self.ewma_ms + (1.0 - alpha) * latency_ms;
}
/// Get the current p95 estimate (conservative: use EWMA directly).
pub fn p95_ms(&self) -> f64 {
self.ewma_ms
}
}
impl Default for NodeLatency {
fn default() -> Self {
Self::new(50.0, 5000)
}
}
/// Hedging manager.
pub struct HedgingManager {
/// Configuration.
config: HedgingConfig,
/// Per-node latency tracking.
node_latencies: Arc<RwLock<HashMap<NodeId, NodeLatency>>>,
/// Topology reference for finding alternate replicas.
topology: Arc<Topology>,
}
impl HedgingManager {
/// Create a new hedging manager.
pub fn new(config: HedgingConfig, topology: Arc<Topology>) -> Self {
Self {
config,
node_latencies: Arc::new(RwLock::new(HashMap::new())),
topology,
}
}
/// Record a latency observation for a node.
pub async fn record_latency(&self, node_id: &NodeId, latency_ms: f64) {
let mut latencies = self.node_latencies.write().await;
let entry = latencies.entry(node_id.clone()).or_insert_with(NodeLatency::default);
entry.update(latency_ms);
}
/// Get the p95 latency for a node.
pub async fn get_p95(&self, node_id: &NodeId) -> f64 {
let latencies = self.node_latencies.read().await;
latencies
.get(node_id)
.map(|l| l.p95_ms())
.unwrap_or(50.0)
}
/// Compute the hedge deadline for a request to the given node.
///
/// Returns None if hedging is disabled or the node has no latency data.
pub async fn hedge_deadline(&self, primary_node: &NodeId) -> Option<Duration> {
if !self.config.enabled {
return None;
}
let p95 = self.get_p95(primary_node).await;
let trigger_ms = (p95 * self.config.p95_trigger_multiplier).max(self.config.min_trigger_ms as f64);
Some(Duration::from_millis(trigger_ms as u64))
}
/// Find an alternate replica for hedging.
///
/// Returns None if:
/// - No alternate available
/// - Max hedges already issued
/// - Cross-group fallback disabled and no intra-group alternate
pub async fn find_alternate(
&self,
primary_node: &NodeId,
shard_id: u32,
hedge_count: u32,
) -> Option<NodeId> {
if hedge_count >= self.config.max_hedges_per_query {
return None;
}
// Get all nodes for this shard (assign across all replica groups)
let all_nodes: Vec<NodeId> = self.topology
.groups()
.flat_map(|group| assign_shard_in_group(shard_id, group.nodes(), self.topology.rf()))
.collect();
let primary_group = self.topology.node(primary_node)?.replica_group;
// First try: same group, different node
for node in &all_nodes {
if node != primary_node {
if let Some(n) = self.topology.node(node) {
if n.replica_group == primary_group {
return Some(node.clone());
}
}
}
}
// Fallback: different group (if enabled)
if self.config.cross_group_fallback {
for node in &all_nodes {
if node != primary_node {
return Some(node.clone());
}
}
}
None
}
/// Check if a hedge should be issued based on elapsed time.
pub fn should_hedge(&self, elapsed: Duration, deadline: Duration) -> bool {
elapsed >= deadline
}
}
impl Default for HedgingManager {
fn default() -> Self {
Self::new(HedgingConfig::default(), Arc::new(Topology::new(1, 1, 1)))
}
}
/// Hedge outcome for metrics.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HedgeOutcome {
/// Primary request won (hedge cancelled or never fired).
PrimaryWon,
/// Hedge request won (primary was slower).
HedgeWon,
/// Both completed at similar time.
Tie,
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_config_default() {
let config = HedgingConfig::default();
assert!(config.enabled);
assert_eq!(config.p95_trigger_multiplier, 1.2);
assert_eq!(config.min_trigger_ms, 15);
assert_eq!(config.max_hedges_per_query, 2);
assert!(config.cross_group_fallback);
}
#[test]
fn test_node_latency_ewma() {
let mut latency = NodeLatency::new(100.0, 1000);
assert_eq!(latency.ewma_ms, 100.0);
// Update with same value
latency.update(100.0);
// EWMA should move toward 100
assert!(latency.ewma_ms > 90.0 && latency.ewma_ms < 110.0);
// Update with much lower value
latency.update(10.0);
assert!(latency.ewma_ms < 100.0);
}
#[test]
fn test_hedge_deadline_computation() {
let topology = Arc::new(Topology::new(1, 1, 1));
let manager = HedgingManager::new(HedgingConfig::default(), topology);
let node = NodeId::new("node-1".to_string());
manager
.node_latencies
.try_write()
.unwrap()
.insert(node.clone(), NodeLatency::new(50.0, 5000));
let rt = tokio::runtime::Runtime::new().unwrap();
let deadline = rt.block_on(async { manager.hedge_deadline(&node).await });
assert!(deadline.is_some());
// 50ms * 1.2 = 60ms, but min is 15ms, so should be 60ms
assert_eq!(deadline.unwrap(), Duration::from_millis(60));
}
#[test]
fn test_hedge_deadline_respects_min() {
let topology = Arc::new(Topology::new(1, 1, 1));
let config = HedgingConfig {
p95_trigger_multiplier: 1.2,
min_trigger_ms: 100,
..Default::default()
};
let manager = HedgingManager::new(config, topology);
let node = NodeId::new("node-1".to_string());
manager
.node_latencies
.try_write()
.unwrap()
.insert(node.clone(), NodeLatency::new(10.0, 5000));
let rt = tokio::runtime::Runtime::new().unwrap();
let deadline = rt.block_on(async { manager.hedge_deadline(&node).await });
assert!(deadline.is_some());
// 10ms * 1.2 = 12ms, but min is 100ms
assert_eq!(deadline.unwrap(), Duration::from_millis(100));
}
#[test]
fn test_hedge_disabled() {
let config = HedgingConfig {
enabled: false,
..Default::default()
};
let topology = Arc::new(Topology::new(1, 1, 1));
let manager = HedgingManager::new(config, topology);
let node = NodeId::new("node-1".to_string());
let rt = tokio::runtime::Runtime::new().unwrap();
let deadline = rt.block_on(async { manager.hedge_deadline(&node).await });
assert!(deadline.is_none());
}
#[tokio::test]
async fn test_record_latency() {
let topology = Arc::new(Topology::new(1, 1, 1));
let manager = HedgingManager::new(HedgingConfig::default(), topology);
let node = NodeId::new("node-1".to_string());
manager.record_latency(&node, 100.0).await;
manager.record_latency(&node, 50.0).await;
let p95 = manager.get_p95(&node).await;
// EWMA should be between 50 and 100
assert!(p95 > 40.0 && p95 < 110.0);
}
}

View file

@ -0,0 +1,356 @@
//! Shard-aware query planner (plan §13.4).
//!
//! Parses filter expressions to determine if a query can be narrowed to
//! a subset of shards based on primary key constraints.
use crate::error::{MiroirError, Result};
use crate::router::shard_for_key;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
/// Query planner configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPlannerConfig {
/// Whether the query planner is enabled.
#[serde(default = "default_true")]
pub enabled: bool,
/// Maximum PK literals in a narrowable IN clause.
#[serde(default = "default_max_literals")]
pub max_pk_literals_narrowable: u32,
/// Whether to log query plans.
#[serde(default = "default_log_plans")]
pub log_plans: bool,
}
fn default_true() -> bool {
true
}
fn default_max_literals() -> u32 {
128
}
fn default_log_plans() -> bool {
false
}
impl Default for QueryPlannerConfig {
fn default() -> Self {
Self {
enabled: true,
max_pk_literals_narrowable: default_max_literals(),
log_plans: default_log_plans(),
}
}
}
/// Query plan result.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPlan {
/// Whether the query was narrowable.
pub narrowed: bool,
/// Reason for narrowing (or not).
pub reason: String,
/// Target shard IDs (empty if not narrowed).
pub target_shards: Vec<u32>,
/// Warnings generated during planning.
pub warnings: Vec<String>,
}
/// Query planner.
pub struct QueryPlanner {
/// Configuration.
config: QueryPlannerConfig,
/// Primary key field name for each index.
primary_keys: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
}
impl QueryPlanner {
/// Create a new query planner.
pub fn new(config: QueryPlannerConfig) -> Self {
Self {
config,
primary_keys: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
}
}
/// Set the primary key field for an index.
pub async fn set_primary_key(&self, index: String, pk_field: String) {
let mut pks = self.primary_keys.write().await;
pks.insert(index, pk_field);
}
/// Get the primary key field for an index.
pub async fn get_primary_key(&self, index: &str) -> Option<String> {
let pks = self.primary_keys.read().await;
pks.get(index).cloned()
}
/// Plan a query given its filter expression and index.
///
/// Returns a plan indicating whether the query can be narrowed to
/// a subset of shards.
pub async fn plan(
&self,
index: &str,
filter: &Option<String>,
shard_count: u32,
) -> QueryPlan {
if !self.config.enabled {
return QueryPlan {
narrowed: false,
reason: "query planner disabled".to_string(),
target_shards: vec![],
warnings: vec![],
};
}
let filter = match filter {
Some(f) => f,
None => {
return QueryPlan {
narrowed: false,
reason: "no filter specified".to_string(),
target_shards: vec![],
warnings: vec![],
}
}
};
// Try to parse the filter for PK constraints
let pk_field = match self.get_primary_key(index).await {
Some(pk) => pk,
None => {
return QueryPlan {
narrowed: false,
reason: "primary key not configured for index".to_string(),
target_shards: vec![],
warnings: vec![],
}
}
};
match self.parse_pk_constraints(filter, &pk_field) {
Ok(PkConstraint::Eq(literal)) => {
// Single PK equality -> narrow to 1 shard
let shard_id = shard_for_key(&literal, shard_count);
QueryPlan {
narrowed: true,
reason: format!("PK equality: {} = {}", pk_field, literal),
target_shards: vec![shard_id],
warnings: vec![],
}
}
Ok(PkConstraint::In(literals)) if literals.len() <= self.config.max_pk_literals_narrowable as usize => {
// PK IN list -> narrow to N shards
let mut shard_ids: HashSet<u32> = HashSet::new();
for literal in &literals {
shard_ids.insert(shard_for_key(literal, shard_count));
}
let mut shards: Vec<u32> = shard_ids.into_iter().collect();
shards.sort_unstable();
QueryPlan {
narrowed: true,
reason: format!("PK IN list: {} values", literals.len()),
target_shards: shards,
warnings: vec![],
}
}
Ok(PkConstraint::In(literals)) => {
// Too many literals for narrowing
QueryPlan {
narrowed: false,
reason: format!(
"PK IN list too large: {} values exceeds maximum of {}",
literals.len(),
self.config.max_pk_literals_narrowable
),
target_shards: vec![],
warnings: vec![],
}
}
Err(_) => {
QueryPlan {
narrowed: false,
reason: "filter not narrowable".to_string(),
target_shards: vec![],
warnings: vec![],
}
}
}
}
/// Parse a filter expression for PK constraints.
///
/// Returns the PK constraint if narrowable, or an error if not.
fn parse_pk_constraints(&self, filter: &str, pk_field: &str) -> Result<PkConstraint> {
// Simple regex-based parser for common patterns:
// 1. "{pk_field}" = "literal"
// 2. "{pk_field}" IN ["literal1", "literal2", ...]
let filter = filter.trim();
// Try equality: pk = "literal"
let eq_pattern = format!(r#"{}\s*=\s*["']([^"']+)["']"#, pk_field);
if let Some(re) = regex::Regex::new(&eq_pattern).ok() {
if let Some(caps) = re.captures(filter) {
if let Some(literal) = caps.get(1) {
return Ok(PkConstraint::Eq(literal.as_str().to_string()));
}
}
}
// Try IN list: pk IN ["literal1", "literal2", ...]
let in_pattern = format!(r#"{}\s+IN\s+\[(.+)\]"#, pk_field);
if let Some(re) = regex::Regex::new(&in_pattern).ok() {
if let Some(caps) = re.captures(filter) {
if let Some(list) = caps.get(1) {
let literals = self.parse_string_list(list.as_str())?;
return Ok(PkConstraint::In(literals));
}
}
}
// Check for non-narrowable patterns
if filter.contains(" OR ") {
return Err(MiroirError::InvalidState("contains OR at top level".to_string()));
}
if filter.contains(&format!("{} != ", pk_field)) || filter.contains(&format!("{}<>", pk_field)) {
return Err(MiroirError::InvalidState("PK negation is not narrowable".to_string()));
}
Err(MiroirError::InvalidState("no PK constraint found".to_string()))
}
/// Parse a comma-separated list of string literals.
fn parse_string_list(&self, input: &str) -> Result<Vec<String>> {
let mut result = Vec::new();
let mut current = String::new();
let mut in_string = false;
let mut escape = false;
for ch in input.chars() {
match ch {
'\\' if in_string => {
escape = true;
}
'"' if in_string && !escape => {
in_string = false;
result.push(current.clone());
current.clear();
}
'"' if !in_string => {
in_string = true;
}
',' if !in_string => {
// Skip
}
' ' | '\t' | '\n' if !in_string => {
// Skip whitespace
}
ch => {
current.push(ch);
escape = false;
}
}
}
Ok(result)
}
}
impl Default for QueryPlanner {
fn default() -> Self {
Self::new(QueryPlannerConfig::default())
}
}
/// Parsed PK constraint.
#[derive(Debug, Clone)]
enum PkConstraint {
/// Single equality: pk = "literal"
Eq(String),
/// IN list: pk IN ["a", "b", ...]
In(Vec<String>),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = QueryPlannerConfig::default();
assert!(config.enabled);
assert_eq!(config.max_pk_literals_narrowable, 128);
assert!(!config.log_plans);
}
#[tokio::test]
async fn test_plan_disabled() {
let config = QueryPlannerConfig {
enabled: false,
..Default::default()
};
let planner = QueryPlanner::new(config);
let plan = planner
.plan("products", &Some("sku = \"abc\"".to_string()), 64)
.await;
assert!(!plan.narrowed);
assert!(plan.reason.contains("disabled"));
}
#[tokio::test]
async fn test_plan_pk_equality() {
let planner = QueryPlanner::default();
planner.set_primary_key("products".into(), "sku".into()).await;
let plan = planner
.plan("products", &Some("sku = \"abc123\"".to_string()), 64)
.await;
assert!(plan.narrowed);
assert_eq!(plan.target_shards.len(), 1);
assert!(plan.reason.contains("PK equality"));
}
#[tokio::test]
async fn test_plan_no_filter() {
let planner = QueryPlanner::default();
let plan = planner.plan("products", &None, 64).await;
assert!(!plan.narrowed);
assert!(plan.reason.contains("no filter"));
}
#[tokio::test]
async fn test_plan_or_not_narrowable() {
let planner = QueryPlanner::default();
planner.set_primary_key("products".into(), "sku".into()).await;
let plan = planner
.plan(
"products",
&Some("sku = \"abc\" OR category = \"books\"".to_string()),
64,
)
.await;
assert!(!plan.narrowed);
assert!(plan.reason.contains("OR"));
}
#[tokio::test]
async fn test_plan_no_pk_configured() {
let planner = QueryPlanner::default();
let plan = planner
.plan("products", &Some("sku = \"abc\"".to_string()), 64)
.await;
assert!(!plan.narrowed);
assert!(plan.reason.contains("primary key not configured"));
}
}

View file

@ -0,0 +1,432 @@
//! Adaptive replica selection using EWMA scoring (plan §13.3).
//!
//! Replaces round-robin with latency-aware selection using EWMA-smoothed
//! metrics: latency p95, in-flight request count, and error rate.
use crate::error::{MiroirError, Result};
use crate::topology::{Group, NodeId};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
/// Replica selection strategy.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SelectionStrategy {
/// EWMA-based adaptive selection.
Adaptive,
/// Round-robin selection.
RoundRobin,
/// Random selection.
Random,
}
impl Default for SelectionStrategy {
fn default() -> Self {
Self::Adaptive
}
}
/// Replica selection configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplicaSelectionConfig {
/// Selection strategy.
#[serde(default)]
pub strategy: String,
/// Latency weight in score computation.
#[serde(default = "default_latency_weight")]
pub latency_weight: f64,
/// In-flight request weight.
#[serde(default = "default_inflight_weight")]
pub inflight_weight: f64,
/// Error rate weight.
#[serde(default = "default_error_weight")]
pub error_weight: f64,
/// EWMA half-life in milliseconds.
#[serde(default = "default_ewma_half_life")]
pub ewma_half_life_ms: u64,
/// Exploration epsilon (probability of random selection).
#[serde(default = "default_epsilon")]
pub exploration_epsilon: f64,
}
fn default_latency_weight() -> f64 {
1.0
}
fn default_inflight_weight() -> f64 {
2.0
}
fn default_error_weight() -> f64 {
10.0
}
fn default_ewma_half_life() -> u64 {
5000
}
fn default_epsilon() -> f64 {
0.05
}
impl Default for ReplicaSelectionConfig {
fn default() -> Self {
Self {
strategy: "adaptive".into(),
latency_weight: default_latency_weight(),
inflight_weight: default_inflight_weight(),
error_weight: default_error_weight(),
ewma_half_life_ms: default_ewma_half_life(),
exploration_epsilon: default_epsilon(),
}
}
}
/// Per-node metrics for adaptive selection.
#[derive(Debug, Clone)]
pub struct NodeMetrics {
/// EWMA of latency p95 in milliseconds.
pub latency_p95_ms: f64,
/// Current in-flight request count.
pub in_flight: u32,
/// EWMA of error rate (0.0 to 1.0).
pub error_rate: f64,
/// EWMA half-life for updates.
pub half_life_ms: u64,
/// Last update timestamp.
pub last_updated: Instant,
}
impl NodeMetrics {
/// Create new metrics with initial values.
pub fn new(initial_latency_ms: f64, half_life_ms: u64) -> Self {
Self {
latency_p95_ms: initial_latency_ms,
in_flight: 0,
error_rate: 0.0,
half_life_ms,
last_updated: Instant::now(),
}
}
/// Update latency with EWMA smoothing.
pub fn update_latency(&mut self, latency_ms: f64) {
let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0);
self.latency_p95_ms = alpha * self.latency_p95_ms + (1.0 - alpha) * latency_ms;
self.last_updated = Instant::now();
}
/// Update error rate with EWMA smoothing.
pub fn update_error(&mut self, is_error: bool) {
let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0);
let new_error = if is_error { 1.0 } else { 0.0 };
self.error_rate = alpha * self.error_rate + (1.0 - alpha) * new_error;
self.last_updated = Instant::now();
}
/// Increment in-flight count.
pub fn increment_in_flight(&mut self) {
self.in_flight += 1;
}
/// Decrement in-flight count.
pub fn decrement_in_flight(&mut self) {
self.in_flight = self.in_flight.saturating_sub(1);
}
/// Compute the composite score (lower is better).
pub fn score(&self, config: &ReplicaSelectionConfig) -> f64 {
config.latency_weight * self.latency_p95_ms
+ config.inflight_weight * (self.in_flight as f64)
+ config.error_weight * (self.error_rate * 1000.0)
}
}
impl Default for NodeMetrics {
fn default() -> Self {
Self::new(50.0, 5000)
}
}
/// Replica selector.
pub struct ReplicaSelector {
/// Configuration.
config: ReplicaSelectionConfig,
/// Per-node metrics.
metrics: Arc<RwLock<HashMap<NodeId, NodeMetrics>>>,
/// Round-robin counter (for round-robin strategy).
rr_counter: Arc<RwLock<HashMap<String, u64>>>,
/// Random number generator.
rng: Arc<std::sync::Mutex<StdRng>>,
}
impl ReplicaSelector {
/// Create a new replica selector.
pub fn new(config: ReplicaSelectionConfig) -> Self {
Self {
config,
metrics: Arc::new(RwLock::new(HashMap::new())),
rr_counter: Arc::new(RwLock::new(HashMap::new())),
rng: Arc::new(std::sync::Mutex::new(StdRng::from_entropy())),
}
}
/// Select a node from the given candidates.
///
/// Returns the selected node ID, or None if candidates is empty.
pub async fn select(&self, candidates: &[NodeId], group_id: u32) -> Option<NodeId> {
if candidates.is_empty() {
return None;
}
let strategy = self.parse_strategy();
match strategy {
SelectionStrategy::Adaptive => self.select_adaptive(candidates).await,
SelectionStrategy::RoundRobin => self.select_round_robin(candidates, group_id as u64).await,
SelectionStrategy::Random => self.select_random(candidates),
}
}
/// Adaptive selection using EWMA scores.
async fn select_adaptive(&self, candidates: &[NodeId]) -> Option<NodeId> {
let metrics = self.metrics.read().await;
// Exploration: with probability epsilon, pick randomly
if self.should_explore() {
return self.select_random(candidates);
}
// Compute scores and find the minimum
let mut best_node = None;
let mut best_score = f64::INFINITY;
for node in candidates {
let score = metrics
.get(node)
.map(|m| m.score(&self.config))
.unwrap_or(1000.0); // High default for unknown nodes
if score < best_score {
best_score = score;
best_node = Some(node.clone());
}
}
best_node
}
/// Round-robin selection.
async fn select_round_robin(&self, candidates: &[NodeId], group_id: u64) -> Option<NodeId> {
let key = format!("group_{}", group_id);
let mut counter = self.rr_counter.write().await;
let idx = *counter.entry(key.clone()).or_insert(0) as usize % candidates.len();
*counter.get_mut(&key).unwrap() += 1;
Some(candidates[idx].clone())
}
/// Random selection.
fn select_random(&self, candidates: &[NodeId]) -> Option<NodeId> {
if candidates.is_empty() {
return None;
}
let idx = self
.rng
.lock()
.unwrap()
.gen_range(0..candidates.len());
Some(candidates[idx].clone())
}
/// Check if we should explore (random selection).
fn should_explore(&self) -> bool {
let mut rng = self.rng.lock().unwrap();
rng.gen::<f64>() < self.config.exploration_epsilon
}
/// Record a successful request (update latency).
pub async fn record_success(&self, node: &NodeId, latency_ms: f64) {
let mut metrics = self.metrics.write().await;
let entry = metrics
.entry(node.clone())
.or_insert_with(NodeMetrics::default);
entry.update_latency(latency_ms);
entry.update_error(false);
entry.decrement_in_flight();
}
/// Record a failed request.
pub async fn record_error(&self, node: &NodeId, latency_ms: Option<f64>) {
let mut metrics = self.metrics.write().await;
let entry = metrics
.entry(node.clone())
.or_insert_with(NodeMetrics::default);
if let Some(lat) = latency_ms {
entry.update_latency(lat);
}
entry.update_error(true);
entry.decrement_in_flight();
}
/// Record that a request is being sent to a node.
pub async fn record_request_start(&self, node: &NodeId) {
let mut metrics = self.metrics.write().await;
let entry = metrics
.entry(node.clone())
.or_insert_with(NodeMetrics::default);
entry.increment_in_flight();
}
/// Get metrics for a node.
pub async fn get_metrics(&self, node: &NodeId) -> Option<NodeMetrics> {
let metrics = self.metrics.read().await;
metrics.get(node).cloned()
}
/// Parse the strategy from config string.
fn parse_strategy(&self) -> SelectionStrategy {
match self.config.strategy.as_str() {
"round_robin" => SelectionStrategy::RoundRobin,
"random" => SelectionStrategy::Random,
_ => SelectionStrategy::Adaptive,
}
}
}
impl Default for ReplicaSelector {
fn default() -> Self {
Self::new(ReplicaSelectionConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = ReplicaSelectionConfig::default();
assert_eq!(config.strategy, "adaptive");
assert_eq!(config.latency_weight, 1.0);
assert_eq!(config.inflight_weight, 2.0);
assert_eq!(config.error_weight, 10.0);
}
#[test]
fn test_node_metrics_score() {
let mut metrics = NodeMetrics::new(50.0, 5000);
assert_eq!(metrics.score(&ReplicaSelectionConfig::default()), 50.0);
metrics.in_flight = 5;
let score = metrics.score(&ReplicaSelectionConfig::default());
// 50 * 1.0 + 5 * 2.0 = 60
assert_eq!(score, 60.0);
}
#[test]
fn test_node_metrics_ewma() {
let mut metrics = NodeMetrics::new(100.0, 1000); // Short half-life
metrics.update_latency(50.0);
// Should move toward 50
assert!(metrics.latency_p95_ms < 100.0 && metrics.latency_p95_ms > 40.0);
metrics.update_error(true);
assert!(metrics.error_rate > 0.0);
metrics.update_error(false);
// Error rate should decay
let rate_before = metrics.error_rate;
metrics.update_error(false);
assert!(metrics.error_rate < rate_before);
}
#[tokio::test]
async fn test_select_adaptive() {
let selector = ReplicaSelector::new(ReplicaSelectionConfig::default());
let node1 = NodeId::new("node-1".to_string());
let node2 = NodeId::new("node-2".to_string());
// Record some metrics
{
let mut metrics = selector.metrics.write().await;
metrics.insert(
node1.clone(),
NodeMetrics {
latency_p95_ms: 10.0,
in_flight: 0,
error_rate: 0.0,
half_life_ms: 5000,
last_updated: Instant::now(),
},
);
metrics.insert(
node2.clone(),
NodeMetrics {
latency_p95_ms: 100.0,
in_flight: 0,
error_rate: 0.0,
half_life_ms: 5000,
last_updated: Instant::now(),
},
);
}
// Should select node-1 (lower score)
let candidates = vec![node2.clone(), node1.clone()];
let selected = selector.select(&candidates, 0).await;
assert_eq!(selected, Some(node1));
}
#[tokio::test]
async fn test_select_round_robin() {
let config = ReplicaSelectionConfig {
strategy: "round_robin".into(),
..Default::default()
};
let selector = ReplicaSelector::new(config);
let node1 = NodeId::new("node-1".to_string());
let node2 = NodeId::new("node-2".to_string());
let candidates = vec![node1.clone(), node2.clone()];
// First call should return node-1
let selected = selector.select(&candidates, 0).await;
assert_eq!(selected, Some(node1.clone()));
// Second call should return node-2
let selected = selector.select(&candidates, 0).await;
assert_eq!(selected, Some(node2.clone()));
// Third call should wrap to node-1
let selected = selector.select(&candidates, 0).await;
assert_eq!(selected, Some(node1));
}
#[tokio::test]
async fn test_record_request_lifecycle() {
let selector = ReplicaSelector::default();
let node = NodeId::new("node-1".to_string());
selector.record_request_start(&node).await;
let metrics = selector.get_metrics(&node).await;
assert!(metrics.is_some());
assert_eq!(metrics.unwrap().in_flight, 1);
selector.record_success(&node, 50.0).await;
let metrics = selector.get_metrics(&node).await;
assert!(metrics.is_some());
assert_eq!(metrics.unwrap().in_flight, 0);
}
#[tokio::test]
async fn test_empty_candidates() {
let selector = ReplicaSelector::default();
let selected = selector.select(&[], 0).await;
assert!(selected.is_none());
}
}

View file

@ -0,0 +1,388 @@
//! Vector and hybrid search sharding (plan §13.12).
//!
//! Handles over-fetching and merging for vector/hybrid search across shards.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Vector search configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorSearchConfig {
/// Whether vector search is enabled.
#[serde(default = "default_true")]
pub enabled: bool,
/// Over-fetch factor (per-shard limit = requested limit × factor).
#[serde(default = "default_over_fetch")]
pub over_fetch_factor: u32,
/// Merge strategy: "convex" or "rrf".
#[serde(default = "default_merge_strategy")]
pub merge_strategy: String,
/// Default hybrid alpha (for convex combination).
#[serde(default = "default_alpha")]
pub hybrid_alpha_default: f64,
/// RRF constant (for Reciprocal Rank Fusion).
#[serde(default = "default_rrf_k")]
pub rrf_k: u32,
}
fn default_true() -> bool {
true
}
fn default_over_fetch() -> u32 {
3
}
fn default_merge_strategy() -> String {
"convex".to_string()
}
fn default_alpha() -> f64 {
0.5
}
fn default_rrf_k() -> u32 {
60
}
impl Default for VectorSearchConfig {
fn default() -> Self {
Self {
enabled: true,
over_fetch_factor: default_over_fetch(),
merge_strategy: default_merge_strategy(),
hybrid_alpha_default: default_alpha(),
rrf_k: default_rrf_k(),
}
}
}
/// Merge strategy for combining results from multiple shards.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MergeStrategy {
/// Convex combination: (1 - α) · bm25 + α · semantic
Convex,
/// Reciprocal Rank Fusion
Rrf,
}
impl MergeStrategy {
/// Parse from string.
pub fn from_str(s: &str) -> Option<Self> {
match s {
"convex" => Some(MergeStrategy::Convex),
"rrf" => Some(MergeStrategy::Rrf),
_ => None,
}
}
}
/// A search hit with scores from multiple sources.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorHit {
/// Primary key of the document.
pub pk: String,
/// BM25 ranking score.
pub ranking_score: f64,
/// Semantic score (if present).
pub semantic_score: Option<f64>,
/// Combined global score.
pub combined_score: f64,
/// Source shard.
pub shard_id: u32,
}
impl VectorHit {
/// Create a new hit from BM25 only.
pub fn bm25_only(pk: String, ranking_score: f64, shard_id: u32) -> Self {
Self {
pk,
ranking_score,
semantic_score: None,
combined_score: ranking_score,
shard_id,
}
}
/// Create a new hybrid hit.
pub fn hybrid(pk: String, ranking_score: f64, semantic_score: f64, shard_id: u32) -> Self {
Self {
pk,
ranking_score,
semantic_score: Some(semantic_score),
combined_score: ranking_score, // Will be recomputed during merge
shard_id,
}
}
/// Merge with convex combination.
pub fn merge_convex(&mut self, alpha: f64) {
if let Some(semantic) = self.semantic_score {
self.combined_score = (1.0 - alpha) * self.ranking_score + alpha * semantic;
}
}
/// Get RRF score for a given rank.
pub fn rrf_score(rank: usize, k: u32) -> f64 {
1.0 / (k as f64 + rank as f64)
}
}
/// Vector search merger — combines over-fetched results from multiple shards.
pub struct VectorMerger {
/// Merge strategy.
strategy: MergeStrategy,
/// Hybrid alpha (for convex).
alpha: f64,
/// RRF constant.
rrf_k: u32,
}
impl VectorMerger {
/// Create a new vector merger.
pub fn new(config: &VectorSearchConfig) -> Self {
let strategy = MergeStrategy::from_str(&config.merge_strategy)
.unwrap_or(MergeStrategy::Convex);
Self {
strategy,
alpha: config.hybrid_alpha_default,
rrf_k: config.rrf_k,
}
}
/// Merge hits from multiple shards into a single ranked list.
///
/// Input: Vec of (shard_id, hits from that shard)
/// Output: Globally ranked hits, truncated to `limit`
pub fn merge(&self, shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec<VectorHit> {
match self.strategy {
MergeStrategy::Convex => self.merge_convex(shard_hits, limit),
MergeStrategy::Rrf => self.merge_rrf(shard_hits, limit),
}
}
/// Convex combination merge.
fn merge_convex(&self, mut shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec<VectorHit> {
// Apply convex combination to each hit
for (_, hit) in &mut shard_hits {
hit.merge_convex(self.alpha);
}
// Sort by combined score descending
shard_hits.sort_by(|a, b| {
b.1.combined_score
.partial_cmp(&a.1.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
// Deduplicate by PK (keep highest score)
let mut deduped = HashMap::new();
for (_, hit) in shard_hits {
deduped
.entry(hit.pk.clone())
.and_modify(|e: &mut VectorHit| {
if hit.combined_score > e.combined_score {
*e = hit;
}
})
.or_insert(hit);
}
// Convert back to vec and re-sort
let mut result: Vec<_> = deduped.into_values().collect();
result.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
result.truncate(limit);
result
}
/// RRF (Reciprocal Rank Fusion) merge.
fn merge_rrf(&self, shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec<VectorHit> {
// Group hits by PK and accumulate RRF scores
let mut rrf_scores: HashMap<String, f64> = HashMap::new();
let mut hit_data: HashMap<String, VectorHit> = HashMap::new();
// First, sort each shard's hits by their original ranking score
let mut per_shard: HashMap<u32, Vec<VectorHit>> = HashMap::new();
for (shard_id, hit) in shard_hits {
per_shard
.entry(shard_id)
.or_insert_with(Vec::new)
.push(hit);
}
// Compute RRF scores
for (_shard_id, mut hits) in per_shard {
// Sort by ranking_score descending (original per-shard ranking)
hits.sort_by(|a, b| {
b.ranking_score
.partial_cmp(&a.ranking_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (rank, hit) in hits.into_iter().enumerate() {
let pk = hit.pk.clone();
let rrf_score = VectorHit::rrf_score(rank, self.rrf_k);
*rrf_scores.entry(pk.clone()).or_insert(0.0) += rrf_score;
hit_data.entry(pk).or_insert(hit);
}
}
// Build result with RRF scores
let mut result: Vec<VectorHit> = hit_data
.into_iter()
.map(|(pk, mut hit)| {
hit.combined_score = *rrf_scores.get(&pk).unwrap_or(&0.0);
hit
})
.collect();
// Sort by RRF score descending
result.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
result.truncate(limit);
result
}
/// Compute the per-shard limit given a requested limit.
pub fn per_shard_limit(&self, requested_limit: usize, over_fetch_factor: u32) -> usize {
requested_limit * over_fetch_factor as usize
}
}
impl Default for VectorMerger {
fn default() -> Self {
Self::new(&VectorSearchConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = VectorSearchConfig::default();
assert!(config.enabled);
assert_eq!(config.over_fetch_factor, 3);
assert_eq!(config.merge_strategy, "convex");
assert_eq!(config.hybrid_alpha_default, 0.5);
assert_eq!(config.rrf_k, 60);
}
#[test]
fn test_merge_strategy_from_str() {
assert_eq!(MergeStrategy::from_str("convex"), Some(MergeStrategy::Convex));
assert_eq!(MergeStrategy::from_str("rrf"), Some(MergeStrategy::Rrf));
assert_eq!(MergeStrategy::from_str("unknown"), None);
}
#[test]
fn test_vector_hit_bm25_only() {
let hit = VectorHit::bm25_only("doc1".to_string(), 0.8, 5);
assert_eq!(hit.pk, "doc1");
assert_eq!(hit.ranking_score, 0.8);
assert!(hit.semantic_score.is_none());
assert_eq!(hit.combined_score, 0.8);
assert_eq!(hit.shard_id, 5);
}
#[test]
fn test_vector_hit_hybrid() {
let mut hit = VectorHit::hybrid("doc1".to_string(), 0.6, 0.9, 5);
assert_eq!(hit.pk, "doc1");
assert_eq!(hit.ranking_score, 0.6);
assert_eq!(hit.semantic_score, Some(0.9));
assert_eq!(hit.shard_id, 5);
hit.merge_convex(0.5);
assert!((hit.combined_score - 0.75).abs() < 0.001);
}
#[test]
fn test_rrf_score() {
let score = VectorHit::rrf_score(0, 60);
assert!((score - 1.0 / 60.0).abs() < 0.0001);
let score = VectorHit::rrf_score(10, 60);
assert!((score - 1.0 / 70.0).abs() < 0.0001);
}
#[test]
fn test_merge_convex_basic() {
let merger = VectorMerger {
strategy: MergeStrategy::Convex,
alpha: 0.5,
rrf_k: 60,
};
let hits = vec![
(0, VectorHit::hybrid("doc1".to_string(), 0.8, 0.6, 0)),
(0, VectorHit::hybrid("doc2".to_string(), 0.7, 0.9, 0)),
(1, VectorHit::hybrid("doc1".to_string(), 0.75, 0.65, 1)),
(1, VectorHit::hybrid("doc3".to_string(), 0.9, 0.5, 1)),
];
let result = merger.merge_convex(hits, 10);
// Should deduplicate doc1, keeping the highest combined score
assert_eq!(result.len(), 3);
assert_eq!(result[0].pk, "doc3"); // (0.9 + 0.5) / 2 = 0.7
assert_eq!(result[1].pk, "doc2"); // (0.7 + 0.9) / 2 = 0.8
assert_eq!(result[2].pk, "doc1"); // Should keep shard 0's version (0.8+0.6)/2=0.7 > (0.75+0.65)/2=0.7
}
#[test]
fn test_merge_rrf_basic() {
let merger = VectorMerger {
strategy: MergeStrategy::Rrf,
alpha: 0.5,
rrf_k: 60,
};
let hits = vec![
(0, VectorHit::bm25_only("doc1".to_string(), 0.9, 0)),
(0, VectorHit::bm25_only("doc2".to_string(), 0.8, 0)),
(0, VectorHit::bm25_only("doc3".to_string(), 0.7, 0)),
(1, VectorHit::bm25_only("doc2".to_string(), 0.95, 1)),
(1, VectorHit::bm25_only("doc4".to_string(), 0.85, 1)),
];
let result = merger.merge_rrf(hits, 10);
// doc2 appears in both shards, gets summed RRF scores
assert!(result.iter().any(|h| h.pk == "doc2"));
let doc2 = result.iter().find(|h| h.pk == "doc2").unwrap();
// Rank 1 in shard 0: 1/61, rank 1 in shard 1: 1/61
assert!((doc2.combined_score - 2.0 / 61.0).abs() < 0.0001);
}
#[test]
fn test_per_shard_limit() {
let merger = VectorMerger::default();
assert_eq!(merger.per_shard_limit(10, 3), 30);
assert_eq!(merger.per_shard_limit(100, 2), 200);
}
#[test]
fn test_merge_limits_output() {
let merger = VectorMerger::default();
let hits: Vec<_> = (0..200)
.map(|i| {
(
i % 10,
VectorHit::bm25_only(format!("doc{}", i), 1.0 - (i as f64) * 0.001, i % 10),
)
})
.collect();
let result = merger.merge_convex(hits, 50);
assert_eq!(result.len(), 50);
}
}