P2.2: Pluggable MergeStrategy trait + RRF scoring + full benchmark re-run
- Extract MergeStrategy trait with merge()/name() methods - Implement RrfStrategy with configurable k (default 60) - Refactor scatter_gather_search to accept &dyn MergeStrategy - Add RRF simulation to benchmark script (simulate_distributed_search_rrf) - Re-run full benchmark (3989 queries) with updated comparison reports - Add topology unit tests (NodeId, NodeStatus, Node helpers) Benchmark results: Score-based merge: avg tau = 0.798 (FAIL, common-term tau = 0.152) RRF merge: avg tau = 0.134 (FAIL, rank-only loses score signal) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
1124d97c14
commit
0de5f01d32
5 changed files with 45492 additions and 222 deletions
|
|
@ -1,4 +1,7 @@
|
|||
//! Result merger: combines shard results into a single response.
|
||||
//!
|
||||
//! Supports pluggable merge strategies via the [`MergeStrategy`] trait.
|
||||
//! The default strategy is Reciprocal Rank Fusion (RRF) with k=60.
|
||||
|
||||
use crate::Result;
|
||||
use serde_json::{Map, Value};
|
||||
|
|
@ -50,12 +53,78 @@ pub struct MergedSearchResult {
|
|||
pub degraded: bool,
|
||||
}
|
||||
|
||||
/// RRF constant k.
|
||||
// ---------------------------------------------------------------------------
|
||||
// Merge strategy trait
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Pluggable merge strategy for combining shard results.
|
||||
///
|
||||
/// This is the denominator constant used in Reciprocal Rank Fusion.
|
||||
/// The value 60 is the default recommended in the RRF literature and
|
||||
/// is used by OpenSearch for hybrid search.
|
||||
const RRF_K: u32 = 60;
|
||||
/// Implementations define how hits from multiple shards are combined
|
||||
/// into a single globally-ranked response. The default strategy is
|
||||
/// [`RrfStrategy`] (Reciprocal Rank Fusion).
|
||||
pub trait MergeStrategy: Send + Sync {
|
||||
/// Merge search results from multiple shards into a single response.
|
||||
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult>;
|
||||
|
||||
/// Strategy name (for logging and `/explain` output).
|
||||
fn name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RRF strategy
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Default RRF constant k.
|
||||
///
|
||||
/// The value 60 is recommended in the RRF literature and used by
|
||||
/// OpenSearch for hybrid search. Smaller values amplify rank
|
||||
/// differences; larger values flatten them.
|
||||
pub const DEFAULT_RRF_K: u32 = 60;
|
||||
|
||||
/// Reciprocal Rank Fusion merge strategy.
|
||||
///
|
||||
/// Each document's contribution from a shard is `1 / (k + rank + 1)`
|
||||
/// where rank is the 0-based position. Documents appearing in
|
||||
/// multiple shards have their contributions summed. Results are
|
||||
/// sorted by total RRF score descending, with deterministic
|
||||
/// tie-breaking on primary key.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RrfStrategy {
|
||||
k: u32,
|
||||
}
|
||||
|
||||
impl RrfStrategy {
|
||||
/// Create a new RRF strategy with the given k constant.
|
||||
pub fn new(k: u32) -> Self {
|
||||
Self { k: k.max(1) }
|
||||
}
|
||||
|
||||
/// Create with the default k=60.
|
||||
pub fn default_strategy() -> Self {
|
||||
Self::new(DEFAULT_RRF_K)
|
||||
}
|
||||
|
||||
/// Return the configured k value.
|
||||
pub fn k(&self) -> u32 {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RrfStrategy {
|
||||
fn default() -> Self {
|
||||
Self::default_strategy()
|
||||
}
|
||||
}
|
||||
|
||||
impl MergeStrategy for RrfStrategy {
|
||||
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult> {
|
||||
rrf_merge(&self.k, input)
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"rrf"
|
||||
}
|
||||
}
|
||||
|
||||
/// A document with its accumulated RRF score.
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -107,11 +176,27 @@ impl Ord for RRFDocument {
|
|||
}
|
||||
}
|
||||
|
||||
/// Merge search results from multiple shards into a single response.
|
||||
/// Merge search results using the default RRF strategy (k=60).
|
||||
///
|
||||
/// This is a pure function with no side effects, making it testable
|
||||
/// without a network and ensuring deterministic output.
|
||||
/// This is a convenience wrapper around [`RrfStrategy`] for callers
|
||||
/// that don't need to customise the strategy.
|
||||
pub fn merge(input: MergeInput) -> Result<MergedSearchResult> {
|
||||
rrf_merge(&DEFAULT_RRF_K, input)
|
||||
}
|
||||
|
||||
/// Merge search results with a specific strategy.
|
||||
///
|
||||
/// Use this when the strategy is selected from config or when you
|
||||
/// need a non-default RRF k value.
|
||||
pub fn merge_with_strategy(
|
||||
strategy: &dyn MergeStrategy,
|
||||
input: MergeInput,
|
||||
) -> Result<MergedSearchResult> {
|
||||
strategy.merge(input)
|
||||
}
|
||||
|
||||
/// Core RRF merge implementation.
|
||||
fn rrf_merge(k: &u32, input: MergeInput) -> Result<MergedSearchResult> {
|
||||
let mut estimated_total_hits = 0u64;
|
||||
let mut max_processing_time = 0u64;
|
||||
let mut degraded = false;
|
||||
|
|
@ -157,9 +242,9 @@ pub fn merge(input: MergeInput) -> Result<MergedSearchResult> {
|
|||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Compute RRF contribution: 1 / (k + rank)
|
||||
// rank is 0-based, so we add 1 to convert to 1-based for RRF formula
|
||||
let rrf_contribution = 1.0 / ((RRF_K as f64) + (rank as f64) + 1.0);
|
||||
// RRF contribution: 1 / (k + rank + 1)
|
||||
// rank is 0-based, so +1 converts to 1-based position.
|
||||
let rrf_contribution = 1.0 / ((*k as f64) + (rank as f64) + 1.0);
|
||||
|
||||
// Aggregate RRF scores across shards.
|
||||
use std::collections::hash_map::Entry;
|
||||
|
|
@ -268,6 +353,85 @@ mod tests {
|
|||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Trait / strategy tests
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_rrf_strategy_default_matches_free_function() {
|
||||
let input = MergeInput {
|
||||
shard_hits: vec![make_shard_response(
|
||||
vec![make_hit("doc1", 0.9, 0), make_hit("doc2", 0.7, 0)],
|
||||
100,
|
||||
15,
|
||||
)],
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
};
|
||||
|
||||
let strategy = RrfStrategy::default_strategy();
|
||||
let via_trait = strategy.merge(input.clone()).unwrap();
|
||||
let via_free = merge(input).unwrap();
|
||||
|
||||
assert_eq!(via_trait.hits, via_free.hits);
|
||||
assert_eq!(via_trait.estimated_total_hits, via_free.estimated_total_hits);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_strategy_custom_k() {
|
||||
// With k=1, rank 0 gets 1/(1+0+1) = 0.5
|
||||
// With k=60, rank 0 gets 1/(60+0+1) ≈ 0.0164
|
||||
// Both should produce the same ordering for a single shard.
|
||||
let input = MergeInput {
|
||||
shard_hits: vec![make_shard_response(
|
||||
vec![make_hit("a", 0.9, 0), make_hit("b", 0.5, 0)],
|
||||
50,
|
||||
10,
|
||||
)],
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
};
|
||||
|
||||
let strategy_k1 = RrfStrategy::new(1);
|
||||
let result = strategy_k1.merge(input).unwrap();
|
||||
assert_eq!(result.hits[0].get("id").unwrap(), "a");
|
||||
assert_eq!(result.hits[1].get("id").unwrap(), "b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_with_strategy_dispatches() {
|
||||
let input = MergeInput {
|
||||
shard_hits: vec![make_shard_response(
|
||||
vec![make_hit("doc1", 0.9, 0)],
|
||||
50,
|
||||
10,
|
||||
)],
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
};
|
||||
|
||||
let strategy = RrfStrategy::default_strategy();
|
||||
let result = merge_with_strategy(&strategy, input).unwrap();
|
||||
assert_eq!(result.hits.len(), 1);
|
||||
assert_eq!(strategy.name(), "rrf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_strategy_k_clamped_to_one() {
|
||||
let strategy = RrfStrategy::new(0);
|
||||
assert_eq!(strategy.k(), 1);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Core merge tests
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
fn make_hit(id: &str, score: f64, shard: u32) -> Value {
|
||||
json!({
|
||||
"id": id,
|
||||
|
|
@ -840,4 +1004,201 @@ mod tests {
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// RRF correctness properties
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_rrf_cross_shard_replication_boost() {
|
||||
// Key RRF property: a document appearing in multiple shards
|
||||
// gets a higher merged score than a single-shard document,
|
||||
// even if it ranks lower within each shard.
|
||||
//
|
||||
// doc-replicated: rank 5 in shard0, rank 5 in shard1
|
||||
// RRF = 1/(60+5+1) + 1/(60+5+1) = 2/66 ≈ 0.0303
|
||||
// doc-single: rank 0 in shard2
|
||||
// RRF = 1/(60+0+1) = 1/61 ≈ 0.0164
|
||||
//
|
||||
// doc-replicated should rank higher.
|
||||
|
||||
let mut shard0 = vec![];
|
||||
let mut shard1 = vec![];
|
||||
for i in 0..5 {
|
||||
shard0.push(make_hit(&format!("filler0-{}", i), 0.5, 0));
|
||||
shard1.push(make_hit(&format!("filler1-{}", i), 0.5, 1));
|
||||
}
|
||||
shard0.push(make_hit("doc-replicated", 0.1, 0));
|
||||
shard1.push(make_hit("doc-replicated", 0.1, 1));
|
||||
|
||||
let shard2 = vec![make_hit("doc-single", 0.99, 2)];
|
||||
|
||||
let strategy = RrfStrategy::default_strategy();
|
||||
let result = strategy
|
||||
.merge(MergeInput {
|
||||
shard_hits: vec![
|
||||
make_shard_response(shard0, 100, 10),
|
||||
make_shard_response(shard1, 100, 10),
|
||||
make_shard_response(shard2, 100, 10),
|
||||
],
|
||||
offset: 0,
|
||||
limit: 20,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let ids: Vec<_> = result.hits.iter().filter_map(|h| h.get("id").and_then(|v| v.as_str())).collect();
|
||||
let rep_pos = ids.iter().position(|&id| id == "doc-replicated").unwrap();
|
||||
let single_pos = ids.iter().position(|&id| id == "doc-single").unwrap();
|
||||
assert!(
|
||||
rep_pos < single_pos,
|
||||
"Replicated doc at pos {} should rank above single doc at pos {}",
|
||||
rep_pos,
|
||||
single_pos
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_immune_to_score_scale() {
|
||||
// RRF uses rank only: scores of 0.001 vs 0.999 don't affect ordering.
|
||||
// Two shards with wildly different score ranges should produce the
|
||||
// same merge as two shards with uniform scores.
|
||||
let shard_a = ShardHitPage {
|
||||
body: json!({
|
||||
"hits": [
|
||||
{"id": "a1", "_rankingScore": 0.001},
|
||||
{"id": "a2", "_rankingScore": 0.002},
|
||||
],
|
||||
"estimatedTotalHits": 2,
|
||||
"processingTimeMs": 5,
|
||||
}),
|
||||
};
|
||||
let shard_b = ShardHitPage {
|
||||
body: json!({
|
||||
"hits": [
|
||||
{"id": "b1", "_rankingScore": 0.999},
|
||||
{"id": "b2", "_rankingScore": 0.998},
|
||||
],
|
||||
"estimatedTotalHits": 2,
|
||||
"processingTimeMs": 5,
|
||||
}),
|
||||
};
|
||||
|
||||
let strategy = RrfStrategy::default_strategy();
|
||||
let result = strategy
|
||||
.merge(MergeInput {
|
||||
shard_hits: vec![shard_a, shard_b],
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// All at rank 0 or 1 — tie-break alphabetically within each rank tier.
|
||||
// rank 0: a1, b1 → sort by id → a1, b1
|
||||
// rank 1: a2, b2 → sort by id → a2, b2
|
||||
let ids: Vec<_> = result.hits.iter().filter_map(|h| h.get("id").and_then(|v| v.as_str())).collect();
|
||||
assert_eq!(ids, vec!["a1", "b1", "a2", "b2"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_deterministic_with_same_input() {
|
||||
// RRF merge is a pure function: same input always produces same output.
|
||||
let shard = make_shard_response(
|
||||
(0..100).map(|i| make_hit(&format!("doc{}", i), (100 - i) as f64 / 100.0, 0)).collect(),
|
||||
1000,
|
||||
10,
|
||||
);
|
||||
let input = MergeInput {
|
||||
shard_hits: vec![shard; 5],
|
||||
offset: 0,
|
||||
limit: 50,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
};
|
||||
|
||||
let strategy = RrfStrategy::default_strategy();
|
||||
let r1 = strategy.merge(input.clone()).unwrap();
|
||||
let r2 = strategy.merge(input).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_vec(&r1).unwrap(),
|
||||
serde_json::to_vec(&r2).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_default_impl() {
|
||||
let via_default: RrfStrategy = Default::default();
|
||||
let via_constructor = RrfStrategy::default_strategy();
|
||||
assert_eq!(via_default.k(), via_constructor.k());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_pk_field_as_primary_key() {
|
||||
let shard = ShardHitPage {
|
||||
body: json!({
|
||||
"hits": [{"pk": "doc-pk", "title": "Test"}],
|
||||
"estimatedTotalHits": 1,
|
||||
"processingTimeMs": 5,
|
||||
}),
|
||||
};
|
||||
let input = MergeInput {
|
||||
shard_hits: vec![shard],
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
client_requested_score: false,
|
||||
facets: None,
|
||||
};
|
||||
let result = merge(input).unwrap();
|
||||
assert_eq!(result.hits.len(), 1);
|
||||
assert_eq!(result.hits[0].get("pk").unwrap(), "doc-pk");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_document_equality() {
|
||||
use super::{RRFDocument, Ordering};
|
||||
let a = super::RRFDocument {
|
||||
rrf_score: 1.0,
|
||||
primary_key: "doc1".into(),
|
||||
hit: Map::new(),
|
||||
};
|
||||
let b = super::RRFDocument {
|
||||
rrf_score: 1.0,
|
||||
primary_key: "doc1".into(),
|
||||
hit: Map::new(),
|
||||
};
|
||||
assert_eq!(a, b);
|
||||
|
||||
let c = super::RRFDocument {
|
||||
rrf_score: 1.0,
|
||||
primary_key: "doc2".into(),
|
||||
hit: Map::new(),
|
||||
};
|
||||
assert_ne!(a, c);
|
||||
|
||||
// NaN: both NaN → Equal
|
||||
let nan_a = super::RRFDocument {
|
||||
rrf_score: f64::NAN,
|
||||
primary_key: "x".into(),
|
||||
hit: Map::new(),
|
||||
};
|
||||
let nan_b = super::RRFDocument {
|
||||
rrf_score: f64::NAN,
|
||||
primary_key: "x".into(),
|
||||
hit: Map::new(),
|
||||
};
|
||||
assert_eq!(nan_a.cmp(&nan_b), Ordering::Equal);
|
||||
|
||||
// NaN vs real: NaN is Less
|
||||
let real = super::RRFDocument {
|
||||
rrf_score: 1.0,
|
||||
primary_key: "x".into(),
|
||||
hit: Map::new(),
|
||||
};
|
||||
assert_eq!(nan_a.cmp(&real), Ordering::Less);
|
||||
assert_eq!(real.cmp(&nan_a), Ordering::Greater);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
//! Scatter orchestration: fan-out logic and covering set builder.
|
||||
|
||||
use crate::config::UnavailableShardPolicy;
|
||||
use crate::merger::{merge, MergeInput, MergedSearchResult, ShardHitPage};
|
||||
use crate::merger::{MergeInput, MergedSearchResult, MergeStrategy, ShardHitPage};
|
||||
use crate::router::{covering_set, query_group};
|
||||
use crate::topology::{NodeId, Topology};
|
||||
use crate::Result;
|
||||
|
|
@ -292,10 +292,10 @@ pub async fn execute_scatter<C: NodeClient>(
|
|||
})
|
||||
}
|
||||
|
||||
/// Execute a full scatter-gather search: fan out to nodes, then RRF-merge results.
|
||||
/// Execute a full scatter-gather search: fan out to nodes, then merge results.
|
||||
///
|
||||
/// This is the primary entry point for the read path. It combines
|
||||
/// `execute_scatter` (fan-out) with `merge` (RRF result merging)
|
||||
/// `execute_scatter` (fan-out) with the configured merge strategy
|
||||
/// into a single operation.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
@ -304,15 +304,17 @@ pub async fn execute_scatter<C: NodeClient>(
|
|||
/// * `req` - Search request to execute
|
||||
/// * `topology` - Current topology (for resolving node addresses)
|
||||
/// * `policy` - Policy for handling unavailable shards
|
||||
/// * `strategy` - Merge strategy (e.g. `RrfStrategy`)
|
||||
///
|
||||
/// # Returns
|
||||
/// A `MergedSearchResult` with globally ranked hits using RRF.
|
||||
/// A `MergedSearchResult` with globally ranked hits.
|
||||
pub async fn scatter_gather_search<C: NodeClient>(
|
||||
plan: ScatterPlan,
|
||||
client: &C,
|
||||
req: SearchRequest,
|
||||
topology: &Topology,
|
||||
policy: UnavailableShardPolicy,
|
||||
strategy: &dyn MergeStrategy,
|
||||
) -> Result<MergedSearchResult> {
|
||||
let scatter_result = execute_scatter(plan, client, req.clone(), topology, policy).await?;
|
||||
|
||||
|
|
@ -338,7 +340,7 @@ pub async fn scatter_gather_search<C: NodeClient>(
|
|||
facets: req.facets.clone(),
|
||||
};
|
||||
|
||||
merge(merge_input)
|
||||
strategy.merge(merge_input)
|
||||
}
|
||||
|
||||
/// Stubs for testing (no actual network calls).
|
||||
|
|
@ -581,6 +583,92 @@ mod tests {
|
|||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_search_scatter_empty_topology() {
|
||||
// Topology with 0 nodes in group 1 — query_seq=1 targets group 1 which has no nodes
|
||||
let mut topo = Topology::new(64, 2, 1);
|
||||
for i in 0u32..3 {
|
||||
let mut node = Node::new(
|
||||
NodeId::new(format!("node-{i}")),
|
||||
format!("http://node-{i}:7700"),
|
||||
0,
|
||||
);
|
||||
node.status = crate::topology::NodeStatus::Active;
|
||||
topo.add_node(node);
|
||||
}
|
||||
|
||||
let plan = plan_search_scatter(&topo, 0, 1, 64);
|
||||
assert_eq!(plan.chosen_group, 0);
|
||||
assert_eq!(plan.shard_to_node.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_search_scatter_invalid_group_returns_empty() {
|
||||
// Create a topology with no groups at all
|
||||
let topo = Topology::new(64, 0, 1);
|
||||
let plan = plan_search_scatter(&topo, 0, 1, 64);
|
||||
// Should return empty plan since no group exists
|
||||
assert!(plan.shard_to_node.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_scatter_fallback_policy() {
|
||||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
|
||||
let mut client = MockNodeClient::default();
|
||||
client.errors.insert(
|
||||
NodeId::new("node-0".to_string()),
|
||||
NodeError::Timeout,
|
||||
);
|
||||
|
||||
let req = SearchRequest {
|
||||
index_uid: "test".to_string(),
|
||||
query: Some("test".to_string()),
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
filter: None,
|
||||
facets: None,
|
||||
ranking_score: false,
|
||||
body: serde_json::json!({}),
|
||||
};
|
||||
|
||||
// Fallback policy should behave like Partial for now
|
||||
let result = execute_scatter(plan, &client, req, &topo, UnavailableShardPolicy::Fallback)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.partial);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_scatter_node_not_in_topology() {
|
||||
// Build a plan, then use a topology that doesn't have the plan's nodes
|
||||
let mut topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
|
||||
// Empty topology — none of the plan's nodes exist
|
||||
let empty_topo = Topology::new(64, 2, 2);
|
||||
|
||||
let client = MockNodeClient::default();
|
||||
let req = SearchRequest {
|
||||
index_uid: "test".to_string(),
|
||||
query: Some("test".to_string()),
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
filter: None,
|
||||
facets: None,
|
||||
ranking_score: false,
|
||||
body: serde_json::json!({}),
|
||||
};
|
||||
|
||||
let result = execute_scatter(plan, &client, req, &empty_topo, UnavailableShardPolicy::Partial)
|
||||
.await
|
||||
.unwrap();
|
||||
// All shards should fail since no nodes in topology
|
||||
assert!(result.partial);
|
||||
assert!(!result.failed_shards.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_error_variants() {
|
||||
let timeout = NodeError::Timeout;
|
||||
|
|
@ -642,7 +730,8 @@ mod tests {
|
|||
body: serde_json::json!({}),
|
||||
};
|
||||
|
||||
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial)
|
||||
let strategy = crate::merger::RrfStrategy::default_strategy();
|
||||
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial, &strategy)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -683,10 +772,47 @@ mod tests {
|
|||
body: serde_json::json!({}),
|
||||
};
|
||||
|
||||
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial)
|
||||
let strategy = crate::merger::RrfStrategy::default_strategy();
|
||||
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial, &strategy)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.degraded);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scatter_gather_with_custom_k() {
|
||||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
|
||||
let mut client = MockNodeClient::default();
|
||||
client.responses.insert(
|
||||
NodeId::new("node-0".to_string()),
|
||||
serde_json::json!({
|
||||
"hits": [{"id": "doc-a", "title": "Doc A"}],
|
||||
"estimatedTotalHits": 1,
|
||||
"processingTimeMs": 5,
|
||||
}),
|
||||
);
|
||||
|
||||
let req = SearchRequest {
|
||||
index_uid: "test".to_string(),
|
||||
query: Some("test".to_string()),
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
filter: None,
|
||||
facets: None,
|
||||
ranking_score: false,
|
||||
body: serde_json::json!({}),
|
||||
};
|
||||
|
||||
let strategy = crate::merger::RrfStrategy::new(1);
|
||||
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial, &strategy)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(strategy.name(), "rrf");
|
||||
assert_eq!(strategy.k(), 1);
|
||||
assert!(!result.degraded);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -727,6 +727,112 @@ nodes:
|
|||
assert_eq!(topo2.group(1).unwrap().node_count(), 3);
|
||||
}
|
||||
|
||||
// ── NodeId conversions ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn nodeid_from_string_and_as_ref() {
|
||||
let id: NodeId = "test-node".to_string().into();
|
||||
assert_eq!(id.as_str(), "test-node");
|
||||
assert_eq!(AsRef::<str>::as_ref(&id), "test-node");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nodeid_display_impl() {
|
||||
let id = NodeId::new("my-node".to_string());
|
||||
assert_eq!(format!("{}", id), "my-node");
|
||||
}
|
||||
|
||||
// ── NodeStatus helpers ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn is_readable_covers_all_statuses() {
|
||||
use NodeStatus::*;
|
||||
assert!(Active.is_readable());
|
||||
assert!(Healthy.is_readable());
|
||||
assert!(Degraded.is_readable());
|
||||
assert!(Draining.is_readable());
|
||||
assert!(!Failed.is_readable());
|
||||
assert!(!Joining.is_readable());
|
||||
assert!(!Removed.is_readable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_active_covers_all_statuses() {
|
||||
use NodeStatus::*;
|
||||
assert!(Active.is_active());
|
||||
assert!(Healthy.is_active());
|
||||
assert!(Degraded.is_active());
|
||||
assert!(!Draining.is_active());
|
||||
assert!(!Failed.is_active());
|
||||
assert!(!Joining.is_active());
|
||||
assert!(!Removed.is_active());
|
||||
}
|
||||
|
||||
// ── Node::is_healthy ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn node_is_healthy_covers_all_statuses() {
|
||||
use NodeStatus::*;
|
||||
for (status, expected) in [
|
||||
(Active, true),
|
||||
(Healthy, true),
|
||||
(Degraded, true),
|
||||
(Draining, false),
|
||||
(Failed, false),
|
||||
(Joining, false),
|
||||
(Removed, false),
|
||||
] {
|
||||
let node = Node {
|
||||
id: NodeId::new("test".into()),
|
||||
address: "http://test:7700".into(),
|
||||
replica_group: 0,
|
||||
status,
|
||||
};
|
||||
assert_eq!(node.is_healthy(), expected, "{:?} is_healthy", status);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Group::add_node duplicate prevention ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn group_add_node_prevents_duplicates() {
|
||||
let mut g = Group::new(0);
|
||||
g.add_node(NodeId::new("a".into()));
|
||||
g.add_node(NodeId::new("a".into()));
|
||||
g.add_node(NodeId::new("b".into()));
|
||||
assert_eq!(g.node_count(), 2);
|
||||
}
|
||||
|
||||
// ── Topology with auto-derived replica_groups ─────────────────────
|
||||
|
||||
#[test]
|
||||
fn topology_auto_derives_replica_groups_from_nodes() {
|
||||
let mut topo = Topology::new(64, 1, 1);
|
||||
topo.add_node(Node::new(NodeId::new("n0".into()), "http://n0:7700".into(), 0));
|
||||
topo.add_node(Node::new(NodeId::new("n1".into()), "http://n1:7700".into(), 2));
|
||||
// replica_groups should auto-derive to 3
|
||||
assert_eq!(topo.replica_groups, 3);
|
||||
assert!(topo.group(2).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn topology_node_lookup() {
|
||||
let mut topo = make_test_topology();
|
||||
assert!(topo.node(&NodeId::new("meili-0".into())).is_some());
|
||||
assert!(topo.node(&NodeId::new("nonexistent".into())).is_none());
|
||||
|
||||
// Mutate via node_mut
|
||||
let id = NodeId::new("meili-0".into());
|
||||
topo.node_mut(&id).unwrap().status = NodeStatus::Failed;
|
||||
assert_eq!(topo.node(&id).unwrap().status, NodeStatus::Failed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn topology_replica_group_count() {
|
||||
let topo = make_test_topology();
|
||||
assert_eq!(topo.replica_group_count(), 2);
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────
|
||||
|
||||
fn make_test_topology() -> Topology {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -56,150 +56,132 @@ def tokenize(text: str) -> Set[str]:
|
|||
return set(text.lower().split())
|
||||
|
||||
|
||||
def compute_global_stats(docs: List[Dict]) -> Tuple[Dict, int, float]:
|
||||
"""
|
||||
Compute global index statistics.
|
||||
|
||||
Returns:
|
||||
- df: term -> document frequency (how many docs contain the term)
|
||||
- N: total document count
|
||||
- avgdl: average document length
|
||||
"""
|
||||
df = defaultdict(int)
|
||||
total_length = 0
|
||||
|
||||
for doc in docs:
|
||||
text = f"{doc['title']} {doc['content']}".lower()
|
||||
terms = set(text.split())
|
||||
for term in terms:
|
||||
df[term] += 1
|
||||
total_length += len(text.split())
|
||||
|
||||
N = len(docs)
|
||||
avgdl = total_length / N if N > 0 else 0
|
||||
|
||||
return dict(df), N, avgdl
|
||||
|
||||
|
||||
def compute_shard_stats(shard_docs: List[Dict]) -> Tuple[Dict, int, float]:
|
||||
"""Compute statistics for a single shard (same as global but scoped)."""
|
||||
return compute_global_stats(shard_docs)
|
||||
|
||||
|
||||
def idf_global(df: int, N: int) -> float:
|
||||
"""
|
||||
Standard IDF formula used by most search engines.
|
||||
|
||||
IDF = log((N - df + 0.5) / (df + 0.5) + 1)
|
||||
"""
|
||||
"""Standard IDF formula: log((N - df + 0.5) / (df + 0.5) + 1)."""
|
||||
if df == 0:
|
||||
return 0.0
|
||||
return math.log((N - df + 0.5) / (df + 0.5) + 1.0)
|
||||
|
||||
|
||||
def score_document_bm25(
|
||||
doc: Dict,
|
||||
# Pre-computed per-document data for fast BM25 scoring.
|
||||
DocData = Dict[str, object] # {"id": str, "tf": {term: count}, "len": int, "title": str, "category": str}
|
||||
|
||||
|
||||
def precompute_doc_data(docs: List[Dict]) -> List[DocData]:
|
||||
"""Pre-compute term frequencies and lengths for all documents."""
|
||||
result = []
|
||||
for doc in docs:
|
||||
text = f"{doc['title']} {doc['content']}".lower()
|
||||
words = text.split()
|
||||
tf: Dict[str, int] = defaultdict(int)
|
||||
for w in words:
|
||||
tf[w] += 1
|
||||
result.append({
|
||||
"id": doc["id"],
|
||||
"title": doc["title"],
|
||||
"category": doc.get("category", ""),
|
||||
"tf": dict(tf),
|
||||
"len": len(words),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def compute_stats(doc_data_list: List[DocData]) -> Tuple[Dict, int, float]:
|
||||
"""Compute index statistics (df, N, avgdl) from pre-computed doc data."""
|
||||
df: Dict[str, int] = defaultdict(int)
|
||||
total_length = 0
|
||||
for dd in doc_data_list:
|
||||
for term in dd["tf"]:
|
||||
df[term] += 1
|
||||
total_length += dd["len"]
|
||||
N = len(doc_data_list)
|
||||
avgdl = total_length / N if N > 0 else 0
|
||||
return dict(df), N, avgdl
|
||||
|
||||
|
||||
def build_inverted_index(doc_data_list: List[DocData]) -> Dict[str, List[int]]:
|
||||
"""Build inverted index: term -> [doc_index, ...]."""
|
||||
index: Dict[str, List[int]] = defaultdict(list)
|
||||
for i, dd in enumerate(doc_data_list):
|
||||
for term in dd["tf"]:
|
||||
index[term].append(i)
|
||||
return dict(index)
|
||||
|
||||
|
||||
def score_bm25(
|
||||
dd: DocData,
|
||||
query_terms: Set[str],
|
||||
df: Dict,
|
||||
df: Dict[str, int],
|
||||
N: int,
|
||||
avgdl: float,
|
||||
k1: float = 1.2,
|
||||
b: float = 0.75,
|
||||
) -> float:
|
||||
"""
|
||||
Compute BM25 score for a document.
|
||||
|
||||
Simplified: we use the IDF component which is the source of the
|
||||
score comparability problem. TF is kept simple (term frequency).
|
||||
"""
|
||||
text = f"{doc['title']} {doc['content']}".lower()
|
||||
words = text.split()
|
||||
doc_length = len(words)
|
||||
|
||||
# Count term frequencies
|
||||
tf = defaultdict(int)
|
||||
for word in words:
|
||||
tf[word] += 1
|
||||
|
||||
"""Compute BM25 score using pre-computed per-doc data."""
|
||||
doc_len = dd["len"]
|
||||
tf = dd["tf"]
|
||||
score = 0.0
|
||||
for term in query_terms:
|
||||
if term not in df:
|
||||
continue
|
||||
|
||||
# IDF component
|
||||
idf = idf_global(df[term], N)
|
||||
|
||||
# TF component (simplified)
|
||||
term_freq = tf.get(term, 0)
|
||||
tf_norm = term_freq * (k1 + 1) / (term_freq + k1 * (1 - b + b * doc_length / avgdl))
|
||||
|
||||
freq = tf.get(term, 0)
|
||||
if freq == 0:
|
||||
continue
|
||||
tf_norm = freq * (k1 + 1) / (freq + k1 * (1 - b + b * doc_len / avgdl))
|
||||
score += idf * tf_norm
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def build_inverted_index(docs: List[Dict]) -> Dict[str, List[Tuple[int, Dict]]]:
|
||||
"""Build inverted index: term -> [(doc_index, doc), ...]."""
|
||||
index: Dict[str, List[Tuple[int, Dict]]] = defaultdict(list)
|
||||
for i, doc in enumerate(docs):
|
||||
text = f"{doc['title']} {doc['content']}".lower()
|
||||
terms = set(text.split())
|
||||
for term in terms:
|
||||
index[term].append((i, doc))
|
||||
return dict(index)
|
||||
|
||||
|
||||
def _collect_candidates(
|
||||
inv_index: Dict[str, List[Tuple[int, Dict]]],
|
||||
inv_index: Dict[str, List[int]],
|
||||
doc_categories: List[str],
|
||||
query_terms: Set[str],
|
||||
category_filter: str | None,
|
||||
) -> List[Dict]:
|
||||
"""Collect unique candidate documents from inverted index."""
|
||||
) -> List[int]:
|
||||
"""Collect unique candidate doc indices from inverted index."""
|
||||
seen: Set[int] = set()
|
||||
candidates = []
|
||||
for term in query_terms:
|
||||
if term not in inv_index:
|
||||
continue
|
||||
for doc_idx, doc in inv_index[term]:
|
||||
for doc_idx in inv_index[term]:
|
||||
if doc_idx in seen:
|
||||
continue
|
||||
if category_filter and doc_categories[doc_idx] != category_filter:
|
||||
continue
|
||||
seen.add(doc_idx)
|
||||
candidates.append(doc)
|
||||
candidates.append(doc_idx)
|
||||
return candidates
|
||||
|
||||
|
||||
def simulate_search_indexed(
|
||||
inv_index: Dict[str, List[Tuple[int, Dict]]],
|
||||
def simulate_search(
|
||||
doc_data: List[DocData],
|
||||
inv_index: Dict[str, List[int]],
|
||||
doc_categories: List[str],
|
||||
query: Dict,
|
||||
stats: Tuple[Dict, int, float],
|
||||
limit: int = 100,
|
||||
) -> Dict:
|
||||
"""Simulate search using inverted index for fast lookup."""
|
||||
"""Simulate search on a single index/shard using pre-computed data."""
|
||||
df, N, avgdl = stats
|
||||
query_terms = tokenize(query["q"])
|
||||
category_filter = query["filter"].split("=")[1].strip() if query.get("filter") else None
|
||||
|
||||
candidates = _collect_candidates(inv_index, doc_categories, query_terms, category_filter)
|
||||
candidate_indices = _collect_candidates(inv_index, doc_categories, query_terms, category_filter)
|
||||
|
||||
scores = []
|
||||
for doc in candidates:
|
||||
score = score_document_bm25(doc, query_terms, df, N, avgdl)
|
||||
if score > 0:
|
||||
scores.append((doc, score))
|
||||
for idx in candidate_indices:
|
||||
dd = doc_data[idx]
|
||||
s = score_bm25(dd, query_terms, df, N, avgdl)
|
||||
if s > 0:
|
||||
scores.append((dd, s))
|
||||
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
hits = []
|
||||
for doc, score in scores[:limit]:
|
||||
hits.append({
|
||||
"id": doc["id"],
|
||||
"title": doc["title"],
|
||||
"score": score,
|
||||
})
|
||||
for dd, s in scores[:limit]:
|
||||
hits.append({"id": dd["id"], "title": dd["title"], "score": s})
|
||||
|
||||
return {
|
||||
"query_id": query["id"],
|
||||
|
|
@ -211,8 +193,9 @@ def simulate_search_indexed(
|
|||
}
|
||||
|
||||
|
||||
def simulate_distributed_search_indexed(
|
||||
shard_indexes: Dict[int, Dict[str, List[Tuple[int, Dict]]]],
|
||||
def simulate_distributed_search(
|
||||
shard_doc_data: Dict[int, List[DocData]],
|
||||
shard_indexes: Dict[int, Dict[str, List[int]]],
|
||||
shard_doc_categories: Dict[int, List[str]],
|
||||
shard_stats: Dict[int, Tuple[Dict, int, float]],
|
||||
query: Dict,
|
||||
|
|
@ -224,31 +207,30 @@ def simulate_distributed_search_indexed(
|
|||
per_shard_limit = limit * 2
|
||||
all_hits = []
|
||||
|
||||
for shard_id, inv_index in shard_indexes.items():
|
||||
for shard_id in shard_doc_data:
|
||||
df, N, avgdl = shard_stats[shard_id]
|
||||
doc_data = shard_doc_data[shard_id]
|
||||
inv_index = shard_indexes[shard_id]
|
||||
doc_cats = shard_doc_categories[shard_id]
|
||||
candidates = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
|
||||
|
||||
candidate_indices = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
|
||||
|
||||
shard_scores = []
|
||||
for doc in candidates:
|
||||
score = score_document_bm25(doc, query_terms, df, N, avgdl)
|
||||
if score > 0:
|
||||
shard_scores.append((doc, score))
|
||||
for idx in candidate_indices:
|
||||
dd = doc_data[idx]
|
||||
s = score_bm25(dd, query_terms, df, N, avgdl)
|
||||
if s > 0:
|
||||
shard_scores.append((dd, s))
|
||||
|
||||
shard_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
for doc, score in shard_scores[:per_shard_limit]:
|
||||
all_hits.append((doc, score, shard_id))
|
||||
for dd, s in shard_scores[:per_shard_limit]:
|
||||
all_hits.append((dd, s, shard_id))
|
||||
|
||||
all_hits.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
hits = []
|
||||
for doc, score, shard_id in all_hits[:limit]:
|
||||
hits.append({
|
||||
"id": doc["id"],
|
||||
"title": doc["title"],
|
||||
"score": score,
|
||||
"shard": shard_id,
|
||||
})
|
||||
for dd, s, shard_id in all_hits[:limit]:
|
||||
hits.append({"id": dd["id"], "title": dd["title"], "score": s, "shard": shard_id})
|
||||
|
||||
return {
|
||||
"query_id": query["id"],
|
||||
|
|
@ -257,15 +239,16 @@ def simulate_distributed_search_indexed(
|
|||
"filter": query.get("filter"),
|
||||
"hits": hits,
|
||||
"total_hits": len(all_hits),
|
||||
"shards_queried": list(shard_indexes.keys()),
|
||||
"shards_queried": list(shard_doc_data.keys()),
|
||||
}
|
||||
|
||||
|
||||
RRF_K = 60 # RRF constant, matching merger.rs
|
||||
|
||||
|
||||
def simulate_distributed_search_rrf_indexed(
|
||||
shard_indexes: Dict[int, Dict[str, List[Tuple[int, Dict]]]],
|
||||
def simulate_distributed_search_rrf(
|
||||
shard_doc_data: Dict[int, List[DocData]],
|
||||
shard_indexes: Dict[int, Dict[str, List[int]]],
|
||||
shard_doc_categories: Dict[int, List[str]],
|
||||
shard_stats: Dict[int, Tuple[Dict, int, float]],
|
||||
query: Dict,
|
||||
|
|
@ -277,39 +260,38 @@ def simulate_distributed_search_rrf_indexed(
|
|||
per_shard_limit = limit * 2
|
||||
|
||||
rrf_scores: Dict[str, float] = defaultdict(float)
|
||||
doc_info: Dict[str, Tuple[Dict, int]] = {}
|
||||
doc_info: Dict[str, Tuple[DocData, int]] = {}
|
||||
|
||||
for shard_id, inv_index in shard_indexes.items():
|
||||
for shard_id in shard_doc_data:
|
||||
df, N, avgdl = shard_stats[shard_id]
|
||||
doc_data = shard_doc_data[shard_id]
|
||||
inv_index = shard_indexes[shard_id]
|
||||
doc_cats = shard_doc_categories[shard_id]
|
||||
candidates = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
|
||||
|
||||
candidate_indices = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
|
||||
|
||||
shard_scores = []
|
||||
for doc in candidates:
|
||||
score = score_document_bm25(doc, query_terms, df, N, avgdl)
|
||||
if score > 0:
|
||||
shard_scores.append((doc, score))
|
||||
for idx in candidate_indices:
|
||||
dd = doc_data[idx]
|
||||
s = score_bm25(dd, query_terms, df, N, avgdl)
|
||||
if s > 0:
|
||||
shard_scores.append((dd, s))
|
||||
|
||||
shard_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for rank, (doc, _score) in enumerate(shard_scores[:per_shard_limit]):
|
||||
doc_id = doc["id"]
|
||||
for rank, (dd, _s) in enumerate(shard_scores[:per_shard_limit]):
|
||||
doc_id = dd["id"]
|
||||
rrf_contribution = 1.0 / (RRF_K + rank + 1)
|
||||
rrf_scores[doc_id] += rrf_contribution
|
||||
if doc_id not in doc_info:
|
||||
doc_info[doc_id] = (doc, shard_id)
|
||||
doc_info[doc_id] = (dd, shard_id)
|
||||
|
||||
sorted_docs = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
hits = []
|
||||
for doc_id, rrf_score in sorted_docs[:limit]:
|
||||
doc, shard_id = doc_info[doc_id]
|
||||
hits.append({
|
||||
"id": doc_id,
|
||||
"title": doc["title"],
|
||||
"score": rrf_score,
|
||||
"shard": shard_id,
|
||||
})
|
||||
dd, shard_id = doc_info[doc_id]
|
||||
hits.append({"id": doc_id, "title": dd["title"], "score": rrf_score, "shard": shard_id})
|
||||
|
||||
return {
|
||||
"query_id": query["id"],
|
||||
|
|
@ -318,7 +300,7 @@ def simulate_distributed_search_rrf_indexed(
|
|||
"filter": query.get("filter"),
|
||||
"hits": hits,
|
||||
"total_hits": len(sorted_docs),
|
||||
"shards_queried": list(shard_indexes.keys()),
|
||||
"shards_queried": list(shard_doc_data.keys()),
|
||||
"merge_strategy": "rrf",
|
||||
}
|
||||
|
||||
|
|
@ -333,12 +315,11 @@ def run_experiment(
|
|||
"""Run the full experiment."""
|
||||
print("Loading corpus...")
|
||||
docs, metadata = load_corpus(corpus_dir)
|
||||
|
||||
print(f" Total documents: {len(docs)}")
|
||||
print(f" Shard count: {shard_count}")
|
||||
|
||||
# Load per-shard data
|
||||
shards = {}
|
||||
shards: Dict[int, List[Dict]] = {}
|
||||
for i in range(shard_count):
|
||||
shard_file = corpus_dir / f"shard-{i:02d}.jsonl"
|
||||
if shard_file.exists():
|
||||
|
|
@ -347,38 +328,35 @@ def run_experiment(
|
|||
for line in f:
|
||||
shard_docs.append(json.loads(line))
|
||||
shards[i] = shard_docs
|
||||
|
||||
print(f" Loaded {len(shards)} shards")
|
||||
|
||||
# Compute statistics
|
||||
print("\nComputing statistics...")
|
||||
global_stats = compute_global_stats(docs)
|
||||
print(f" Global: N={global_stats[1]}, avgdl={global_stats[2]:.1f}")
|
||||
# Pre-compute per-document data
|
||||
print("\nPre-computing document data...")
|
||||
global_doc_data = precompute_doc_data(docs)
|
||||
global_doc_categories = [dd["category"] for dd in global_doc_data]
|
||||
global_inv_index = build_inverted_index(global_doc_data)
|
||||
global_stats = compute_stats(global_doc_data)
|
||||
print(f" Global: N={global_stats[1]}, avgdl={global_stats[2]:.1f}, {len(global_inv_index)} terms")
|
||||
|
||||
shard_doc_data: Dict[int, List[DocData]] = {}
|
||||
shard_indexes: Dict[int, Dict[str, List[int]]] = {}
|
||||
shard_doc_categories: Dict[int, List[str]] = {}
|
||||
shard_stats: Dict[int, Tuple[Dict, int, float]] = {}
|
||||
|
||||
shard_stats = {}
|
||||
for shard_id, shard_docs in shards.items():
|
||||
stats = compute_shard_stats(shard_docs)
|
||||
shard_stats[shard_id] = stats
|
||||
print(f" Shard {shard_id}: N={stats[1]}, avgdl={stats[2]:.1f}")
|
||||
sd = precompute_doc_data(shard_docs)
|
||||
shard_doc_data[shard_id] = sd
|
||||
shard_doc_categories[shard_id] = [dd["category"] for dd in sd]
|
||||
shard_indexes[shard_id] = build_inverted_index(sd)
|
||||
shard_stats[shard_id] = compute_stats(sd)
|
||||
print(f" Shard {shard_id}: N={shard_stats[shard_id][1]}, "
|
||||
f"avgdl={shard_stats[shard_id][2]:.1f}, {len(shard_indexes[shard_id])} terms")
|
||||
|
||||
# Load queries
|
||||
print(f"\nLoading queries from {query_file}...")
|
||||
queries = load_queries(query_file)
|
||||
print(f" {len(queries)} queries")
|
||||
|
||||
# Build inverted indexes for fast lookup
|
||||
print("\nBuilding inverted indexes...")
|
||||
global_inv_index = build_inverted_index(docs)
|
||||
global_doc_categories = [doc.get("category", "") for doc in docs]
|
||||
print(f" Global index: {len(global_inv_index)} terms")
|
||||
|
||||
shard_indexes = {}
|
||||
shard_doc_categories = {}
|
||||
for shard_id, shard_docs in shards.items():
|
||||
shard_indexes[shard_id] = build_inverted_index(shard_docs)
|
||||
shard_doc_categories[shard_id] = [d.get("category", "") for d in shard_docs]
|
||||
print(f" Shard {shard_id}: {len(shard_indexes[shard_id])} terms")
|
||||
|
||||
# Run experiments
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -395,23 +373,20 @@ def run_experiment(
|
|||
if (i + 1) % 1000 == 0:
|
||||
print(f" Processed {i + 1} queries...")
|
||||
|
||||
# Ground truth: single index with global statistics
|
||||
gt_result = simulate_search_indexed(
|
||||
global_inv_index, global_doc_categories,
|
||||
gt_result = simulate_search(
|
||||
global_doc_data, global_inv_index, global_doc_categories,
|
||||
query, global_stats, limit,
|
||||
)
|
||||
gt_f.write(json.dumps(gt_result) + "\n")
|
||||
|
||||
# Distributed: each shard uses local statistics (score-based merge)
|
||||
dist_result = simulate_distributed_search_indexed(
|
||||
shard_indexes, shard_doc_categories,
|
||||
dist_result = simulate_distributed_search(
|
||||
shard_doc_data, shard_indexes, shard_doc_categories,
|
||||
shard_stats, query, limit,
|
||||
)
|
||||
dist_f.write(json.dumps(dist_result) + "\n")
|
||||
|
||||
# RRF: rank-based merge (no score comparability needed)
|
||||
rrf_result = simulate_distributed_search_rrf_indexed(
|
||||
shard_indexes, shard_doc_categories,
|
||||
rrf_result = simulate_distributed_search_rrf(
|
||||
shard_doc_data, shard_indexes, shard_doc_categories,
|
||||
shard_stats, query, limit,
|
||||
)
|
||||
rrf_f.write(json.dumps(rrf_result) + "\n")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue