diff --git a/crates/miroir-core/src/config/advanced.rs b/crates/miroir-core/src/config/advanced.rs index b38a43e..f26ccfd 100644 --- a/crates/miroir-core/src/config/advanced.rs +++ b/crates/miroir-core/src/config/advanced.rs @@ -922,3 +922,19 @@ impl Default for TracingConfig { } } } + +// --------------------------------------------------------------------------- +// Conversions +// --------------------------------------------------------------------------- + +impl From for crate::vector::VectorSearchConfig { + fn from(config: VectorSearchConfig) -> Self { + Self { + enabled: config.enabled, + over_fetch_factor: config.over_fetch_factor, + merge_strategy: config.merge_strategy, + hybrid_alpha_default: config.hybrid_alpha_default, + rrf_k: config.rrf_k, + } + } +} diff --git a/crates/miroir-core/src/lib.rs b/crates/miroir-core/src/lib.rs index 4a1cb5d..d8132d8 100644 --- a/crates/miroir-core/src/lib.rs +++ b/crates/miroir-core/src/lib.rs @@ -56,6 +56,7 @@ pub mod task_store; pub mod tenant; pub mod topology; pub mod ttl; +pub mod vector; // Raft prototype temporarily disabled (openraft 0.9.22 fails on Rust 1.87) // #[cfg(feature = "raft-proto")] diff --git a/crates/miroir-core/src/merger.rs b/crates/miroir-core/src/merger.rs index dd4792c..ef7173a 100644 --- a/crates/miroir-core/src/merger.rs +++ b/crates/miroir-core/src/merger.rs @@ -3,10 +3,12 @@ //! Supports pluggable merge strategies via the [`MergeStrategy`] trait. //! The default strategy is Reciprocal Rank Fusion (RRF) with k=60. +use crate::scatter::VectorMode; +use crate::vector::{VectorHit, VectorMerger, VectorSearchConfig}; use crate::Result; use serde_json::{Map, Value}; use std::cmp::Ordering; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; /// Input to the merge operation. #[derive(Debug, Clone)] @@ -28,6 +30,27 @@ pub struct MergeInput { /// Failed shard IDs (for X-Miroir-Degraded header). pub failed_shards: Vec, + + /// Vector search mode (plan §13.12). + pub vector_mode: VectorMode, + + /// Vector search configuration (for merge strategy selection). + pub vector_config: Option, +} + +impl Default for MergeInput { + fn default() -> Self { + Self { + shard_hits: Vec::new(), + offset: 0, + limit: 20, + client_requested_score: false, + facets: None, + failed_shards: Vec::new(), + vector_mode: VectorMode::KeywordOnly, + vector_config: None, + } + } } /// Response from a single shard (node). @@ -488,6 +511,202 @@ fn score_merge(input: MergeInput) -> Result { }) } +// --------------------------------------------------------------------------- +// Vector merge strategy (plan §13.12) +// --------------------------------------------------------------------------- + +/// Vector-aware merge strategy for hybrid search (plan §13.12). +#[derive(Debug, Clone)] +pub struct VectorMergeStrategy { + merger: VectorMerger, +} + +impl VectorMergeStrategy { + pub fn new(config: &VectorSearchConfig) -> Self { + Self { + merger: VectorMerger::new(config), + } + } + + pub fn default_strategy() -> Self { + Self::new(&VectorSearchConfig::default()) + } +} + +impl Default for VectorMergeStrategy { + fn default() -> Self { + Self::default_strategy() + } +} + +impl MergeStrategy for VectorMergeStrategy { + fn merge(&self, input: MergeInput) -> Result { + vector_merge(&self.merger, &input) + } + + fn name(&self) -> &'static str { + "vector" + } +} + +fn vector_merge(merger: &VectorMerger, input: &MergeInput) -> Result { + let mut estimated_total_hits = 0u64; + let mut max_processing_time = 0u64; + let degraded = !input.failed_shards.is_empty(); + let mut shard_hits = Vec::new(); + + for shard_page in &input.shard_hits { + let body = &shard_page.body; + + if let Some(Value::Number(n)) = body.get("estimatedTotalHits") { + if let Some(n) = n.as_u64() { + estimated_total_hits = estimated_total_hits.saturating_add(n); + } + } + + if let Some(Value::Number(n)) = body.get("processingTimeMs") { + if let Some(n) = n.as_u64() { + max_processing_time = max_processing_time.max(n); + } + } + + if let Some(Value::Array(hits)) = body.get("hits") { + for hit in hits { + if let Value::Object(map) = hit { + let pk = map + .get("id") + .or_else(|| map.get("pk")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let ranking_score = map + .get("_rankingScore") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + + let semantic_score = map.get("_semanticScore").and_then(|v| v.as_f64()); + + let shard_id = map + .get("_miroir_shard") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + + let vector_hit = match input.vector_mode { + VectorMode::KeywordOnly => { + VectorHit::bm25_only(pk, ranking_score, shard_id) + } + VectorMode::VectorOnly => { + let semantic = semantic_score.unwrap_or(0.0); + let mut hit = VectorHit::bm25_only(pk, semantic, shard_id); + hit.semantic_score = Some(semantic); + hit.combined_score = semantic; + hit + } + VectorMode::Hybrid => { + if let Some(semantic) = semantic_score { + VectorHit::hybrid(pk, ranking_score, semantic, shard_id) + } else { + VectorHit::bm25_only(pk, ranking_score, shard_id) + } + } + }; + + shard_hits.push((shard_id, vector_hit)); + } + } + } + } + + let merged_hits = merger.merge(shard_hits, input.limit + input.offset); + let paginated_hits: Vec<_> = merged_hits + .into_iter() + .skip(input.offset) + .take(input.limit) + .collect(); + + let mut hits = Vec::with_capacity(paginated_hits.len()); + for doc in paginated_hits { + let mut hit_map = Map::new(); + hit_map.insert("id".to_string(), Value::String(doc.pk.clone())); + + if input.client_requested_score { + hit_map.insert( + "_rankingScore".to_string(), + Value::Number( + serde_json::Number::from_f64(doc.ranking_score) + .unwrap_or_else(|| serde_json::Number::from(0)), + ), + ); + } + + if input.client_requested_score { + if let Some(semantic) = doc.semantic_score { + hit_map.insert( + "_semanticScore".to_string(), + Value::Number( + serde_json::Number::from_f64(semantic) + .unwrap_or_else(|| serde_json::Number::from(0)), + ), + ); + } + } + + hit_map.retain(|k, _| !k.starts_with("_miroir_")); + hits.push(Value::Object(hit_map)); + } + + let facet_distribution = merge_facets(&input.shard_hits, input.facets.as_deref()); + + Ok(MergedSearchResult { + hits, + facet_distribution, + estimated_total_hits, + processing_time_ms: max_processing_time, + degraded, + failed_shards: input.failed_shards.clone(), + }) +} + +/// Adaptive merge strategy that selects vector or score merge based on mode. +#[derive(Debug, Clone)] +pub struct AdaptiveMergeStrategy { + vector_strategy: VectorMergeStrategy, + score_strategy: ScoreMergeStrategy, +} + +impl AdaptiveMergeStrategy { + pub fn new(vector_config: &VectorSearchConfig) -> Self { + Self { + vector_strategy: VectorMergeStrategy::new(vector_config), + score_strategy: ScoreMergeStrategy::new(), + } + } + + pub fn default_strategy() -> Self { + Self::new(&VectorSearchConfig::default()) + } +} + +impl Default for AdaptiveMergeStrategy { + fn default() -> Self { + Self::default_strategy() + } +} + +impl MergeStrategy for AdaptiveMergeStrategy { + fn merge(&self, input: MergeInput) -> Result { + match input.vector_mode { + VectorMode::KeywordOnly => self.score_strategy.merge(input), + VectorMode::VectorOnly | VectorMode::Hybrid => self.vector_strategy.merge(input), + } + } + + fn name(&self) -> &'static str { + "adaptive" + } +} + /// Merge facet distributions from multiple shards. /// /// Uses BTreeMap for stable ordering (deterministic serialization). diff --git a/crates/miroir-core/src/scatter.rs b/crates/miroir-core/src/scatter.rs index 53dba7c..e882737 100644 --- a/crates/miroir-core/src/scatter.rs +++ b/crates/miroir-core/src/scatter.rs @@ -347,6 +347,14 @@ pub enum VectorMode { Hybrid, } +impl Default for VectorMode { + fn default() -> Self { + Self::KeywordOnly + } +} + +use crate::vector::VectorSearchConfig; + #[derive(Debug, Clone)] pub struct SearchRequest { pub index_uid: String, @@ -364,6 +372,8 @@ pub struct SearchRequest { pub over_fetch_factor: u32, /// Vector search mode (keyword-only, vector-only, or hybrid). pub vector_mode: VectorMode, + /// Vector search configuration (for merge strategy). + pub vector_config: Option, } impl SearchRequest { @@ -1097,6 +1107,8 @@ pub async fn scatter_gather_search( client_requested_score: req.ranking_score, facets: req.facets.clone(), failed_shards, + vector_mode: req.vector_mode, + vector_config: req.vector_config.clone(), }; // Span for the merge operation @@ -1517,6 +1529,7 @@ mod tests { global_idf: None, over_fetch_factor: 1, vector_mode: VectorMode::KeywordOnly, + vector_config: None, } } @@ -2014,6 +2027,7 @@ mod tests { global_idf: None, over_fetch_factor: 1, vector_mode: VectorMode::KeywordOnly, + vector_config: None, }; let body = req.to_node_body(); @@ -2061,6 +2075,7 @@ mod tests { global_idf: None, over_fetch_factor: 1, vector_mode: VectorMode::KeywordOnly, + vector_config: None, }; let body = req.to_node_body(); @@ -2090,6 +2105,7 @@ mod tests { global_idf: None, over_fetch_factor: 1, vector_mode: VectorMode::KeywordOnly, + vector_config: None, }; let body = req.to_node_body(); @@ -2154,6 +2170,7 @@ mod tests { global_idf: None, over_fetch_factor: 3, vector_mode: VectorMode::KeywordOnly, + vector_config: None, }; let body = req.to_node_body(); assert_eq!(body.get("limit"), Some(&serde_json::json!(10))); diff --git a/crates/miroir-core/src/vector.rs b/crates/miroir-core/src/vector.rs index 7eeea7d..7334739 100644 --- a/crates/miroir-core/src/vector.rs +++ b/crates/miroir-core/src/vector.rs @@ -126,6 +126,7 @@ impl VectorHit { } /// Vector search merger — combines over-fetched results from multiple shards. +#[derive(Debug, Clone)] pub struct VectorMerger { /// Merge strategy. strategy: MergeStrategy, @@ -138,8 +139,8 @@ pub struct VectorMerger { impl VectorMerger { /// Create a new vector merger. pub fn new(config: &VectorSearchConfig) -> Self { - let strategy = MergeStrategy::from_str(&config.merge_strategy) - .unwrap_or(MergeStrategy::Convex); + let strategy = + MergeStrategy::from_str(&config.merge_strategy).unwrap_or(MergeStrategy::Convex); Self { strategy, alpha: config.hybrid_alpha_default, @@ -179,7 +180,7 @@ impl VectorMerger { .entry(hit.pk.clone()) .and_modify(|e: &mut VectorHit| { if hit.combined_score > e.combined_score { - *e = hit; + *e = hit.clone(); } }) .or_insert(hit); @@ -206,10 +207,7 @@ impl VectorMerger { // First, sort each shard's hits by their original ranking score let mut per_shard: HashMap> = HashMap::new(); for (shard_id, hit) in shard_hits { - per_shard - .entry(shard_id) - .or_insert_with(Vec::new) - .push(hit); + per_shard.entry(shard_id).or_insert_with(Vec::new).push(hit); } // Compute RRF scores @@ -278,7 +276,10 @@ mod tests { #[test] fn test_merge_strategy_from_str() { - assert_eq!(MergeStrategy::from_str("convex"), Some(MergeStrategy::Convex)); + assert_eq!( + MergeStrategy::from_str("convex"), + Some(MergeStrategy::Convex) + ); assert_eq!(MergeStrategy::from_str("rrf"), Some(MergeStrategy::Rrf)); assert_eq!(MergeStrategy::from_str("unknown"), None); }