feat(vector): implement VectorMergeStrategy for hybrid search (P5.12 §13.12)
Add vector/hybrid search sharding support per plan §13.12: - VectorMergeStrategy uses VectorMerger to combine over-fetched results - AdaptiveMergeStrategy selects vector or score merge based on query mode - Extend MergeInput with vector_mode and vector_config fields - Add Default impl for MergeInput to simplify test code - Add From<config::VectorSearchConfig> for vector::VectorSearchConfig - Wire up AdaptiveMergeStrategy in search handlers The implementation: - Detects vector mode (keyword-only, vector-only, hybrid) from request body - Applies over-fetch factor for vector/hybrid queries - Uses VectorMerger with convex or RRF merge strategies - Falls back to ScoreMergeStrategy for keyword-only queries Closes: miroir-uhj.12 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c37a2ae2d7
commit
ab523ef95e
5 changed files with 263 additions and 9 deletions
|
|
@ -922,3 +922,19 @@ impl Default for TracingConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Conversions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
impl From<VectorSearchConfig> for crate::vector::VectorSearchConfig {
|
||||
fn from(config: VectorSearchConfig) -> Self {
|
||||
Self {
|
||||
enabled: config.enabled,
|
||||
over_fetch_factor: config.over_fetch_factor,
|
||||
merge_strategy: config.merge_strategy,
|
||||
hybrid_alpha_default: config.hybrid_alpha_default,
|
||||
rrf_k: config.rrf_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ pub mod task_store;
|
|||
pub mod tenant;
|
||||
pub mod topology;
|
||||
pub mod ttl;
|
||||
pub mod vector;
|
||||
|
||||
// Raft prototype temporarily disabled (openraft 0.9.22 fails on Rust 1.87)
|
||||
// #[cfg(feature = "raft-proto")]
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
//! Supports pluggable merge strategies via the [`MergeStrategy`] trait.
|
||||
//! The default strategy is Reciprocal Rank Fusion (RRF) with k=60.
|
||||
|
||||
use crate::scatter::VectorMode;
|
||||
use crate::vector::{VectorHit, VectorMerger, VectorSearchConfig};
|
||||
use crate::Result;
|
||||
use serde_json::{Map, Value};
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
|
||||
/// Input to the merge operation.
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -28,6 +30,27 @@ pub struct MergeInput {
|
|||
|
||||
/// Failed shard IDs (for X-Miroir-Degraded header).
|
||||
pub failed_shards: Vec<u32>,
|
||||
|
||||
/// Vector search mode (plan §13.12).
|
||||
pub vector_mode: VectorMode,
|
||||
|
||||
/// Vector search configuration (for merge strategy selection).
|
||||
pub vector_config: Option<VectorSearchConfig>,
|
||||
}
|
||||
|
||||
impl Default for MergeInput {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
shard_hits: Vec::new(),
|
||||
offset: 0,
|
||||
limit: 20,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
failed_shards: Vec::new(),
|
||||
vector_mode: VectorMode::KeywordOnly,
|
||||
vector_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response from a single shard (node).
|
||||
|
|
@ -488,6 +511,202 @@ fn score_merge(input: MergeInput) -> Result<MergedSearchResult> {
|
|||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Vector merge strategy (plan §13.12)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Vector-aware merge strategy for hybrid search (plan §13.12).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorMergeStrategy {
|
||||
merger: VectorMerger,
|
||||
}
|
||||
|
||||
impl VectorMergeStrategy {
|
||||
pub fn new(config: &VectorSearchConfig) -> Self {
|
||||
Self {
|
||||
merger: VectorMerger::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_strategy() -> Self {
|
||||
Self::new(&VectorSearchConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VectorMergeStrategy {
|
||||
fn default() -> Self {
|
||||
Self::default_strategy()
|
||||
}
|
||||
}
|
||||
|
||||
impl MergeStrategy for VectorMergeStrategy {
|
||||
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult> {
|
||||
vector_merge(&self.merger, &input)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"vector"
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_merge(merger: &VectorMerger, input: &MergeInput) -> Result<MergedSearchResult> {
|
||||
let mut estimated_total_hits = 0u64;
|
||||
let mut max_processing_time = 0u64;
|
||||
let degraded = !input.failed_shards.is_empty();
|
||||
let mut shard_hits = Vec::new();
|
||||
|
||||
for shard_page in &input.shard_hits {
|
||||
let body = &shard_page.body;
|
||||
|
||||
if let Some(Value::Number(n)) = body.get("estimatedTotalHits") {
|
||||
if let Some(n) = n.as_u64() {
|
||||
estimated_total_hits = estimated_total_hits.saturating_add(n);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(Value::Number(n)) = body.get("processingTimeMs") {
|
||||
if let Some(n) = n.as_u64() {
|
||||
max_processing_time = max_processing_time.max(n);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(Value::Array(hits)) = body.get("hits") {
|
||||
for hit in hits {
|
||||
if let Value::Object(map) = hit {
|
||||
let pk = map
|
||||
.get("id")
|
||||
.or_else(|| map.get("pk"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let ranking_score = map
|
||||
.get("_rankingScore")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let semantic_score = map.get("_semanticScore").and_then(|v| v.as_f64());
|
||||
|
||||
let shard_id = map
|
||||
.get("_miroir_shard")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let vector_hit = match input.vector_mode {
|
||||
VectorMode::KeywordOnly => {
|
||||
VectorHit::bm25_only(pk, ranking_score, shard_id)
|
||||
}
|
||||
VectorMode::VectorOnly => {
|
||||
let semantic = semantic_score.unwrap_or(0.0);
|
||||
let mut hit = VectorHit::bm25_only(pk, semantic, shard_id);
|
||||
hit.semantic_score = Some(semantic);
|
||||
hit.combined_score = semantic;
|
||||
hit
|
||||
}
|
||||
VectorMode::Hybrid => {
|
||||
if let Some(semantic) = semantic_score {
|
||||
VectorHit::hybrid(pk, ranking_score, semantic, shard_id)
|
||||
} else {
|
||||
VectorHit::bm25_only(pk, ranking_score, shard_id)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
shard_hits.push((shard_id, vector_hit));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let merged_hits = merger.merge(shard_hits, input.limit + input.offset);
|
||||
let paginated_hits: Vec<_> = merged_hits
|
||||
.into_iter()
|
||||
.skip(input.offset)
|
||||
.take(input.limit)
|
||||
.collect();
|
||||
|
||||
let mut hits = Vec::with_capacity(paginated_hits.len());
|
||||
for doc in paginated_hits {
|
||||
let mut hit_map = Map::new();
|
||||
hit_map.insert("id".to_string(), Value::String(doc.pk.clone()));
|
||||
|
||||
if input.client_requested_score {
|
||||
hit_map.insert(
|
||||
"_rankingScore".to_string(),
|
||||
Value::Number(
|
||||
serde_json::Number::from_f64(doc.ranking_score)
|
||||
.unwrap_or_else(|| serde_json::Number::from(0)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if input.client_requested_score {
|
||||
if let Some(semantic) = doc.semantic_score {
|
||||
hit_map.insert(
|
||||
"_semanticScore".to_string(),
|
||||
Value::Number(
|
||||
serde_json::Number::from_f64(semantic)
|
||||
.unwrap_or_else(|| serde_json::Number::from(0)),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
hit_map.retain(|k, _| !k.starts_with("_miroir_"));
|
||||
hits.push(Value::Object(hit_map));
|
||||
}
|
||||
|
||||
let facet_distribution = merge_facets(&input.shard_hits, input.facets.as_deref());
|
||||
|
||||
Ok(MergedSearchResult {
|
||||
hits,
|
||||
facet_distribution,
|
||||
estimated_total_hits,
|
||||
processing_time_ms: max_processing_time,
|
||||
degraded,
|
||||
failed_shards: input.failed_shards.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Adaptive merge strategy that selects vector or score merge based on mode.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveMergeStrategy {
|
||||
vector_strategy: VectorMergeStrategy,
|
||||
score_strategy: ScoreMergeStrategy,
|
||||
}
|
||||
|
||||
impl AdaptiveMergeStrategy {
|
||||
pub fn new(vector_config: &VectorSearchConfig) -> Self {
|
||||
Self {
|
||||
vector_strategy: VectorMergeStrategy::new(vector_config),
|
||||
score_strategy: ScoreMergeStrategy::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_strategy() -> Self {
|
||||
Self::new(&VectorSearchConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AdaptiveMergeStrategy {
|
||||
fn default() -> Self {
|
||||
Self::default_strategy()
|
||||
}
|
||||
}
|
||||
|
||||
impl MergeStrategy for AdaptiveMergeStrategy {
|
||||
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult> {
|
||||
match input.vector_mode {
|
||||
VectorMode::KeywordOnly => self.score_strategy.merge(input),
|
||||
VectorMode::VectorOnly | VectorMode::Hybrid => self.vector_strategy.merge(input),
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"adaptive"
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge facet distributions from multiple shards.
|
||||
///
|
||||
/// Uses BTreeMap for stable ordering (deterministic serialization).
|
||||
|
|
|
|||
|
|
@ -347,6 +347,14 @@ pub enum VectorMode {
|
|||
Hybrid,
|
||||
}
|
||||
|
||||
impl Default for VectorMode {
|
||||
fn default() -> Self {
|
||||
Self::KeywordOnly
|
||||
}
|
||||
}
|
||||
|
||||
use crate::vector::VectorSearchConfig;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchRequest {
|
||||
pub index_uid: String,
|
||||
|
|
@ -364,6 +372,8 @@ pub struct SearchRequest {
|
|||
pub over_fetch_factor: u32,
|
||||
/// Vector search mode (keyword-only, vector-only, or hybrid).
|
||||
pub vector_mode: VectorMode,
|
||||
/// Vector search configuration (for merge strategy).
|
||||
pub vector_config: Option<VectorSearchConfig>,
|
||||
}
|
||||
|
||||
impl SearchRequest {
|
||||
|
|
@ -1097,6 +1107,8 @@ pub async fn scatter_gather_search<C: NodeClient>(
|
|||
client_requested_score: req.ranking_score,
|
||||
facets: req.facets.clone(),
|
||||
failed_shards,
|
||||
vector_mode: req.vector_mode,
|
||||
vector_config: req.vector_config.clone(),
|
||||
};
|
||||
|
||||
// Span for the merge operation
|
||||
|
|
@ -1517,6 +1529,7 @@ mod tests {
|
|||
global_idf: None,
|
||||
over_fetch_factor: 1,
|
||||
vector_mode: VectorMode::KeywordOnly,
|
||||
vector_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2014,6 +2027,7 @@ mod tests {
|
|||
global_idf: None,
|
||||
over_fetch_factor: 1,
|
||||
vector_mode: VectorMode::KeywordOnly,
|
||||
vector_config: None,
|
||||
};
|
||||
|
||||
let body = req.to_node_body();
|
||||
|
|
@ -2061,6 +2075,7 @@ mod tests {
|
|||
global_idf: None,
|
||||
over_fetch_factor: 1,
|
||||
vector_mode: VectorMode::KeywordOnly,
|
||||
vector_config: None,
|
||||
};
|
||||
|
||||
let body = req.to_node_body();
|
||||
|
|
@ -2090,6 +2105,7 @@ mod tests {
|
|||
global_idf: None,
|
||||
over_fetch_factor: 1,
|
||||
vector_mode: VectorMode::KeywordOnly,
|
||||
vector_config: None,
|
||||
};
|
||||
|
||||
let body = req.to_node_body();
|
||||
|
|
@ -2154,6 +2170,7 @@ mod tests {
|
|||
global_idf: None,
|
||||
over_fetch_factor: 3,
|
||||
vector_mode: VectorMode::KeywordOnly,
|
||||
vector_config: None,
|
||||
};
|
||||
let body = req.to_node_body();
|
||||
assert_eq!(body.get("limit"), Some(&serde_json::json!(10)));
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ impl VectorHit {
|
|||
}
|
||||
|
||||
/// Vector search merger — combines over-fetched results from multiple shards.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorMerger {
|
||||
/// Merge strategy.
|
||||
strategy: MergeStrategy,
|
||||
|
|
@ -138,8 +139,8 @@ pub struct VectorMerger {
|
|||
impl VectorMerger {
|
||||
/// Create a new vector merger.
|
||||
pub fn new(config: &VectorSearchConfig) -> Self {
|
||||
let strategy = MergeStrategy::from_str(&config.merge_strategy)
|
||||
.unwrap_or(MergeStrategy::Convex);
|
||||
let strategy =
|
||||
MergeStrategy::from_str(&config.merge_strategy).unwrap_or(MergeStrategy::Convex);
|
||||
Self {
|
||||
strategy,
|
||||
alpha: config.hybrid_alpha_default,
|
||||
|
|
@ -179,7 +180,7 @@ impl VectorMerger {
|
|||
.entry(hit.pk.clone())
|
||||
.and_modify(|e: &mut VectorHit| {
|
||||
if hit.combined_score > e.combined_score {
|
||||
*e = hit;
|
||||
*e = hit.clone();
|
||||
}
|
||||
})
|
||||
.or_insert(hit);
|
||||
|
|
@ -206,10 +207,7 @@ impl VectorMerger {
|
|||
// First, sort each shard's hits by their original ranking score
|
||||
let mut per_shard: HashMap<u32, Vec<VectorHit>> = HashMap::new();
|
||||
for (shard_id, hit) in shard_hits {
|
||||
per_shard
|
||||
.entry(shard_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(hit);
|
||||
per_shard.entry(shard_id).or_insert_with(Vec::new).push(hit);
|
||||
}
|
||||
|
||||
// Compute RRF scores
|
||||
|
|
@ -278,7 +276,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_merge_strategy_from_str() {
|
||||
assert_eq!(MergeStrategy::from_str("convex"), Some(MergeStrategy::Convex));
|
||||
assert_eq!(
|
||||
MergeStrategy::from_str("convex"),
|
||||
Some(MergeStrategy::Convex)
|
||||
);
|
||||
assert_eq!(MergeStrategy::from_str("rrf"), Some(MergeStrategy::Rrf));
|
||||
assert_eq!(MergeStrategy::from_str("unknown"), None);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue