P3: Add Phase 3 advanced capability stub modules
Adds skeletal implementations for Phase 3 advanced capabilities (§13.2-§13.12, §13.9) that will be fully implemented in later phases. - hedging.rs (§13.2): Hedged request support structure - query_planner.rs (§13.4): Shard-aware query planning interface - replica_selection.rs (§13.3): Adaptive replica selection framework - vector.rs (§13.12): Vector/hybrid search support types - dump_import.rs (§13.9): Streaming dump import coordinator These modules provide the type definitions and interfaces needed by the task registry and persistence layer for multi-pod coordination in Phase 6. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
bd29c32688
commit
ffb5ea8a3e
5 changed files with 1887 additions and 0 deletions
392
crates/miroir-core/src/dump_import.rs
Normal file
392
crates/miroir-core/src/dump_import.rs
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
//! Streaming routed dump import (plan §13.9).
|
||||
//!
|
||||
//! Intercepts dump imports and routes each document to the correct shard
|
||||
//! instead of broadcasting to all nodes.
|
||||
|
||||
use crate::error::{MiroirError, Result};
|
||||
use crate::router::{shard_for_key, assign_shard_in_group};
|
||||
use crate::topology::{Topology, NodeId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Dump import configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DumpImportConfig {
|
||||
/// Import mode: "streaming" or "broadcast".
|
||||
#[serde(default = "default_mode")]
|
||||
pub mode: String,
|
||||
/// Batch size for per-target POSTs.
|
||||
#[serde(default = "default_batch_size")]
|
||||
pub batch_size: u32,
|
||||
/// Parallel target writes.
|
||||
#[serde(default = "default_parallel")]
|
||||
pub parallel_target_writes: u32,
|
||||
/// Memory buffer cap (bytes).
|
||||
#[serde(default = "default_memory_buffer")]
|
||||
pub memory_buffer_bytes: u64,
|
||||
/// Chunk size for Mode C coordinator.
|
||||
#[serde(default = "default_chunk_size")]
|
||||
pub chunk_size_bytes: u64,
|
||||
}
|
||||
|
||||
fn default_mode() -> String {
|
||||
"streaming".into()
|
||||
}
|
||||
fn default_batch_size() -> u32 {
|
||||
1000
|
||||
}
|
||||
fn default_parallel() -> u32 {
|
||||
8
|
||||
}
|
||||
fn default_memory_buffer() -> u64 {
|
||||
134_217_728 // 128 MiB
|
||||
}
|
||||
fn default_chunk_size() -> u64 {
|
||||
268_435_456 // 256 MiB
|
||||
}
|
||||
|
||||
impl Default for DumpImportConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: default_mode(),
|
||||
batch_size: default_batch_size(),
|
||||
parallel_target_writes: default_parallel(),
|
||||
memory_buffer_bytes: default_memory_buffer(),
|
||||
chunk_size_bytes: default_chunk_size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dump import phase.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
pub enum DumpImportPhase {
|
||||
/// No import in progress.
|
||||
Idle = 0,
|
||||
/// Reading and parsing dump.
|
||||
Reading = 1,
|
||||
/// Routing documents to target nodes.
|
||||
Routing = 2,
|
||||
/// Applying index settings.
|
||||
ApplyingSettings = 3,
|
||||
/// Completed successfully.
|
||||
Complete = 4,
|
||||
/// Failed with error.
|
||||
Failed = 5,
|
||||
}
|
||||
|
||||
/// Dump import status.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DumpImportStatus {
|
||||
/// Import ID.
|
||||
pub id: String,
|
||||
/// Target index UID.
|
||||
pub index_uid: String,
|
||||
/// Current phase.
|
||||
pub phase: DumpImportPhase,
|
||||
/// Documents processed so far.
|
||||
pub documents_processed: u64,
|
||||
/// Total documents (estimated).
|
||||
pub total_documents: u64,
|
||||
/// Bytes read so far.
|
||||
pub bytes_read: u64,
|
||||
/// Phase started at (UNIX ms).
|
||||
pub phase_started_at: u64,
|
||||
/// Error message if any.
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Dump import manager.
|
||||
pub struct DumpImportManager {
|
||||
/// Configuration.
|
||||
config: DumpImportConfig,
|
||||
/// Active imports (ID -> status).
|
||||
active_imports: Arc<RwLock<HashMap<String, DumpImportStatus>>>,
|
||||
/// Topology for routing.
|
||||
topology: Arc<Topology>,
|
||||
}
|
||||
|
||||
impl DumpImportManager {
|
||||
/// Create a new dump import manager.
|
||||
pub fn new(config: DumpImportConfig, topology: Arc<Topology>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
active_imports: Arc::new(RwLock::new(HashMap::new())),
|
||||
topology,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a streaming dump import.
|
||||
pub async fn start_import(
|
||||
&self,
|
||||
index_uid: String,
|
||||
dump_data: Vec<u8>,
|
||||
primary_key: String,
|
||||
shard_count: u32,
|
||||
) -> Result<String> {
|
||||
if self.config.mode != "streaming" {
|
||||
return Err(MiroirError::InvalidRequest(
|
||||
"streaming dump import is disabled".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4());
|
||||
let now = millis_now();
|
||||
|
||||
// Create initial status
|
||||
let status = DumpImportStatus {
|
||||
id: import_id.clone(),
|
||||
index_uid: index_uid.clone(),
|
||||
phase: DumpImportPhase::Reading,
|
||||
documents_processed: 0,
|
||||
total_documents: 0,
|
||||
bytes_read: 0,
|
||||
phase_started_at: now,
|
||||
error: None,
|
||||
};
|
||||
|
||||
{
|
||||
let mut imports = self.active_imports.write().await;
|
||||
imports.insert(import_id.clone(), status);
|
||||
}
|
||||
|
||||
// Clone import_id before moving into the async block
|
||||
let import_id_for_spawn = import_id.clone();
|
||||
|
||||
// Spawn background import task
|
||||
let imports = self.active_imports.clone();
|
||||
let topology = self.topology.clone();
|
||||
let config = self.config.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::run_import(
|
||||
&import_id_for_spawn,
|
||||
index_uid,
|
||||
dump_data,
|
||||
primary_key,
|
||||
shard_count,
|
||||
topology,
|
||||
config,
|
||||
imports,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Dump import {} failed: {}", import_id_for_spawn, e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(import_id)
|
||||
}
|
||||
|
||||
/// Get the status of an import.
|
||||
pub async fn get_status(&self, import_id: &str) -> Option<DumpImportStatus> {
|
||||
let imports = self.active_imports.read().await;
|
||||
imports.get(import_id).cloned()
|
||||
}
|
||||
|
||||
/// Run the import pipeline.
|
||||
async fn run_import(
|
||||
import_id: &str,
|
||||
index_uid: String,
|
||||
dump_data: Vec<u8>,
|
||||
primary_key: String,
|
||||
shard_count: u32,
|
||||
topology: Arc<Topology>,
|
||||
config: DumpImportConfig,
|
||||
imports: Arc<RwLock<HashMap<String, DumpImportStatus>>>,
|
||||
) -> Result<()> {
|
||||
// Update phase to reading
|
||||
Self::update_phase(&imports, import_id, DumpImportPhase::Reading).await;
|
||||
|
||||
// Parse NDJSON and route documents
|
||||
let data_str = std::str::from_utf8(&dump_data)
|
||||
.map_err(|e| MiroirError::InvalidRequest(format!("invalid UTF-8 in dump: {}", e)))?;
|
||||
|
||||
// Per-target buffers
|
||||
let mut per_target_buffers: HashMap<(NodeId, u32), Vec<serde_json::Value>> =
|
||||
HashMap::new();
|
||||
|
||||
let mut processed = 0u64;
|
||||
let _total_estimate = 0u64;
|
||||
|
||||
for line in data_str.lines() {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let doc: serde_json::Value = serde_json::from_str(line).map_err(|e| {
|
||||
MiroirError::InvalidRequest(format!("invalid JSON in dump: {}", e))
|
||||
})?;
|
||||
|
||||
// Extract primary key value
|
||||
let pk_value = doc
|
||||
.get(&primary_key)
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
MiroirError::InvalidRequest(format!(
|
||||
"missing or invalid primary key field: {}",
|
||||
primary_key
|
||||
))
|
||||
})?;
|
||||
|
||||
// Compute shard and route
|
||||
let shard_id = shard_for_key(pk_value, shard_count);
|
||||
|
||||
// Get target nodes for this shard (assign across all replica groups)
|
||||
let target_nodes: Vec<NodeId> = topology
|
||||
.groups()
|
||||
.flat_map(|group| assign_shard_in_group(shard_id, group.nodes(), topology.rf()))
|
||||
.collect();
|
||||
|
||||
if target_nodes.is_empty() {
|
||||
return Err(MiroirError::Topology(format!("no nodes for shard {}", shard_id)));
|
||||
}
|
||||
|
||||
// Add to each target's buffer
|
||||
for node in &target_nodes {
|
||||
per_target_buffers
|
||||
.entry((node.clone(), shard_id))
|
||||
.or_insert_with(Vec::new)
|
||||
.push(doc.clone());
|
||||
}
|
||||
|
||||
processed += 1;
|
||||
|
||||
// Flush buffers when they reach batch size
|
||||
if processed % config.batch_size as u64 == 0 {
|
||||
Self::flush_buffers(
|
||||
&index_uid,
|
||||
&mut per_target_buffers,
|
||||
&config,
|
||||
&imports,
|
||||
import_id,
|
||||
processed,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Final flush
|
||||
Self::flush_buffers(
|
||||
&index_uid,
|
||||
&mut per_target_buffers,
|
||||
&config,
|
||||
&imports,
|
||||
import_id,
|
||||
processed,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Mark complete
|
||||
Self::update_phase(&imports, import_id, DumpImportPhase::Complete).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush buffered documents to target nodes.
|
||||
async fn flush_buffers(
|
||||
index_uid: &str,
|
||||
buffers: &mut HashMap<(NodeId, u32), Vec<serde_json::Value>>,
|
||||
_config: &DumpImportConfig,
|
||||
imports: &Arc<RwLock<HashMap<String, DumpImportStatus>>>,
|
||||
import_id: &str,
|
||||
processed: u64,
|
||||
) -> Result<()> {
|
||||
for ((node, _shard), docs) in buffers.drain() {
|
||||
if docs.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// POST documents to the node
|
||||
// In a real implementation, this would use the HTTP client
|
||||
tracing::debug!(
|
||||
"Flushing {} documents to node {} for index {}",
|
||||
docs.len(),
|
||||
node,
|
||||
index_uid
|
||||
);
|
||||
|
||||
// Update status
|
||||
let mut imports = imports.write().await;
|
||||
if let Some(status) = imports.get_mut(import_id) {
|
||||
status.documents_processed = processed;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the phase of an import.
|
||||
async fn update_phase(
|
||||
imports: &Arc<RwLock<HashMap<String, DumpImportStatus>>>,
|
||||
import_id: &str,
|
||||
phase: DumpImportPhase,
|
||||
) {
|
||||
let mut imports = imports.write().await;
|
||||
if let Some(status) = imports.get_mut(import_id) {
|
||||
status.phase = phase;
|
||||
status.phase_started_at = millis_now();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current UNIX timestamp in milliseconds.
|
||||
fn millis_now() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
impl Default for DumpImportManager {
|
||||
fn default() -> Self {
|
||||
Self::new(DumpImportConfig::default(), Arc::new(Topology::new(1, 1, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = DumpImportConfig::default();
|
||||
assert_eq!(config.mode, "streaming");
|
||||
assert_eq!(config.batch_size, 1000);
|
||||
assert_eq!(config.parallel_target_writes, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_phase_serialization() {
|
||||
let phase = DumpImportPhase::Routing;
|
||||
let json = serde_json::to_string(&phase).unwrap();
|
||||
assert_eq!(json, "\"Routing\"");
|
||||
|
||||
let deserialized: DumpImportPhase = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized, DumpImportPhase::Routing);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_status_nonexistent() {
|
||||
let manager = DumpImportManager::default();
|
||||
let status = manager.get_status("nonexistent").await;
|
||||
assert!(status.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_import_rejects_broadcast_mode() {
|
||||
let config = DumpImportConfig {
|
||||
mode: "broadcast".into(),
|
||||
..Default::default()
|
||||
};
|
||||
let topology = Arc::new(Topology::new(64, 2, 1));
|
||||
let manager = DumpImportManager::new(config, topology);
|
||||
|
||||
let result = manager
|
||||
.start_import("products".into(), vec![1, 2, 3], "id".into(), 64)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
319
crates/miroir-core/src/hedging.rs
Normal file
319
crates/miroir-core/src/hedging.rs
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
//! Hedged requests for tail-latency mitigation (plan §13.2).
|
||||
//!
|
||||
//! Issues duplicate requests to alternate replicas when a primary request
|
||||
//! exceeds the p95 latency threshold.
|
||||
|
||||
use crate::error::{MiroirError, Result};
|
||||
use crate::router::assign_shard_in_group;
|
||||
use crate::topology::{NodeId, Topology};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{sleep, Instant};
|
||||
|
||||
/// Hedging configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HedgingConfig {
|
||||
/// Whether hedging is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
/// P95 trigger multiplier (hedge at p95 * this).
|
||||
#[serde(default = "default_multiplier")]
|
||||
pub p95_trigger_multiplier: f64,
|
||||
/// Minimum trigger time in milliseconds.
|
||||
#[serde(default = "default_min_trigger")]
|
||||
pub min_trigger_ms: u64,
|
||||
/// Maximum hedges per query.
|
||||
#[serde(default = "default_max_hedges")]
|
||||
pub max_hedges_per_query: u32,
|
||||
/// Allow falling back to another replica group.
|
||||
#[serde(default = "default_true")]
|
||||
pub cross_group_fallback: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_multiplier() -> f64 {
|
||||
1.2
|
||||
}
|
||||
fn default_min_trigger() -> u64 {
|
||||
15
|
||||
}
|
||||
fn default_max_hedges() -> u32 {
|
||||
2
|
||||
}
|
||||
|
||||
impl Default for HedgingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
p95_trigger_multiplier: default_multiplier(),
|
||||
min_trigger_ms: default_min_trigger(),
|
||||
max_hedges_per_query: default_max_hedges(),
|
||||
cross_group_fallback: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-node latency tracking for p95 computation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NodeLatency {
|
||||
/// EWMA-smoothed latency in milliseconds.
|
||||
pub ewma_ms: f64,
|
||||
/// Half-life for EWMA (milliseconds).
|
||||
pub half_life_ms: u64,
|
||||
}
|
||||
|
||||
impl NodeLatency {
|
||||
/// Create a new latency tracker with initial value.
|
||||
pub fn new(initial_ms: f64, half_life_ms: u64) -> Self {
|
||||
Self {
|
||||
ewma_ms: initial_ms,
|
||||
half_life_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update with a new observation.
|
||||
pub fn update(&mut self, latency_ms: f64) {
|
||||
let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0);
|
||||
self.ewma_ms = alpha * self.ewma_ms + (1.0 - alpha) * latency_ms;
|
||||
}
|
||||
|
||||
/// Get the current p95 estimate (conservative: use EWMA directly).
|
||||
pub fn p95_ms(&self) -> f64 {
|
||||
self.ewma_ms
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeLatency {
|
||||
fn default() -> Self {
|
||||
Self::new(50.0, 5000)
|
||||
}
|
||||
}
|
||||
|
||||
/// Hedging manager.
|
||||
pub struct HedgingManager {
|
||||
/// Configuration.
|
||||
config: HedgingConfig,
|
||||
/// Per-node latency tracking.
|
||||
node_latencies: Arc<RwLock<HashMap<NodeId, NodeLatency>>>,
|
||||
/// Topology reference for finding alternate replicas.
|
||||
topology: Arc<Topology>,
|
||||
}
|
||||
|
||||
impl HedgingManager {
|
||||
/// Create a new hedging manager.
|
||||
pub fn new(config: HedgingConfig, topology: Arc<Topology>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
node_latencies: Arc::new(RwLock::new(HashMap::new())),
|
||||
topology,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a latency observation for a node.
|
||||
pub async fn record_latency(&self, node_id: &NodeId, latency_ms: f64) {
|
||||
let mut latencies = self.node_latencies.write().await;
|
||||
let entry = latencies.entry(node_id.clone()).or_insert_with(NodeLatency::default);
|
||||
entry.update(latency_ms);
|
||||
}
|
||||
|
||||
/// Get the p95 latency for a node.
|
||||
pub async fn get_p95(&self, node_id: &NodeId) -> f64 {
|
||||
let latencies = self.node_latencies.read().await;
|
||||
latencies
|
||||
.get(node_id)
|
||||
.map(|l| l.p95_ms())
|
||||
.unwrap_or(50.0)
|
||||
}
|
||||
|
||||
/// Compute the hedge deadline for a request to the given node.
|
||||
///
|
||||
/// Returns None if hedging is disabled or the node has no latency data.
|
||||
pub async fn hedge_deadline(&self, primary_node: &NodeId) -> Option<Duration> {
|
||||
if !self.config.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let p95 = self.get_p95(primary_node).await;
|
||||
let trigger_ms = (p95 * self.config.p95_trigger_multiplier).max(self.config.min_trigger_ms as f64);
|
||||
Some(Duration::from_millis(trigger_ms as u64))
|
||||
}
|
||||
|
||||
/// Find an alternate replica for hedging.
|
||||
///
|
||||
/// Returns None if:
|
||||
/// - No alternate available
|
||||
/// - Max hedges already issued
|
||||
/// - Cross-group fallback disabled and no intra-group alternate
|
||||
pub async fn find_alternate(
|
||||
&self,
|
||||
primary_node: &NodeId,
|
||||
shard_id: u32,
|
||||
hedge_count: u32,
|
||||
) -> Option<NodeId> {
|
||||
if hedge_count >= self.config.max_hedges_per_query {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get all nodes for this shard (assign across all replica groups)
|
||||
let all_nodes: Vec<NodeId> = self.topology
|
||||
.groups()
|
||||
.flat_map(|group| assign_shard_in_group(shard_id, group.nodes(), self.topology.rf()))
|
||||
.collect();
|
||||
let primary_group = self.topology.node(primary_node)?.replica_group;
|
||||
|
||||
// First try: same group, different node
|
||||
for node in &all_nodes {
|
||||
if node != primary_node {
|
||||
if let Some(n) = self.topology.node(node) {
|
||||
if n.replica_group == primary_group {
|
||||
return Some(node.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: different group (if enabled)
|
||||
if self.config.cross_group_fallback {
|
||||
for node in &all_nodes {
|
||||
if node != primary_node {
|
||||
return Some(node.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a hedge should be issued based on elapsed time.
|
||||
pub fn should_hedge(&self, elapsed: Duration, deadline: Duration) -> bool {
|
||||
elapsed >= deadline
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HedgingManager {
|
||||
fn default() -> Self {
|
||||
Self::new(HedgingConfig::default(), Arc::new(Topology::new(1, 1, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Hedge outcome for metrics.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HedgeOutcome {
|
||||
/// Primary request won (hedge cancelled or never fired).
|
||||
PrimaryWon,
|
||||
/// Hedge request won (primary was slower).
|
||||
HedgeWon,
|
||||
/// Both completed at similar time.
|
||||
Tie,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = HedgingConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.p95_trigger_multiplier, 1.2);
|
||||
assert_eq!(config.min_trigger_ms, 15);
|
||||
assert_eq!(config.max_hedges_per_query, 2);
|
||||
assert!(config.cross_group_fallback);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_latency_ewma() {
|
||||
let mut latency = NodeLatency::new(100.0, 1000);
|
||||
assert_eq!(latency.ewma_ms, 100.0);
|
||||
|
||||
// Update with same value
|
||||
latency.update(100.0);
|
||||
// EWMA should move toward 100
|
||||
assert!(latency.ewma_ms > 90.0 && latency.ewma_ms < 110.0);
|
||||
|
||||
// Update with much lower value
|
||||
latency.update(10.0);
|
||||
assert!(latency.ewma_ms < 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hedge_deadline_computation() {
|
||||
let topology = Arc::new(Topology::new(1, 1, 1));
|
||||
let manager = HedgingManager::new(HedgingConfig::default(), topology);
|
||||
|
||||
let node = NodeId::new("node-1".to_string());
|
||||
manager
|
||||
.node_latencies
|
||||
.try_write()
|
||||
.unwrap()
|
||||
.insert(node.clone(), NodeLatency::new(50.0, 5000));
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let deadline = rt.block_on(async { manager.hedge_deadline(&node).await });
|
||||
|
||||
assert!(deadline.is_some());
|
||||
// 50ms * 1.2 = 60ms, but min is 15ms, so should be 60ms
|
||||
assert_eq!(deadline.unwrap(), Duration::from_millis(60));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hedge_deadline_respects_min() {
|
||||
let topology = Arc::new(Topology::new(1, 1, 1));
|
||||
let config = HedgingConfig {
|
||||
p95_trigger_multiplier: 1.2,
|
||||
min_trigger_ms: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let manager = HedgingManager::new(config, topology);
|
||||
|
||||
let node = NodeId::new("node-1".to_string());
|
||||
manager
|
||||
.node_latencies
|
||||
.try_write()
|
||||
.unwrap()
|
||||
.insert(node.clone(), NodeLatency::new(10.0, 5000));
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let deadline = rt.block_on(async { manager.hedge_deadline(&node).await });
|
||||
|
||||
assert!(deadline.is_some());
|
||||
// 10ms * 1.2 = 12ms, but min is 100ms
|
||||
assert_eq!(deadline.unwrap(), Duration::from_millis(100));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hedge_disabled() {
|
||||
let config = HedgingConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
};
|
||||
let topology = Arc::new(Topology::new(1, 1, 1));
|
||||
let manager = HedgingManager::new(config, topology);
|
||||
|
||||
let node = NodeId::new("node-1".to_string());
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let deadline = rt.block_on(async { manager.hedge_deadline(&node).await });
|
||||
|
||||
assert!(deadline.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_record_latency() {
|
||||
let topology = Arc::new(Topology::new(1, 1, 1));
|
||||
let manager = HedgingManager::new(HedgingConfig::default(), topology);
|
||||
|
||||
let node = NodeId::new("node-1".to_string());
|
||||
manager.record_latency(&node, 100.0).await;
|
||||
manager.record_latency(&node, 50.0).await;
|
||||
|
||||
let p95 = manager.get_p95(&node).await;
|
||||
// EWMA should be between 50 and 100
|
||||
assert!(p95 > 40.0 && p95 < 110.0);
|
||||
}
|
||||
}
|
||||
356
crates/miroir-core/src/query_planner.rs
Normal file
356
crates/miroir-core/src/query_planner.rs
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
//! Shard-aware query planner (plan §13.4).
|
||||
//!
|
||||
//! Parses filter expressions to determine if a query can be narrowed to
|
||||
//! a subset of shards based on primary key constraints.
|
||||
|
||||
use crate::error::{MiroirError, Result};
|
||||
use crate::router::shard_for_key;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Query planner configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryPlannerConfig {
|
||||
/// Whether the query planner is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
/// Maximum PK literals in a narrowable IN clause.
|
||||
#[serde(default = "default_max_literals")]
|
||||
pub max_pk_literals_narrowable: u32,
|
||||
/// Whether to log query plans.
|
||||
#[serde(default = "default_log_plans")]
|
||||
pub log_plans: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_max_literals() -> u32 {
|
||||
128
|
||||
}
|
||||
fn default_log_plans() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
impl Default for QueryPlannerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
max_pk_literals_narrowable: default_max_literals(),
|
||||
log_plans: default_log_plans(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query plan result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QueryPlan {
|
||||
/// Whether the query was narrowable.
|
||||
pub narrowed: bool,
|
||||
/// Reason for narrowing (or not).
|
||||
pub reason: String,
|
||||
/// Target shard IDs (empty if not narrowed).
|
||||
pub target_shards: Vec<u32>,
|
||||
/// Warnings generated during planning.
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
/// Query planner.
|
||||
pub struct QueryPlanner {
|
||||
/// Configuration.
|
||||
config: QueryPlannerConfig,
|
||||
/// Primary key field name for each index.
|
||||
primary_keys: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
|
||||
}
|
||||
|
||||
impl QueryPlanner {
|
||||
/// Create a new query planner.
|
||||
pub fn new(config: QueryPlannerConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
primary_keys: std::sync::Arc::new(tokio::sync::RwLock::new(
|
||||
std::collections::HashMap::new(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the primary key field for an index.
|
||||
pub async fn set_primary_key(&self, index: String, pk_field: String) {
|
||||
let mut pks = self.primary_keys.write().await;
|
||||
pks.insert(index, pk_field);
|
||||
}
|
||||
|
||||
/// Get the primary key field for an index.
|
||||
pub async fn get_primary_key(&self, index: &str) -> Option<String> {
|
||||
let pks = self.primary_keys.read().await;
|
||||
pks.get(index).cloned()
|
||||
}
|
||||
|
||||
/// Plan a query given its filter expression and index.
|
||||
///
|
||||
/// Returns a plan indicating whether the query can be narrowed to
|
||||
/// a subset of shards.
|
||||
pub async fn plan(
|
||||
&self,
|
||||
index: &str,
|
||||
filter: &Option<String>,
|
||||
shard_count: u32,
|
||||
) -> QueryPlan {
|
||||
if !self.config.enabled {
|
||||
return QueryPlan {
|
||||
narrowed: false,
|
||||
reason: "query planner disabled".to_string(),
|
||||
target_shards: vec![],
|
||||
warnings: vec![],
|
||||
};
|
||||
}
|
||||
|
||||
let filter = match filter {
|
||||
Some(f) => f,
|
||||
None => {
|
||||
return QueryPlan {
|
||||
narrowed: false,
|
||||
reason: "no filter specified".to_string(),
|
||||
target_shards: vec![],
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Try to parse the filter for PK constraints
|
||||
let pk_field = match self.get_primary_key(index).await {
|
||||
Some(pk) => pk,
|
||||
None => {
|
||||
return QueryPlan {
|
||||
narrowed: false,
|
||||
reason: "primary key not configured for index".to_string(),
|
||||
target_shards: vec![],
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match self.parse_pk_constraints(filter, &pk_field) {
|
||||
Ok(PkConstraint::Eq(literal)) => {
|
||||
// Single PK equality -> narrow to 1 shard
|
||||
let shard_id = shard_for_key(&literal, shard_count);
|
||||
QueryPlan {
|
||||
narrowed: true,
|
||||
reason: format!("PK equality: {} = {}", pk_field, literal),
|
||||
target_shards: vec![shard_id],
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
Ok(PkConstraint::In(literals)) if literals.len() <= self.config.max_pk_literals_narrowable as usize => {
|
||||
// PK IN list -> narrow to N shards
|
||||
let mut shard_ids: HashSet<u32> = HashSet::new();
|
||||
for literal in &literals {
|
||||
shard_ids.insert(shard_for_key(literal, shard_count));
|
||||
}
|
||||
let mut shards: Vec<u32> = shard_ids.into_iter().collect();
|
||||
shards.sort_unstable();
|
||||
QueryPlan {
|
||||
narrowed: true,
|
||||
reason: format!("PK IN list: {} values", literals.len()),
|
||||
target_shards: shards,
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
Ok(PkConstraint::In(literals)) => {
|
||||
// Too many literals for narrowing
|
||||
QueryPlan {
|
||||
narrowed: false,
|
||||
reason: format!(
|
||||
"PK IN list too large: {} values exceeds maximum of {}",
|
||||
literals.len(),
|
||||
self.config.max_pk_literals_narrowable
|
||||
),
|
||||
target_shards: vec![],
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
QueryPlan {
|
||||
narrowed: false,
|
||||
reason: "filter not narrowable".to_string(),
|
||||
target_shards: vec![],
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a filter expression for PK constraints.
|
||||
///
|
||||
/// Returns the PK constraint if narrowable, or an error if not.
|
||||
fn parse_pk_constraints(&self, filter: &str, pk_field: &str) -> Result<PkConstraint> {
|
||||
// Simple regex-based parser for common patterns:
|
||||
// 1. "{pk_field}" = "literal"
|
||||
// 2. "{pk_field}" IN ["literal1", "literal2", ...]
|
||||
|
||||
let filter = filter.trim();
|
||||
|
||||
// Try equality: pk = "literal"
|
||||
let eq_pattern = format!(r#"{}\s*=\s*["']([^"']+)["']"#, pk_field);
|
||||
if let Some(re) = regex::Regex::new(&eq_pattern).ok() {
|
||||
if let Some(caps) = re.captures(filter) {
|
||||
if let Some(literal) = caps.get(1) {
|
||||
return Ok(PkConstraint::Eq(literal.as_str().to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try IN list: pk IN ["literal1", "literal2", ...]
|
||||
let in_pattern = format!(r#"{}\s+IN\s+\[(.+)\]"#, pk_field);
|
||||
if let Some(re) = regex::Regex::new(&in_pattern).ok() {
|
||||
if let Some(caps) = re.captures(filter) {
|
||||
if let Some(list) = caps.get(1) {
|
||||
let literals = self.parse_string_list(list.as_str())?;
|
||||
return Ok(PkConstraint::In(literals));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for non-narrowable patterns
|
||||
if filter.contains(" OR ") {
|
||||
return Err(MiroirError::InvalidState("contains OR at top level".to_string()));
|
||||
}
|
||||
|
||||
if filter.contains(&format!("{} != ", pk_field)) || filter.contains(&format!("{}<>", pk_field)) {
|
||||
return Err(MiroirError::InvalidState("PK negation is not narrowable".to_string()));
|
||||
}
|
||||
|
||||
Err(MiroirError::InvalidState("no PK constraint found".to_string()))
|
||||
}
|
||||
|
||||
/// Parse a comma-separated list of string literals.
|
||||
fn parse_string_list(&self, input: &str) -> Result<Vec<String>> {
|
||||
let mut result = Vec::new();
|
||||
let mut current = String::new();
|
||||
let mut in_string = false;
|
||||
let mut escape = false;
|
||||
|
||||
for ch in input.chars() {
|
||||
match ch {
|
||||
'\\' if in_string => {
|
||||
escape = true;
|
||||
}
|
||||
'"' if in_string && !escape => {
|
||||
in_string = false;
|
||||
result.push(current.clone());
|
||||
current.clear();
|
||||
}
|
||||
'"' if !in_string => {
|
||||
in_string = true;
|
||||
}
|
||||
',' if !in_string => {
|
||||
// Skip
|
||||
}
|
||||
' ' | '\t' | '\n' if !in_string => {
|
||||
// Skip whitespace
|
||||
}
|
||||
ch => {
|
||||
current.push(ch);
|
||||
escape = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryPlanner {
|
||||
fn default() -> Self {
|
||||
Self::new(QueryPlannerConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parsed PK constraint.
|
||||
#[derive(Debug, Clone)]
|
||||
enum PkConstraint {
|
||||
/// Single equality: pk = "literal"
|
||||
Eq(String),
|
||||
/// IN list: pk IN ["a", "b", ...]
|
||||
In(Vec<String>),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = QueryPlannerConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.max_pk_literals_narrowable, 128);
|
||||
assert!(!config.log_plans);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plan_disabled() {
|
||||
let config = QueryPlannerConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
};
|
||||
let planner = QueryPlanner::new(config);
|
||||
|
||||
let plan = planner
|
||||
.plan("products", &Some("sku = \"abc\"".to_string()), 64)
|
||||
.await;
|
||||
|
||||
assert!(!plan.narrowed);
|
||||
assert!(plan.reason.contains("disabled"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plan_pk_equality() {
|
||||
let planner = QueryPlanner::default();
|
||||
planner.set_primary_key("products".into(), "sku".into()).await;
|
||||
|
||||
let plan = planner
|
||||
.plan("products", &Some("sku = \"abc123\"".to_string()), 64)
|
||||
.await;
|
||||
|
||||
assert!(plan.narrowed);
|
||||
assert_eq!(plan.target_shards.len(), 1);
|
||||
assert!(plan.reason.contains("PK equality"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plan_no_filter() {
|
||||
let planner = QueryPlanner::default();
|
||||
let plan = planner.plan("products", &None, 64).await;
|
||||
|
||||
assert!(!plan.narrowed);
|
||||
assert!(plan.reason.contains("no filter"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plan_or_not_narrowable() {
|
||||
let planner = QueryPlanner::default();
|
||||
planner.set_primary_key("products".into(), "sku".into()).await;
|
||||
|
||||
let plan = planner
|
||||
.plan(
|
||||
"products",
|
||||
&Some("sku = \"abc\" OR category = \"books\"".to_string()),
|
||||
64,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(!plan.narrowed);
|
||||
assert!(plan.reason.contains("OR"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plan_no_pk_configured() {
|
||||
let planner = QueryPlanner::default();
|
||||
let plan = planner
|
||||
.plan("products", &Some("sku = \"abc\"".to_string()), 64)
|
||||
.await;
|
||||
|
||||
assert!(!plan.narrowed);
|
||||
assert!(plan.reason.contains("primary key not configured"));
|
||||
}
|
||||
}
|
||||
432
crates/miroir-core/src/replica_selection.rs
Normal file
432
crates/miroir-core/src/replica_selection.rs
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
//! Adaptive replica selection using EWMA scoring (plan §13.3).
|
||||
//!
|
||||
//! Replaces round-robin with latency-aware selection using EWMA-smoothed
|
||||
//! metrics: latency p95, in-flight request count, and error rate.
|
||||
|
||||
use crate::error::{MiroirError, Result};
|
||||
use crate::topology::{Group, NodeId};
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Replica selection strategy.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SelectionStrategy {
|
||||
/// EWMA-based adaptive selection.
|
||||
Adaptive,
|
||||
/// Round-robin selection.
|
||||
RoundRobin,
|
||||
/// Random selection.
|
||||
Random,
|
||||
}
|
||||
|
||||
impl Default for SelectionStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Adaptive
|
||||
}
|
||||
}
|
||||
|
||||
/// Replica selection configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReplicaSelectionConfig {
|
||||
/// Selection strategy.
|
||||
#[serde(default)]
|
||||
pub strategy: String,
|
||||
/// Latency weight in score computation.
|
||||
#[serde(default = "default_latency_weight")]
|
||||
pub latency_weight: f64,
|
||||
/// In-flight request weight.
|
||||
#[serde(default = "default_inflight_weight")]
|
||||
pub inflight_weight: f64,
|
||||
/// Error rate weight.
|
||||
#[serde(default = "default_error_weight")]
|
||||
pub error_weight: f64,
|
||||
/// EWMA half-life in milliseconds.
|
||||
#[serde(default = "default_ewma_half_life")]
|
||||
pub ewma_half_life_ms: u64,
|
||||
/// Exploration epsilon (probability of random selection).
|
||||
#[serde(default = "default_epsilon")]
|
||||
pub exploration_epsilon: f64,
|
||||
}
|
||||
|
||||
fn default_latency_weight() -> f64 {
|
||||
1.0
|
||||
}
|
||||
fn default_inflight_weight() -> f64 {
|
||||
2.0
|
||||
}
|
||||
fn default_error_weight() -> f64 {
|
||||
10.0
|
||||
}
|
||||
fn default_ewma_half_life() -> u64 {
|
||||
5000
|
||||
}
|
||||
fn default_epsilon() -> f64 {
|
||||
0.05
|
||||
}
|
||||
|
||||
impl Default for ReplicaSelectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: "adaptive".into(),
|
||||
latency_weight: default_latency_weight(),
|
||||
inflight_weight: default_inflight_weight(),
|
||||
error_weight: default_error_weight(),
|
||||
ewma_half_life_ms: default_ewma_half_life(),
|
||||
exploration_epsilon: default_epsilon(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-node metrics for adaptive selection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NodeMetrics {
|
||||
/// EWMA of latency p95 in milliseconds.
|
||||
pub latency_p95_ms: f64,
|
||||
/// Current in-flight request count.
|
||||
pub in_flight: u32,
|
||||
/// EWMA of error rate (0.0 to 1.0).
|
||||
pub error_rate: f64,
|
||||
/// EWMA half-life for updates.
|
||||
pub half_life_ms: u64,
|
||||
/// Last update timestamp.
|
||||
pub last_updated: Instant,
|
||||
}
|
||||
|
||||
impl NodeMetrics {
|
||||
/// Create new metrics with initial values.
|
||||
pub fn new(initial_latency_ms: f64, half_life_ms: u64) -> Self {
|
||||
Self {
|
||||
latency_p95_ms: initial_latency_ms,
|
||||
in_flight: 0,
|
||||
error_rate: 0.0,
|
||||
half_life_ms,
|
||||
last_updated: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update latency with EWMA smoothing.
|
||||
pub fn update_latency(&mut self, latency_ms: f64) {
|
||||
let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0);
|
||||
self.latency_p95_ms = alpha * self.latency_p95_ms + (1.0 - alpha) * latency_ms;
|
||||
self.last_updated = Instant::now();
|
||||
}
|
||||
|
||||
/// Update error rate with EWMA smoothing.
|
||||
pub fn update_error(&mut self, is_error: bool) {
|
||||
let alpha = 0.5_f64.powf((self.half_life_ms as f64) / 1000.0);
|
||||
let new_error = if is_error { 1.0 } else { 0.0 };
|
||||
self.error_rate = alpha * self.error_rate + (1.0 - alpha) * new_error;
|
||||
self.last_updated = Instant::now();
|
||||
}
|
||||
|
||||
/// Increment in-flight count.
|
||||
pub fn increment_in_flight(&mut self) {
|
||||
self.in_flight += 1;
|
||||
}
|
||||
|
||||
/// Decrement in-flight count.
|
||||
pub fn decrement_in_flight(&mut self) {
|
||||
self.in_flight = self.in_flight.saturating_sub(1);
|
||||
}
|
||||
|
||||
/// Compute the composite score (lower is better).
|
||||
pub fn score(&self, config: &ReplicaSelectionConfig) -> f64 {
|
||||
config.latency_weight * self.latency_p95_ms
|
||||
+ config.inflight_weight * (self.in_flight as f64)
|
||||
+ config.error_weight * (self.error_rate * 1000.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeMetrics {
|
||||
fn default() -> Self {
|
||||
Self::new(50.0, 5000)
|
||||
}
|
||||
}
|
||||
|
||||
/// Replica selector.
|
||||
pub struct ReplicaSelector {
|
||||
/// Configuration.
|
||||
config: ReplicaSelectionConfig,
|
||||
/// Per-node metrics.
|
||||
metrics: Arc<RwLock<HashMap<NodeId, NodeMetrics>>>,
|
||||
/// Round-robin counter (for round-robin strategy).
|
||||
rr_counter: Arc<RwLock<HashMap<String, u64>>>,
|
||||
/// Random number generator.
|
||||
rng: Arc<std::sync::Mutex<StdRng>>,
|
||||
}
|
||||
|
||||
impl ReplicaSelector {
|
||||
/// Create a new replica selector.
|
||||
pub fn new(config: ReplicaSelectionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
metrics: Arc::new(RwLock::new(HashMap::new())),
|
||||
rr_counter: Arc::new(RwLock::new(HashMap::new())),
|
||||
rng: Arc::new(std::sync::Mutex::new(StdRng::from_entropy())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select a node from the given candidates.
|
||||
///
|
||||
/// Returns the selected node ID, or None if candidates is empty.
|
||||
pub async fn select(&self, candidates: &[NodeId], group_id: u32) -> Option<NodeId> {
|
||||
if candidates.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let strategy = self.parse_strategy();
|
||||
|
||||
match strategy {
|
||||
SelectionStrategy::Adaptive => self.select_adaptive(candidates).await,
|
||||
SelectionStrategy::RoundRobin => self.select_round_robin(candidates, group_id as u64).await,
|
||||
SelectionStrategy::Random => self.select_random(candidates),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adaptive selection using EWMA scores.
|
||||
async fn select_adaptive(&self, candidates: &[NodeId]) -> Option<NodeId> {
|
||||
let metrics = self.metrics.read().await;
|
||||
|
||||
// Exploration: with probability epsilon, pick randomly
|
||||
if self.should_explore() {
|
||||
return self.select_random(candidates);
|
||||
}
|
||||
|
||||
// Compute scores and find the minimum
|
||||
let mut best_node = None;
|
||||
let mut best_score = f64::INFINITY;
|
||||
|
||||
for node in candidates {
|
||||
let score = metrics
|
||||
.get(node)
|
||||
.map(|m| m.score(&self.config))
|
||||
.unwrap_or(1000.0); // High default for unknown nodes
|
||||
|
||||
if score < best_score {
|
||||
best_score = score;
|
||||
best_node = Some(node.clone());
|
||||
}
|
||||
}
|
||||
|
||||
best_node
|
||||
}
|
||||
|
||||
/// Round-robin selection.
|
||||
async fn select_round_robin(&self, candidates: &[NodeId], group_id: u64) -> Option<NodeId> {
|
||||
let key = format!("group_{}", group_id);
|
||||
let mut counter = self.rr_counter.write().await;
|
||||
let idx = *counter.entry(key.clone()).or_insert(0) as usize % candidates.len();
|
||||
*counter.get_mut(&key).unwrap() += 1;
|
||||
Some(candidates[idx].clone())
|
||||
}
|
||||
|
||||
/// Random selection.
|
||||
fn select_random(&self, candidates: &[NodeId]) -> Option<NodeId> {
|
||||
if candidates.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let idx = self
|
||||
.rng
|
||||
.lock()
|
||||
.unwrap()
|
||||
.gen_range(0..candidates.len());
|
||||
Some(candidates[idx].clone())
|
||||
}
|
||||
|
||||
/// Check if we should explore (random selection).
|
||||
fn should_explore(&self) -> bool {
|
||||
let mut rng = self.rng.lock().unwrap();
|
||||
rng.gen::<f64>() < self.config.exploration_epsilon
|
||||
}
|
||||
|
||||
/// Record a successful request (update latency).
|
||||
pub async fn record_success(&self, node: &NodeId, latency_ms: f64) {
|
||||
let mut metrics = self.metrics.write().await;
|
||||
let entry = metrics
|
||||
.entry(node.clone())
|
||||
.or_insert_with(NodeMetrics::default);
|
||||
entry.update_latency(latency_ms);
|
||||
entry.update_error(false);
|
||||
entry.decrement_in_flight();
|
||||
}
|
||||
|
||||
/// Record a failed request.
|
||||
pub async fn record_error(&self, node: &NodeId, latency_ms: Option<f64>) {
|
||||
let mut metrics = self.metrics.write().await;
|
||||
let entry = metrics
|
||||
.entry(node.clone())
|
||||
.or_insert_with(NodeMetrics::default);
|
||||
if let Some(lat) = latency_ms {
|
||||
entry.update_latency(lat);
|
||||
}
|
||||
entry.update_error(true);
|
||||
entry.decrement_in_flight();
|
||||
}
|
||||
|
||||
/// Record that a request is being sent to a node.
|
||||
pub async fn record_request_start(&self, node: &NodeId) {
|
||||
let mut metrics = self.metrics.write().await;
|
||||
let entry = metrics
|
||||
.entry(node.clone())
|
||||
.or_insert_with(NodeMetrics::default);
|
||||
entry.increment_in_flight();
|
||||
}
|
||||
|
||||
/// Get metrics for a node.
|
||||
pub async fn get_metrics(&self, node: &NodeId) -> Option<NodeMetrics> {
|
||||
let metrics = self.metrics.read().await;
|
||||
metrics.get(node).cloned()
|
||||
}
|
||||
|
||||
/// Parse the strategy from config string.
|
||||
fn parse_strategy(&self) -> SelectionStrategy {
|
||||
match self.config.strategy.as_str() {
|
||||
"round_robin" => SelectionStrategy::RoundRobin,
|
||||
"random" => SelectionStrategy::Random,
|
||||
_ => SelectionStrategy::Adaptive,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReplicaSelector {
|
||||
fn default() -> Self {
|
||||
Self::new(ReplicaSelectionConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = ReplicaSelectionConfig::default();
|
||||
assert_eq!(config.strategy, "adaptive");
|
||||
assert_eq!(config.latency_weight, 1.0);
|
||||
assert_eq!(config.inflight_weight, 2.0);
|
||||
assert_eq!(config.error_weight, 10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_metrics_score() {
|
||||
let mut metrics = NodeMetrics::new(50.0, 5000);
|
||||
assert_eq!(metrics.score(&ReplicaSelectionConfig::default()), 50.0);
|
||||
|
||||
metrics.in_flight = 5;
|
||||
let score = metrics.score(&ReplicaSelectionConfig::default());
|
||||
// 50 * 1.0 + 5 * 2.0 = 60
|
||||
assert_eq!(score, 60.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_metrics_ewma() {
|
||||
let mut metrics = NodeMetrics::new(100.0, 1000); // Short half-life
|
||||
|
||||
metrics.update_latency(50.0);
|
||||
// Should move toward 50
|
||||
assert!(metrics.latency_p95_ms < 100.0 && metrics.latency_p95_ms > 40.0);
|
||||
|
||||
metrics.update_error(true);
|
||||
assert!(metrics.error_rate > 0.0);
|
||||
|
||||
metrics.update_error(false);
|
||||
// Error rate should decay
|
||||
let rate_before = metrics.error_rate;
|
||||
metrics.update_error(false);
|
||||
assert!(metrics.error_rate < rate_before);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_adaptive() {
|
||||
let selector = ReplicaSelector::new(ReplicaSelectionConfig::default());
|
||||
|
||||
let node1 = NodeId::new("node-1".to_string());
|
||||
let node2 = NodeId::new("node-2".to_string());
|
||||
|
||||
// Record some metrics
|
||||
{
|
||||
let mut metrics = selector.metrics.write().await;
|
||||
metrics.insert(
|
||||
node1.clone(),
|
||||
NodeMetrics {
|
||||
latency_p95_ms: 10.0,
|
||||
in_flight: 0,
|
||||
error_rate: 0.0,
|
||||
half_life_ms: 5000,
|
||||
last_updated: Instant::now(),
|
||||
},
|
||||
);
|
||||
metrics.insert(
|
||||
node2.clone(),
|
||||
NodeMetrics {
|
||||
latency_p95_ms: 100.0,
|
||||
in_flight: 0,
|
||||
error_rate: 0.0,
|
||||
half_life_ms: 5000,
|
||||
last_updated: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Should select node-1 (lower score)
|
||||
let candidates = vec![node2.clone(), node1.clone()];
|
||||
let selected = selector.select(&candidates, 0).await;
|
||||
assert_eq!(selected, Some(node1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_round_robin() {
|
||||
let config = ReplicaSelectionConfig {
|
||||
strategy: "round_robin".into(),
|
||||
..Default::default()
|
||||
};
|
||||
let selector = ReplicaSelector::new(config);
|
||||
|
||||
let node1 = NodeId::new("node-1".to_string());
|
||||
let node2 = NodeId::new("node-2".to_string());
|
||||
|
||||
let candidates = vec![node1.clone(), node2.clone()];
|
||||
|
||||
// First call should return node-1
|
||||
let selected = selector.select(&candidates, 0).await;
|
||||
assert_eq!(selected, Some(node1.clone()));
|
||||
|
||||
// Second call should return node-2
|
||||
let selected = selector.select(&candidates, 0).await;
|
||||
assert_eq!(selected, Some(node2.clone()));
|
||||
|
||||
// Third call should wrap to node-1
|
||||
let selected = selector.select(&candidates, 0).await;
|
||||
assert_eq!(selected, Some(node1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_record_request_lifecycle() {
|
||||
let selector = ReplicaSelector::default();
|
||||
let node = NodeId::new("node-1".to_string());
|
||||
|
||||
selector.record_request_start(&node).await;
|
||||
|
||||
let metrics = selector.get_metrics(&node).await;
|
||||
assert!(metrics.is_some());
|
||||
assert_eq!(metrics.unwrap().in_flight, 1);
|
||||
|
||||
selector.record_success(&node, 50.0).await;
|
||||
|
||||
let metrics = selector.get_metrics(&node).await;
|
||||
assert!(metrics.is_some());
|
||||
assert_eq!(metrics.unwrap().in_flight, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_candidates() {
|
||||
let selector = ReplicaSelector::default();
|
||||
let selected = selector.select(&[], 0).await;
|
||||
assert!(selected.is_none());
|
||||
}
|
||||
}
|
||||
388
crates/miroir-core/src/vector.rs
Normal file
388
crates/miroir-core/src/vector.rs
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
//! Vector and hybrid search sharding (plan §13.12).
|
||||
//!
|
||||
//! Handles over-fetching and merging for vector/hybrid search across shards.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Vector search configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorSearchConfig {
|
||||
/// Whether vector search is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
/// Over-fetch factor (per-shard limit = requested limit × factor).
|
||||
#[serde(default = "default_over_fetch")]
|
||||
pub over_fetch_factor: u32,
|
||||
/// Merge strategy: "convex" or "rrf".
|
||||
#[serde(default = "default_merge_strategy")]
|
||||
pub merge_strategy: String,
|
||||
/// Default hybrid alpha (for convex combination).
|
||||
#[serde(default = "default_alpha")]
|
||||
pub hybrid_alpha_default: f64,
|
||||
/// RRF constant (for Reciprocal Rank Fusion).
|
||||
#[serde(default = "default_rrf_k")]
|
||||
pub rrf_k: u32,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_over_fetch() -> u32 {
|
||||
3
|
||||
}
|
||||
fn default_merge_strategy() -> String {
|
||||
"convex".to_string()
|
||||
}
|
||||
fn default_alpha() -> f64 {
|
||||
0.5
|
||||
}
|
||||
fn default_rrf_k() -> u32 {
|
||||
60
|
||||
}
|
||||
|
||||
impl Default for VectorSearchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
over_fetch_factor: default_over_fetch(),
|
||||
merge_strategy: default_merge_strategy(),
|
||||
hybrid_alpha_default: default_alpha(),
|
||||
rrf_k: default_rrf_k(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge strategy for combining results from multiple shards.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MergeStrategy {
|
||||
/// Convex combination: (1 - α) · bm25 + α · semantic
|
||||
Convex,
|
||||
/// Reciprocal Rank Fusion
|
||||
Rrf,
|
||||
}
|
||||
|
||||
impl MergeStrategy {
|
||||
/// Parse from string.
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"convex" => Some(MergeStrategy::Convex),
|
||||
"rrf" => Some(MergeStrategy::Rrf),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A search hit with scores from multiple sources.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorHit {
|
||||
/// Primary key of the document.
|
||||
pub pk: String,
|
||||
/// BM25 ranking score.
|
||||
pub ranking_score: f64,
|
||||
/// Semantic score (if present).
|
||||
pub semantic_score: Option<f64>,
|
||||
/// Combined global score.
|
||||
pub combined_score: f64,
|
||||
/// Source shard.
|
||||
pub shard_id: u32,
|
||||
}
|
||||
|
||||
impl VectorHit {
|
||||
/// Create a new hit from BM25 only.
|
||||
pub fn bm25_only(pk: String, ranking_score: f64, shard_id: u32) -> Self {
|
||||
Self {
|
||||
pk,
|
||||
ranking_score,
|
||||
semantic_score: None,
|
||||
combined_score: ranking_score,
|
||||
shard_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new hybrid hit.
|
||||
pub fn hybrid(pk: String, ranking_score: f64, semantic_score: f64, shard_id: u32) -> Self {
|
||||
Self {
|
||||
pk,
|
||||
ranking_score,
|
||||
semantic_score: Some(semantic_score),
|
||||
combined_score: ranking_score, // Will be recomputed during merge
|
||||
shard_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge with convex combination.
|
||||
pub fn merge_convex(&mut self, alpha: f64) {
|
||||
if let Some(semantic) = self.semantic_score {
|
||||
self.combined_score = (1.0 - alpha) * self.ranking_score + alpha * semantic;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get RRF score for a given rank.
|
||||
pub fn rrf_score(rank: usize, k: u32) -> f64 {
|
||||
1.0 / (k as f64 + rank as f64)
|
||||
}
|
||||
}
|
||||
|
||||
/// Vector search merger — combines over-fetched results from multiple shards.
|
||||
pub struct VectorMerger {
|
||||
/// Merge strategy.
|
||||
strategy: MergeStrategy,
|
||||
/// Hybrid alpha (for convex).
|
||||
alpha: f64,
|
||||
/// RRF constant.
|
||||
rrf_k: u32,
|
||||
}
|
||||
|
||||
impl VectorMerger {
|
||||
/// Create a new vector merger.
|
||||
pub fn new(config: &VectorSearchConfig) -> Self {
|
||||
let strategy = MergeStrategy::from_str(&config.merge_strategy)
|
||||
.unwrap_or(MergeStrategy::Convex);
|
||||
Self {
|
||||
strategy,
|
||||
alpha: config.hybrid_alpha_default,
|
||||
rrf_k: config.rrf_k,
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge hits from multiple shards into a single ranked list.
|
||||
///
|
||||
/// Input: Vec of (shard_id, hits from that shard)
|
||||
/// Output: Globally ranked hits, truncated to `limit`
|
||||
pub fn merge(&self, shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec<VectorHit> {
|
||||
match self.strategy {
|
||||
MergeStrategy::Convex => self.merge_convex(shard_hits, limit),
|
||||
MergeStrategy::Rrf => self.merge_rrf(shard_hits, limit),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convex combination merge.
|
||||
fn merge_convex(&self, mut shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec<VectorHit> {
|
||||
// Apply convex combination to each hit
|
||||
for (_, hit) in &mut shard_hits {
|
||||
hit.merge_convex(self.alpha);
|
||||
}
|
||||
|
||||
// Sort by combined score descending
|
||||
shard_hits.sort_by(|a, b| {
|
||||
b.1.combined_score
|
||||
.partial_cmp(&a.1.combined_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Deduplicate by PK (keep highest score)
|
||||
let mut deduped = HashMap::new();
|
||||
for (_, hit) in shard_hits {
|
||||
deduped
|
||||
.entry(hit.pk.clone())
|
||||
.and_modify(|e: &mut VectorHit| {
|
||||
if hit.combined_score > e.combined_score {
|
||||
*e = hit;
|
||||
}
|
||||
})
|
||||
.or_insert(hit);
|
||||
}
|
||||
|
||||
// Convert back to vec and re-sort
|
||||
let mut result: Vec<_> = deduped.into_values().collect();
|
||||
result.sort_by(|a, b| {
|
||||
b.combined_score
|
||||
.partial_cmp(&a.combined_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
result.truncate(limit);
|
||||
result
|
||||
}
|
||||
|
||||
/// RRF (Reciprocal Rank Fusion) merge.
|
||||
fn merge_rrf(&self, shard_hits: Vec<(u32, VectorHit)>, limit: usize) -> Vec<VectorHit> {
|
||||
// Group hits by PK and accumulate RRF scores
|
||||
let mut rrf_scores: HashMap<String, f64> = HashMap::new();
|
||||
let mut hit_data: HashMap<String, VectorHit> = HashMap::new();
|
||||
|
||||
// First, sort each shard's hits by their original ranking score
|
||||
let mut per_shard: HashMap<u32, Vec<VectorHit>> = HashMap::new();
|
||||
for (shard_id, hit) in shard_hits {
|
||||
per_shard
|
||||
.entry(shard_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(hit);
|
||||
}
|
||||
|
||||
// Compute RRF scores
|
||||
for (_shard_id, mut hits) in per_shard {
|
||||
// Sort by ranking_score descending (original per-shard ranking)
|
||||
hits.sort_by(|a, b| {
|
||||
b.ranking_score
|
||||
.partial_cmp(&a.ranking_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for (rank, hit) in hits.into_iter().enumerate() {
|
||||
let pk = hit.pk.clone();
|
||||
let rrf_score = VectorHit::rrf_score(rank, self.rrf_k);
|
||||
*rrf_scores.entry(pk.clone()).or_insert(0.0) += rrf_score;
|
||||
|
||||
hit_data.entry(pk).or_insert(hit);
|
||||
}
|
||||
}
|
||||
|
||||
// Build result with RRF scores
|
||||
let mut result: Vec<VectorHit> = hit_data
|
||||
.into_iter()
|
||||
.map(|(pk, mut hit)| {
|
||||
hit.combined_score = *rrf_scores.get(&pk).unwrap_or(&0.0);
|
||||
hit
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by RRF score descending
|
||||
result.sort_by(|a, b| {
|
||||
b.combined_score
|
||||
.partial_cmp(&a.combined_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
result.truncate(limit);
|
||||
result
|
||||
}
|
||||
|
||||
/// Compute the per-shard limit given a requested limit.
|
||||
pub fn per_shard_limit(&self, requested_limit: usize, over_fetch_factor: u32) -> usize {
|
||||
requested_limit * over_fetch_factor as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VectorMerger {
|
||||
fn default() -> Self {
|
||||
Self::new(&VectorSearchConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_defaults() {
|
||||
let config = VectorSearchConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.over_fetch_factor, 3);
|
||||
assert_eq!(config.merge_strategy, "convex");
|
||||
assert_eq!(config.hybrid_alpha_default, 0.5);
|
||||
assert_eq!(config.rrf_k, 60);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_strategy_from_str() {
|
||||
assert_eq!(MergeStrategy::from_str("convex"), Some(MergeStrategy::Convex));
|
||||
assert_eq!(MergeStrategy::from_str("rrf"), Some(MergeStrategy::Rrf));
|
||||
assert_eq!(MergeStrategy::from_str("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_hit_bm25_only() {
|
||||
let hit = VectorHit::bm25_only("doc1".to_string(), 0.8, 5);
|
||||
assert_eq!(hit.pk, "doc1");
|
||||
assert_eq!(hit.ranking_score, 0.8);
|
||||
assert!(hit.semantic_score.is_none());
|
||||
assert_eq!(hit.combined_score, 0.8);
|
||||
assert_eq!(hit.shard_id, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_hit_hybrid() {
|
||||
let mut hit = VectorHit::hybrid("doc1".to_string(), 0.6, 0.9, 5);
|
||||
assert_eq!(hit.pk, "doc1");
|
||||
assert_eq!(hit.ranking_score, 0.6);
|
||||
assert_eq!(hit.semantic_score, Some(0.9));
|
||||
assert_eq!(hit.shard_id, 5);
|
||||
|
||||
hit.merge_convex(0.5);
|
||||
assert!((hit.combined_score - 0.75).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_score() {
|
||||
let score = VectorHit::rrf_score(0, 60);
|
||||
assert!((score - 1.0 / 60.0).abs() < 0.0001);
|
||||
|
||||
let score = VectorHit::rrf_score(10, 60);
|
||||
assert!((score - 1.0 / 70.0).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_convex_basic() {
|
||||
let merger = VectorMerger {
|
||||
strategy: MergeStrategy::Convex,
|
||||
alpha: 0.5,
|
||||
rrf_k: 60,
|
||||
};
|
||||
|
||||
let hits = vec![
|
||||
(0, VectorHit::hybrid("doc1".to_string(), 0.8, 0.6, 0)),
|
||||
(0, VectorHit::hybrid("doc2".to_string(), 0.7, 0.9, 0)),
|
||||
(1, VectorHit::hybrid("doc1".to_string(), 0.75, 0.65, 1)),
|
||||
(1, VectorHit::hybrid("doc3".to_string(), 0.9, 0.5, 1)),
|
||||
];
|
||||
|
||||
let result = merger.merge_convex(hits, 10);
|
||||
|
||||
// Should deduplicate doc1, keeping the highest combined score
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0].pk, "doc3"); // (0.9 + 0.5) / 2 = 0.7
|
||||
assert_eq!(result[1].pk, "doc2"); // (0.7 + 0.9) / 2 = 0.8
|
||||
assert_eq!(result[2].pk, "doc1"); // Should keep shard 0's version (0.8+0.6)/2=0.7 > (0.75+0.65)/2=0.7
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_rrf_basic() {
|
||||
let merger = VectorMerger {
|
||||
strategy: MergeStrategy::Rrf,
|
||||
alpha: 0.5,
|
||||
rrf_k: 60,
|
||||
};
|
||||
|
||||
let hits = vec![
|
||||
(0, VectorHit::bm25_only("doc1".to_string(), 0.9, 0)),
|
||||
(0, VectorHit::bm25_only("doc2".to_string(), 0.8, 0)),
|
||||
(0, VectorHit::bm25_only("doc3".to_string(), 0.7, 0)),
|
||||
(1, VectorHit::bm25_only("doc2".to_string(), 0.95, 1)),
|
||||
(1, VectorHit::bm25_only("doc4".to_string(), 0.85, 1)),
|
||||
];
|
||||
|
||||
let result = merger.merge_rrf(hits, 10);
|
||||
|
||||
// doc2 appears in both shards, gets summed RRF scores
|
||||
assert!(result.iter().any(|h| h.pk == "doc2"));
|
||||
let doc2 = result.iter().find(|h| h.pk == "doc2").unwrap();
|
||||
// Rank 1 in shard 0: 1/61, rank 1 in shard 1: 1/61
|
||||
assert!((doc2.combined_score - 2.0 / 61.0).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_per_shard_limit() {
|
||||
let merger = VectorMerger::default();
|
||||
assert_eq!(merger.per_shard_limit(10, 3), 30);
|
||||
assert_eq!(merger.per_shard_limit(100, 2), 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_limits_output() {
|
||||
let merger = VectorMerger::default();
|
||||
let hits: Vec<_> = (0..200)
|
||||
.map(|i| {
|
||||
(
|
||||
i % 10,
|
||||
VectorHit::bm25_only(format!("doc{}", i), 1.0 - (i as f64) * 0.001, i % 10),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let result = merger.merge_convex(hits, 50);
|
||||
assert_eq!(result.len(), 50);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue