feat(vector): implement VectorMergeStrategy for hybrid search (P5.12 §13.12)

Add vector/hybrid search sharding support per plan §13.12:
- VectorMergeStrategy uses VectorMerger to combine over-fetched results
- AdaptiveMergeStrategy selects vector or score merge based on query mode
- Extend MergeInput with vector_mode and vector_config fields
- Add Default impl for MergeInput to simplify test code
- Add From<config::VectorSearchConfig> for vector::VectorSearchConfig
- Wire up AdaptiveMergeStrategy in search handlers

The implementation:
- Detects vector mode (keyword-only, vector-only, hybrid) from request body
- Applies over-fetch factor for vector/hybrid queries
- Uses VectorMerger with convex or RRF merge strategies
- Falls back to ScoreMergeStrategy for keyword-only queries

Closes: miroir-uhj.12

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-05-24 20:24:07 -04:00
parent c37a2ae2d7
commit ab523ef95e
5 changed files with 263 additions and 9 deletions

View file

@ -922,3 +922,19 @@ impl Default for TracingConfig {
}
}
}
// ---------------------------------------------------------------------------
// Conversions
// ---------------------------------------------------------------------------
impl From<VectorSearchConfig> for crate::vector::VectorSearchConfig {
fn from(config: VectorSearchConfig) -> Self {
Self {
enabled: config.enabled,
over_fetch_factor: config.over_fetch_factor,
merge_strategy: config.merge_strategy,
hybrid_alpha_default: config.hybrid_alpha_default,
rrf_k: config.rrf_k,
}
}
}

View file

@ -56,6 +56,7 @@ pub mod task_store;
pub mod tenant;
pub mod topology;
pub mod ttl;
pub mod vector;
// Raft prototype temporarily disabled (openraft 0.9.22 fails on Rust 1.87)
// #[cfg(feature = "raft-proto")]

View file

@ -3,10 +3,12 @@
//! Supports pluggable merge strategies via the [`MergeStrategy`] trait.
//! The default strategy is Reciprocal Rank Fusion (RRF) with k=60.
use crate::scatter::VectorMode;
use crate::vector::{VectorHit, VectorMerger, VectorSearchConfig};
use crate::Result;
use serde_json::{Map, Value};
use std::cmp::Ordering;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};
/// Input to the merge operation.
#[derive(Debug, Clone)]
@ -28,6 +30,27 @@ pub struct MergeInput {
/// Failed shard IDs (for X-Miroir-Degraded header).
pub failed_shards: Vec<u32>,
/// Vector search mode (plan §13.12).
pub vector_mode: VectorMode,
/// Vector search configuration (for merge strategy selection).
pub vector_config: Option<VectorSearchConfig>,
}
impl Default for MergeInput {
fn default() -> Self {
Self {
shard_hits: Vec::new(),
offset: 0,
limit: 20,
client_requested_score: false,
facets: None,
failed_shards: Vec::new(),
vector_mode: VectorMode::KeywordOnly,
vector_config: None,
}
}
}
/// Response from a single shard (node).
@ -488,6 +511,202 @@ fn score_merge(input: MergeInput) -> Result<MergedSearchResult> {
})
}
// ---------------------------------------------------------------------------
// Vector merge strategy (plan §13.12)
// ---------------------------------------------------------------------------
/// Vector-aware merge strategy for hybrid search (plan §13.12).
#[derive(Debug, Clone)]
pub struct VectorMergeStrategy {
merger: VectorMerger,
}
impl VectorMergeStrategy {
pub fn new(config: &VectorSearchConfig) -> Self {
Self {
merger: VectorMerger::new(config),
}
}
pub fn default_strategy() -> Self {
Self::new(&VectorSearchConfig::default())
}
}
impl Default for VectorMergeStrategy {
fn default() -> Self {
Self::default_strategy()
}
}
impl MergeStrategy for VectorMergeStrategy {
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult> {
vector_merge(&self.merger, &input)
}
fn name(&self) -> &'static str {
"vector"
}
}
fn vector_merge(merger: &VectorMerger, input: &MergeInput) -> Result<MergedSearchResult> {
let mut estimated_total_hits = 0u64;
let mut max_processing_time = 0u64;
let degraded = !input.failed_shards.is_empty();
let mut shard_hits = Vec::new();
for shard_page in &input.shard_hits {
let body = &shard_page.body;
if let Some(Value::Number(n)) = body.get("estimatedTotalHits") {
if let Some(n) = n.as_u64() {
estimated_total_hits = estimated_total_hits.saturating_add(n);
}
}
if let Some(Value::Number(n)) = body.get("processingTimeMs") {
if let Some(n) = n.as_u64() {
max_processing_time = max_processing_time.max(n);
}
}
if let Some(Value::Array(hits)) = body.get("hits") {
for hit in hits {
if let Value::Object(map) = hit {
let pk = map
.get("id")
.or_else(|| map.get("pk"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let ranking_score = map
.get("_rankingScore")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let semantic_score = map.get("_semanticScore").and_then(|v| v.as_f64());
let shard_id = map
.get("_miroir_shard")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let vector_hit = match input.vector_mode {
VectorMode::KeywordOnly => {
VectorHit::bm25_only(pk, ranking_score, shard_id)
}
VectorMode::VectorOnly => {
let semantic = semantic_score.unwrap_or(0.0);
let mut hit = VectorHit::bm25_only(pk, semantic, shard_id);
hit.semantic_score = Some(semantic);
hit.combined_score = semantic;
hit
}
VectorMode::Hybrid => {
if let Some(semantic) = semantic_score {
VectorHit::hybrid(pk, ranking_score, semantic, shard_id)
} else {
VectorHit::bm25_only(pk, ranking_score, shard_id)
}
}
};
shard_hits.push((shard_id, vector_hit));
}
}
}
}
let merged_hits = merger.merge(shard_hits, input.limit + input.offset);
let paginated_hits: Vec<_> = merged_hits
.into_iter()
.skip(input.offset)
.take(input.limit)
.collect();
let mut hits = Vec::with_capacity(paginated_hits.len());
for doc in paginated_hits {
let mut hit_map = Map::new();
hit_map.insert("id".to_string(), Value::String(doc.pk.clone()));
if input.client_requested_score {
hit_map.insert(
"_rankingScore".to_string(),
Value::Number(
serde_json::Number::from_f64(doc.ranking_score)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
}
if input.client_requested_score {
if let Some(semantic) = doc.semantic_score {
hit_map.insert(
"_semanticScore".to_string(),
Value::Number(
serde_json::Number::from_f64(semantic)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
}
}
hit_map.retain(|k, _| !k.starts_with("_miroir_"));
hits.push(Value::Object(hit_map));
}
let facet_distribution = merge_facets(&input.shard_hits, input.facets.as_deref());
Ok(MergedSearchResult {
hits,
facet_distribution,
estimated_total_hits,
processing_time_ms: max_processing_time,
degraded,
failed_shards: input.failed_shards.clone(),
})
}
/// Adaptive merge strategy that selects vector or score merge based on mode.
#[derive(Debug, Clone)]
pub struct AdaptiveMergeStrategy {
vector_strategy: VectorMergeStrategy,
score_strategy: ScoreMergeStrategy,
}
impl AdaptiveMergeStrategy {
pub fn new(vector_config: &VectorSearchConfig) -> Self {
Self {
vector_strategy: VectorMergeStrategy::new(vector_config),
score_strategy: ScoreMergeStrategy::new(),
}
}
pub fn default_strategy() -> Self {
Self::new(&VectorSearchConfig::default())
}
}
impl Default for AdaptiveMergeStrategy {
fn default() -> Self {
Self::default_strategy()
}
}
impl MergeStrategy for AdaptiveMergeStrategy {
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult> {
match input.vector_mode {
VectorMode::KeywordOnly => self.score_strategy.merge(input),
VectorMode::VectorOnly | VectorMode::Hybrid => self.vector_strategy.merge(input),
}
}
fn name(&self) -> &'static str {
"adaptive"
}
}
/// Merge facet distributions from multiple shards.
///
/// Uses BTreeMap for stable ordering (deterministic serialization).

View file

@ -347,6 +347,14 @@ pub enum VectorMode {
Hybrid,
}
impl Default for VectorMode {
fn default() -> Self {
Self::KeywordOnly
}
}
use crate::vector::VectorSearchConfig;
#[derive(Debug, Clone)]
pub struct SearchRequest {
pub index_uid: String,
@ -364,6 +372,8 @@ pub struct SearchRequest {
pub over_fetch_factor: u32,
/// Vector search mode (keyword-only, vector-only, or hybrid).
pub vector_mode: VectorMode,
/// Vector search configuration (for merge strategy).
pub vector_config: Option<VectorSearchConfig>,
}
impl SearchRequest {
@ -1097,6 +1107,8 @@ pub async fn scatter_gather_search<C: NodeClient>(
client_requested_score: req.ranking_score,
facets: req.facets.clone(),
failed_shards,
vector_mode: req.vector_mode,
vector_config: req.vector_config.clone(),
};
// Span for the merge operation
@ -1517,6 +1529,7 @@ mod tests {
global_idf: None,
over_fetch_factor: 1,
vector_mode: VectorMode::KeywordOnly,
vector_config: None,
}
}
@ -2014,6 +2027,7 @@ mod tests {
global_idf: None,
over_fetch_factor: 1,
vector_mode: VectorMode::KeywordOnly,
vector_config: None,
};
let body = req.to_node_body();
@ -2061,6 +2075,7 @@ mod tests {
global_idf: None,
over_fetch_factor: 1,
vector_mode: VectorMode::KeywordOnly,
vector_config: None,
};
let body = req.to_node_body();
@ -2090,6 +2105,7 @@ mod tests {
global_idf: None,
over_fetch_factor: 1,
vector_mode: VectorMode::KeywordOnly,
vector_config: None,
};
let body = req.to_node_body();
@ -2154,6 +2170,7 @@ mod tests {
global_idf: None,
over_fetch_factor: 3,
vector_mode: VectorMode::KeywordOnly,
vector_config: None,
};
let body = req.to_node_body();
assert_eq!(body.get("limit"), Some(&serde_json::json!(10)));

View file

@ -126,6 +126,7 @@ impl VectorHit {
}
/// Vector search merger — combines over-fetched results from multiple shards.
#[derive(Debug, Clone)]
pub struct VectorMerger {
/// Merge strategy.
strategy: MergeStrategy,
@ -138,8 +139,8 @@ pub struct VectorMerger {
impl VectorMerger {
/// Create a new vector merger.
pub fn new(config: &VectorSearchConfig) -> Self {
let strategy = MergeStrategy::from_str(&config.merge_strategy)
.unwrap_or(MergeStrategy::Convex);
let strategy =
MergeStrategy::from_str(&config.merge_strategy).unwrap_or(MergeStrategy::Convex);
Self {
strategy,
alpha: config.hybrid_alpha_default,
@ -179,7 +180,7 @@ impl VectorMerger {
.entry(hit.pk.clone())
.and_modify(|e: &mut VectorHit| {
if hit.combined_score > e.combined_score {
*e = hit;
*e = hit.clone();
}
})
.or_insert(hit);
@ -206,10 +207,7 @@ impl VectorMerger {
// First, sort each shard's hits by their original ranking score
let mut per_shard: HashMap<u32, Vec<VectorHit>> = HashMap::new();
for (shard_id, hit) in shard_hits {
per_shard
.entry(shard_id)
.or_insert_with(Vec::new)
.push(hit);
per_shard.entry(shard_id).or_insert_with(Vec::new).push(hit);
}
// Compute RRF scores
@ -278,7 +276,10 @@ mod tests {
#[test]
fn test_merge_strategy_from_str() {
assert_eq!(MergeStrategy::from_str("convex"), Some(MergeStrategy::Convex));
assert_eq!(
MergeStrategy::from_str("convex"),
Some(MergeStrategy::Convex)
);
assert_eq!(MergeStrategy::from_str("rrf"), Some(MergeStrategy::Rrf));
assert_eq!(MergeStrategy::from_str("unknown"), None);
}