P2.2: Pluggable MergeStrategy trait + RRF scoring + full benchmark re-run

- Extract MergeStrategy trait with merge()/name() methods
- Implement RrfStrategy with configurable k (default 60)
- Refactor scatter_gather_search to accept &dyn MergeStrategy
- Add RRF simulation to benchmark script (simulate_distributed_search_rrf)
- Re-run full benchmark (3989 queries) with updated comparison reports
- Add topology unit tests (NodeId, NodeStatus, Node helpers)

Benchmark results:
  Score-based merge: avg tau = 0.798 (FAIL, common-term tau = 0.152)
  RRF merge:         avg tau = 0.134 (FAIL, rank-only loses score signal)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-04-19 02:07:39 -04:00
parent 1124d97c14
commit 0de5f01d32
5 changed files with 45492 additions and 222 deletions

View file

@ -1,4 +1,7 @@
//! Result merger: combines shard results into a single response.
//!
//! Supports pluggable merge strategies via the [`MergeStrategy`] trait.
//! The default strategy is Reciprocal Rank Fusion (RRF) with k=60.
use crate::Result;
use serde_json::{Map, Value};
@ -50,12 +53,78 @@ pub struct MergedSearchResult {
pub degraded: bool,
}
/// RRF constant k.
// ---------------------------------------------------------------------------
// Merge strategy trait
// ---------------------------------------------------------------------------
/// Pluggable merge strategy for combining shard results.
///
/// This is the denominator constant used in Reciprocal Rank Fusion.
/// The value 60 is the default recommended in the RRF literature and
/// is used by OpenSearch for hybrid search.
const RRF_K: u32 = 60;
/// Implementations define how hits from multiple shards are combined
/// into a single globally-ranked response. The default strategy is
/// [`RrfStrategy`] (Reciprocal Rank Fusion).
pub trait MergeStrategy: Send + Sync {
/// Merge search results from multiple shards into a single response.
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult>;
/// Strategy name (for logging and `/explain` output).
fn name(&self) -> &'static str;
}
// ---------------------------------------------------------------------------
// RRF strategy
// ---------------------------------------------------------------------------
/// Default RRF constant k.
///
/// The value 60 is recommended in the RRF literature and used by
/// OpenSearch for hybrid search. Smaller values amplify rank
/// differences; larger values flatten them.
pub const DEFAULT_RRF_K: u32 = 60;
/// Reciprocal Rank Fusion merge strategy.
///
/// Each document's contribution from a shard is `1 / (k + rank + 1)`
/// where rank is the 0-based position. Documents appearing in
/// multiple shards have their contributions summed. Results are
/// sorted by total RRF score descending, with deterministic
/// tie-breaking on primary key.
#[derive(Debug, Clone)]
pub struct RrfStrategy {
k: u32,
}
impl RrfStrategy {
/// Create a new RRF strategy with the given k constant.
pub fn new(k: u32) -> Self {
Self { k: k.max(1) }
}
/// Create with the default k=60.
pub fn default_strategy() -> Self {
Self::new(DEFAULT_RRF_K)
}
/// Return the configured k value.
pub fn k(&self) -> u32 {
self.k
}
}
impl Default for RrfStrategy {
fn default() -> Self {
Self::default_strategy()
}
}
impl MergeStrategy for RrfStrategy {
fn merge(&self, input: MergeInput) -> Result<MergedSearchResult> {
rrf_merge(&self.k, input)
}
fn name(&self) -> &'static str {
"rrf"
}
}
/// A document with its accumulated RRF score.
#[derive(Debug, Clone)]
@ -107,11 +176,27 @@ impl Ord for RRFDocument {
}
}
/// Merge search results from multiple shards into a single response.
/// Merge search results using the default RRF strategy (k=60).
///
/// This is a pure function with no side effects, making it testable
/// without a network and ensuring deterministic output.
/// This is a convenience wrapper around [`RrfStrategy`] for callers
/// that don't need to customise the strategy.
pub fn merge(input: MergeInput) -> Result<MergedSearchResult> {
rrf_merge(&DEFAULT_RRF_K, input)
}
/// Merge search results with a specific strategy.
///
/// Use this when the strategy is selected from config or when you
/// need a non-default RRF k value.
pub fn merge_with_strategy(
strategy: &dyn MergeStrategy,
input: MergeInput,
) -> Result<MergedSearchResult> {
strategy.merge(input)
}
/// Core RRF merge implementation.
fn rrf_merge(k: &u32, input: MergeInput) -> Result<MergedSearchResult> {
let mut estimated_total_hits = 0u64;
let mut max_processing_time = 0u64;
let mut degraded = false;
@ -157,9 +242,9 @@ pub fn merge(input: MergeInput) -> Result<MergedSearchResult> {
.unwrap_or("")
.to_string();
// Compute RRF contribution: 1 / (k + rank)
// rank is 0-based, so we add 1 to convert to 1-based for RRF formula
let rrf_contribution = 1.0 / ((RRF_K as f64) + (rank as f64) + 1.0);
// RRF contribution: 1 / (k + rank + 1)
// rank is 0-based, so +1 converts to 1-based position.
let rrf_contribution = 1.0 / ((*k as f64) + (rank as f64) + 1.0);
// Aggregate RRF scores across shards.
use std::collections::hash_map::Entry;
@ -268,6 +353,85 @@ mod tests {
use super::*;
use serde_json::json;
// -----------------------------------------------------------------------
// Trait / strategy tests
// -----------------------------------------------------------------------
#[test]
fn test_rrf_strategy_default_matches_free_function() {
let input = MergeInput {
shard_hits: vec![make_shard_response(
vec![make_hit("doc1", 0.9, 0), make_hit("doc2", 0.7, 0)],
100,
15,
)],
offset: 0,
limit: 10,
client_requested_score: false,
facets: None,
};
let strategy = RrfStrategy::default_strategy();
let via_trait = strategy.merge(input.clone()).unwrap();
let via_free = merge(input).unwrap();
assert_eq!(via_trait.hits, via_free.hits);
assert_eq!(via_trait.estimated_total_hits, via_free.estimated_total_hits);
}
#[test]
fn test_rrf_strategy_custom_k() {
// With k=1, rank 0 gets 1/(1+0+1) = 0.5
// With k=60, rank 0 gets 1/(60+0+1) ≈ 0.0164
// Both should produce the same ordering for a single shard.
let input = MergeInput {
shard_hits: vec![make_shard_response(
vec![make_hit("a", 0.9, 0), make_hit("b", 0.5, 0)],
50,
10,
)],
offset: 0,
limit: 10,
client_requested_score: false,
facets: None,
};
let strategy_k1 = RrfStrategy::new(1);
let result = strategy_k1.merge(input).unwrap();
assert_eq!(result.hits[0].get("id").unwrap(), "a");
assert_eq!(result.hits[1].get("id").unwrap(), "b");
}
#[test]
fn test_merge_with_strategy_dispatches() {
let input = MergeInput {
shard_hits: vec![make_shard_response(
vec![make_hit("doc1", 0.9, 0)],
50,
10,
)],
offset: 0,
limit: 10,
client_requested_score: false,
facets: None,
};
let strategy = RrfStrategy::default_strategy();
let result = merge_with_strategy(&strategy, input).unwrap();
assert_eq!(result.hits.len(), 1);
assert_eq!(strategy.name(), "rrf");
}
#[test]
fn test_rrf_strategy_k_clamped_to_one() {
let strategy = RrfStrategy::new(0);
assert_eq!(strategy.k(), 1);
}
// -----------------------------------------------------------------------
// Core merge tests
// -----------------------------------------------------------------------
fn make_hit(id: &str, score: f64, shard: u32) -> Value {
json!({
"id": id,
@ -840,4 +1004,201 @@ mod tests {
);
}
}
// -----------------------------------------------------------------------
// RRF correctness properties
// -----------------------------------------------------------------------
#[test]
fn test_rrf_cross_shard_replication_boost() {
// Key RRF property: a document appearing in multiple shards
// gets a higher merged score than a single-shard document,
// even if it ranks lower within each shard.
//
// doc-replicated: rank 5 in shard0, rank 5 in shard1
// RRF = 1/(60+5+1) + 1/(60+5+1) = 2/66 ≈ 0.0303
// doc-single: rank 0 in shard2
// RRF = 1/(60+0+1) = 1/61 ≈ 0.0164
//
// doc-replicated should rank higher.
let mut shard0 = vec![];
let mut shard1 = vec![];
for i in 0..5 {
shard0.push(make_hit(&format!("filler0-{}", i), 0.5, 0));
shard1.push(make_hit(&format!("filler1-{}", i), 0.5, 1));
}
shard0.push(make_hit("doc-replicated", 0.1, 0));
shard1.push(make_hit("doc-replicated", 0.1, 1));
let shard2 = vec![make_hit("doc-single", 0.99, 2)];
let strategy = RrfStrategy::default_strategy();
let result = strategy
.merge(MergeInput {
shard_hits: vec![
make_shard_response(shard0, 100, 10),
make_shard_response(shard1, 100, 10),
make_shard_response(shard2, 100, 10),
],
offset: 0,
limit: 20,
client_requested_score: false,
facets: None,
})
.unwrap();
let ids: Vec<_> = result.hits.iter().filter_map(|h| h.get("id").and_then(|v| v.as_str())).collect();
let rep_pos = ids.iter().position(|&id| id == "doc-replicated").unwrap();
let single_pos = ids.iter().position(|&id| id == "doc-single").unwrap();
assert!(
rep_pos < single_pos,
"Replicated doc at pos {} should rank above single doc at pos {}",
rep_pos,
single_pos
);
}
#[test]
fn test_rrf_immune_to_score_scale() {
// RRF uses rank only: scores of 0.001 vs 0.999 don't affect ordering.
// Two shards with wildly different score ranges should produce the
// same merge as two shards with uniform scores.
let shard_a = ShardHitPage {
body: json!({
"hits": [
{"id": "a1", "_rankingScore": 0.001},
{"id": "a2", "_rankingScore": 0.002},
],
"estimatedTotalHits": 2,
"processingTimeMs": 5,
}),
};
let shard_b = ShardHitPage {
body: json!({
"hits": [
{"id": "b1", "_rankingScore": 0.999},
{"id": "b2", "_rankingScore": 0.998},
],
"estimatedTotalHits": 2,
"processingTimeMs": 5,
}),
};
let strategy = RrfStrategy::default_strategy();
let result = strategy
.merge(MergeInput {
shard_hits: vec![shard_a, shard_b],
offset: 0,
limit: 10,
client_requested_score: false,
facets: None,
})
.unwrap();
// All at rank 0 or 1 — tie-break alphabetically within each rank tier.
// rank 0: a1, b1 → sort by id → a1, b1
// rank 1: a2, b2 → sort by id → a2, b2
let ids: Vec<_> = result.hits.iter().filter_map(|h| h.get("id").and_then(|v| v.as_str())).collect();
assert_eq!(ids, vec!["a1", "b1", "a2", "b2"]);
}
#[test]
fn test_rrf_deterministic_with_same_input() {
// RRF merge is a pure function: same input always produces same output.
let shard = make_shard_response(
(0..100).map(|i| make_hit(&format!("doc{}", i), (100 - i) as f64 / 100.0, 0)).collect(),
1000,
10,
);
let input = MergeInput {
shard_hits: vec![shard; 5],
offset: 0,
limit: 50,
client_requested_score: false,
facets: None,
};
let strategy = RrfStrategy::default_strategy();
let r1 = strategy.merge(input.clone()).unwrap();
let r2 = strategy.merge(input).unwrap();
assert_eq!(
serde_json::to_vec(&r1).unwrap(),
serde_json::to_vec(&r2).unwrap()
);
}
#[test]
fn test_rrf_default_impl() {
let via_default: RrfStrategy = Default::default();
let via_constructor = RrfStrategy::default_strategy();
assert_eq!(via_default.k(), via_constructor.k());
}
#[test]
fn test_merge_pk_field_as_primary_key() {
let shard = ShardHitPage {
body: json!({
"hits": [{"pk": "doc-pk", "title": "Test"}],
"estimatedTotalHits": 1,
"processingTimeMs": 5,
}),
};
let input = MergeInput {
shard_hits: vec![shard],
offset: 0,
limit: 10,
client_requested_score: false,
facets: None,
};
let result = merge(input).unwrap();
assert_eq!(result.hits.len(), 1);
assert_eq!(result.hits[0].get("pk").unwrap(), "doc-pk");
}
#[test]
fn test_rrf_document_equality() {
use super::{RRFDocument, Ordering};
let a = super::RRFDocument {
rrf_score: 1.0,
primary_key: "doc1".into(),
hit: Map::new(),
};
let b = super::RRFDocument {
rrf_score: 1.0,
primary_key: "doc1".into(),
hit: Map::new(),
};
assert_eq!(a, b);
let c = super::RRFDocument {
rrf_score: 1.0,
primary_key: "doc2".into(),
hit: Map::new(),
};
assert_ne!(a, c);
// NaN: both NaN → Equal
let nan_a = super::RRFDocument {
rrf_score: f64::NAN,
primary_key: "x".into(),
hit: Map::new(),
};
let nan_b = super::RRFDocument {
rrf_score: f64::NAN,
primary_key: "x".into(),
hit: Map::new(),
};
assert_eq!(nan_a.cmp(&nan_b), Ordering::Equal);
// NaN vs real: NaN is Less
let real = super::RRFDocument {
rrf_score: 1.0,
primary_key: "x".into(),
hit: Map::new(),
};
assert_eq!(nan_a.cmp(&real), Ordering::Less);
assert_eq!(real.cmp(&nan_a), Ordering::Greater);
}
}

View file

@ -1,7 +1,7 @@
//! Scatter orchestration: fan-out logic and covering set builder.
use crate::config::UnavailableShardPolicy;
use crate::merger::{merge, MergeInput, MergedSearchResult, ShardHitPage};
use crate::merger::{MergeInput, MergedSearchResult, MergeStrategy, ShardHitPage};
use crate::router::{covering_set, query_group};
use crate::topology::{NodeId, Topology};
use crate::Result;
@ -292,10 +292,10 @@ pub async fn execute_scatter<C: NodeClient>(
})
}
/// Execute a full scatter-gather search: fan out to nodes, then RRF-merge results.
/// Execute a full scatter-gather search: fan out to nodes, then merge results.
///
/// This is the primary entry point for the read path. It combines
/// `execute_scatter` (fan-out) with `merge` (RRF result merging)
/// `execute_scatter` (fan-out) with the configured merge strategy
/// into a single operation.
///
/// # Arguments
@ -304,15 +304,17 @@ pub async fn execute_scatter<C: NodeClient>(
/// * `req` - Search request to execute
/// * `topology` - Current topology (for resolving node addresses)
/// * `policy` - Policy for handling unavailable shards
/// * `strategy` - Merge strategy (e.g. `RrfStrategy`)
///
/// # Returns
/// A `MergedSearchResult` with globally ranked hits using RRF.
/// A `MergedSearchResult` with globally ranked hits.
pub async fn scatter_gather_search<C: NodeClient>(
plan: ScatterPlan,
client: &C,
req: SearchRequest,
topology: &Topology,
policy: UnavailableShardPolicy,
strategy: &dyn MergeStrategy,
) -> Result<MergedSearchResult> {
let scatter_result = execute_scatter(plan, client, req.clone(), topology, policy).await?;
@ -338,7 +340,7 @@ pub async fn scatter_gather_search<C: NodeClient>(
facets: req.facets.clone(),
};
merge(merge_input)
strategy.merge(merge_input)
}
/// Stubs for testing (no actual network calls).
@ -581,6 +583,92 @@ mod tests {
assert!(result.is_err());
}
#[test]
fn test_plan_search_scatter_empty_topology() {
// Topology with 0 nodes in group 1 — query_seq=1 targets group 1 which has no nodes
let mut topo = Topology::new(64, 2, 1);
for i in 0u32..3 {
let mut node = Node::new(
NodeId::new(format!("node-{i}")),
format!("http://node-{i}:7700"),
0,
);
node.status = crate::topology::NodeStatus::Active;
topo.add_node(node);
}
let plan = plan_search_scatter(&topo, 0, 1, 64);
assert_eq!(plan.chosen_group, 0);
assert_eq!(plan.shard_to_node.len(), 64);
}
#[test]
fn test_plan_search_scatter_invalid_group_returns_empty() {
// Create a topology with no groups at all
let topo = Topology::new(64, 0, 1);
let plan = plan_search_scatter(&topo, 0, 1, 64);
// Should return empty plan since no group exists
assert!(plan.shard_to_node.is_empty());
}
#[tokio::test]
async fn test_execute_scatter_fallback_policy() {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut client = MockNodeClient::default();
client.errors.insert(
NodeId::new("node-0".to_string()),
NodeError::Timeout,
);
let req = SearchRequest {
index_uid: "test".to_string(),
query: Some("test".to_string()),
offset: 0,
limit: 10,
filter: None,
facets: None,
ranking_score: false,
body: serde_json::json!({}),
};
// Fallback policy should behave like Partial for now
let result = execute_scatter(plan, &client, req, &topo, UnavailableShardPolicy::Fallback)
.await
.unwrap();
assert!(result.partial);
}
#[tokio::test]
async fn test_execute_scatter_node_not_in_topology() {
// Build a plan, then use a topology that doesn't have the plan's nodes
let mut topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
// Empty topology — none of the plan's nodes exist
let empty_topo = Topology::new(64, 2, 2);
let client = MockNodeClient::default();
let req = SearchRequest {
index_uid: "test".to_string(),
query: Some("test".to_string()),
offset: 0,
limit: 10,
filter: None,
facets: None,
ranking_score: false,
body: serde_json::json!({}),
};
let result = execute_scatter(plan, &client, req, &empty_topo, UnavailableShardPolicy::Partial)
.await
.unwrap();
// All shards should fail since no nodes in topology
assert!(result.partial);
assert!(!result.failed_shards.is_empty());
}
#[test]
fn test_node_error_variants() {
let timeout = NodeError::Timeout;
@ -642,7 +730,8 @@ mod tests {
body: serde_json::json!({}),
};
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial)
let strategy = crate::merger::RrfStrategy::default_strategy();
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial, &strategy)
.await
.unwrap();
@ -683,10 +772,47 @@ mod tests {
body: serde_json::json!({}),
};
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial)
let strategy = crate::merger::RrfStrategy::default_strategy();
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial, &strategy)
.await
.unwrap();
assert!(result.degraded);
}
#[tokio::test]
async fn test_scatter_gather_with_custom_k() {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut client = MockNodeClient::default();
client.responses.insert(
NodeId::new("node-0".to_string()),
serde_json::json!({
"hits": [{"id": "doc-a", "title": "Doc A"}],
"estimatedTotalHits": 1,
"processingTimeMs": 5,
}),
);
let req = SearchRequest {
index_uid: "test".to_string(),
query: Some("test".to_string()),
offset: 0,
limit: 10,
filter: None,
facets: None,
ranking_score: false,
body: serde_json::json!({}),
};
let strategy = crate::merger::RrfStrategy::new(1);
let result = scatter_gather_search(plan, &client, req, &topo, UnavailableShardPolicy::Partial, &strategy)
.await
.unwrap();
assert_eq!(strategy.name(), "rrf");
assert_eq!(strategy.k(), 1);
assert!(!result.degraded);
}
}

View file

@ -727,6 +727,112 @@ nodes:
assert_eq!(topo2.group(1).unwrap().node_count(), 3);
}
// ── NodeId conversions ────────────────────────────────────────────
#[test]
fn nodeid_from_string_and_as_ref() {
let id: NodeId = "test-node".to_string().into();
assert_eq!(id.as_str(), "test-node");
assert_eq!(AsRef::<str>::as_ref(&id), "test-node");
}
#[test]
fn nodeid_display_impl() {
let id = NodeId::new("my-node".to_string());
assert_eq!(format!("{}", id), "my-node");
}
// ── NodeStatus helpers ────────────────────────────────────────────
#[test]
fn is_readable_covers_all_statuses() {
use NodeStatus::*;
assert!(Active.is_readable());
assert!(Healthy.is_readable());
assert!(Degraded.is_readable());
assert!(Draining.is_readable());
assert!(!Failed.is_readable());
assert!(!Joining.is_readable());
assert!(!Removed.is_readable());
}
#[test]
fn is_active_covers_all_statuses() {
use NodeStatus::*;
assert!(Active.is_active());
assert!(Healthy.is_active());
assert!(Degraded.is_active());
assert!(!Draining.is_active());
assert!(!Failed.is_active());
assert!(!Joining.is_active());
assert!(!Removed.is_active());
}
// ── Node::is_healthy ──────────────────────────────────────────────
#[test]
fn node_is_healthy_covers_all_statuses() {
use NodeStatus::*;
for (status, expected) in [
(Active, true),
(Healthy, true),
(Degraded, true),
(Draining, false),
(Failed, false),
(Joining, false),
(Removed, false),
] {
let node = Node {
id: NodeId::new("test".into()),
address: "http://test:7700".into(),
replica_group: 0,
status,
};
assert_eq!(node.is_healthy(), expected, "{:?} is_healthy", status);
}
}
// ── Group::add_node duplicate prevention ──────────────────────────
#[test]
fn group_add_node_prevents_duplicates() {
let mut g = Group::new(0);
g.add_node(NodeId::new("a".into()));
g.add_node(NodeId::new("a".into()));
g.add_node(NodeId::new("b".into()));
assert_eq!(g.node_count(), 2);
}
// ── Topology with auto-derived replica_groups ─────────────────────
#[test]
fn topology_auto_derives_replica_groups_from_nodes() {
let mut topo = Topology::new(64, 1, 1);
topo.add_node(Node::new(NodeId::new("n0".into()), "http://n0:7700".into(), 0));
topo.add_node(Node::new(NodeId::new("n1".into()), "http://n1:7700".into(), 2));
// replica_groups should auto-derive to 3
assert_eq!(topo.replica_groups, 3);
assert!(topo.group(2).is_some());
}
#[test]
fn topology_node_lookup() {
let mut topo = make_test_topology();
assert!(topo.node(&NodeId::new("meili-0".into())).is_some());
assert!(topo.node(&NodeId::new("nonexistent".into())).is_none());
// Mutate via node_mut
let id = NodeId::new("meili-0".into());
topo.node_mut(&id).unwrap().status = NodeStatus::Failed;
assert_eq!(topo.node(&id).unwrap().status, NodeStatus::Failed);
}
#[test]
fn topology_replica_group_count() {
let topo = make_test_topology();
assert_eq!(topo.replica_group_count(), 2);
}
// ── Helpers ───────────────────────────────────────────────────────
fn make_test_topology() -> Topology {

View file

@ -56,150 +56,132 @@ def tokenize(text: str) -> Set[str]:
return set(text.lower().split())
def compute_global_stats(docs: List[Dict]) -> Tuple[Dict, int, float]:
"""
Compute global index statistics.
Returns:
- df: term -> document frequency (how many docs contain the term)
- N: total document count
- avgdl: average document length
"""
df = defaultdict(int)
total_length = 0
for doc in docs:
text = f"{doc['title']} {doc['content']}".lower()
terms = set(text.split())
for term in terms:
df[term] += 1
total_length += len(text.split())
N = len(docs)
avgdl = total_length / N if N > 0 else 0
return dict(df), N, avgdl
def compute_shard_stats(shard_docs: List[Dict]) -> Tuple[Dict, int, float]:
"""Compute statistics for a single shard (same as global but scoped)."""
return compute_global_stats(shard_docs)
def idf_global(df: int, N: int) -> float:
"""
Standard IDF formula used by most search engines.
IDF = log((N - df + 0.5) / (df + 0.5) + 1)
"""
"""Standard IDF formula: log((N - df + 0.5) / (df + 0.5) + 1)."""
if df == 0:
return 0.0
return math.log((N - df + 0.5) / (df + 0.5) + 1.0)
def score_document_bm25(
doc: Dict,
# Pre-computed per-document data for fast BM25 scoring.
DocData = Dict[str, object] # {"id": str, "tf": {term: count}, "len": int, "title": str, "category": str}
def precompute_doc_data(docs: List[Dict]) -> List[DocData]:
"""Pre-compute term frequencies and lengths for all documents."""
result = []
for doc in docs:
text = f"{doc['title']} {doc['content']}".lower()
words = text.split()
tf: Dict[str, int] = defaultdict(int)
for w in words:
tf[w] += 1
result.append({
"id": doc["id"],
"title": doc["title"],
"category": doc.get("category", ""),
"tf": dict(tf),
"len": len(words),
})
return result
def compute_stats(doc_data_list: List[DocData]) -> Tuple[Dict, int, float]:
"""Compute index statistics (df, N, avgdl) from pre-computed doc data."""
df: Dict[str, int] = defaultdict(int)
total_length = 0
for dd in doc_data_list:
for term in dd["tf"]:
df[term] += 1
total_length += dd["len"]
N = len(doc_data_list)
avgdl = total_length / N if N > 0 else 0
return dict(df), N, avgdl
def build_inverted_index(doc_data_list: List[DocData]) -> Dict[str, List[int]]:
"""Build inverted index: term -> [doc_index, ...]."""
index: Dict[str, List[int]] = defaultdict(list)
for i, dd in enumerate(doc_data_list):
for term in dd["tf"]:
index[term].append(i)
return dict(index)
def score_bm25(
dd: DocData,
query_terms: Set[str],
df: Dict,
df: Dict[str, int],
N: int,
avgdl: float,
k1: float = 1.2,
b: float = 0.75,
) -> float:
"""
Compute BM25 score for a document.
Simplified: we use the IDF component which is the source of the
score comparability problem. TF is kept simple (term frequency).
"""
text = f"{doc['title']} {doc['content']}".lower()
words = text.split()
doc_length = len(words)
# Count term frequencies
tf = defaultdict(int)
for word in words:
tf[word] += 1
"""Compute BM25 score using pre-computed per-doc data."""
doc_len = dd["len"]
tf = dd["tf"]
score = 0.0
for term in query_terms:
if term not in df:
continue
# IDF component
idf = idf_global(df[term], N)
# TF component (simplified)
term_freq = tf.get(term, 0)
tf_norm = term_freq * (k1 + 1) / (term_freq + k1 * (1 - b + b * doc_length / avgdl))
freq = tf.get(term, 0)
if freq == 0:
continue
tf_norm = freq * (k1 + 1) / (freq + k1 * (1 - b + b * doc_len / avgdl))
score += idf * tf_norm
return score
def build_inverted_index(docs: List[Dict]) -> Dict[str, List[Tuple[int, Dict]]]:
"""Build inverted index: term -> [(doc_index, doc), ...]."""
index: Dict[str, List[Tuple[int, Dict]]] = defaultdict(list)
for i, doc in enumerate(docs):
text = f"{doc['title']} {doc['content']}".lower()
terms = set(text.split())
for term in terms:
index[term].append((i, doc))
return dict(index)
def _collect_candidates(
inv_index: Dict[str, List[Tuple[int, Dict]]],
inv_index: Dict[str, List[int]],
doc_categories: List[str],
query_terms: Set[str],
category_filter: str | None,
) -> List[Dict]:
"""Collect unique candidate documents from inverted index."""
) -> List[int]:
"""Collect unique candidate doc indices from inverted index."""
seen: Set[int] = set()
candidates = []
for term in query_terms:
if term not in inv_index:
continue
for doc_idx, doc in inv_index[term]:
for doc_idx in inv_index[term]:
if doc_idx in seen:
continue
if category_filter and doc_categories[doc_idx] != category_filter:
continue
seen.add(doc_idx)
candidates.append(doc)
candidates.append(doc_idx)
return candidates
def simulate_search_indexed(
inv_index: Dict[str, List[Tuple[int, Dict]]],
def simulate_search(
doc_data: List[DocData],
inv_index: Dict[str, List[int]],
doc_categories: List[str],
query: Dict,
stats: Tuple[Dict, int, float],
limit: int = 100,
) -> Dict:
"""Simulate search using inverted index for fast lookup."""
"""Simulate search on a single index/shard using pre-computed data."""
df, N, avgdl = stats
query_terms = tokenize(query["q"])
category_filter = query["filter"].split("=")[1].strip() if query.get("filter") else None
candidates = _collect_candidates(inv_index, doc_categories, query_terms, category_filter)
candidate_indices = _collect_candidates(inv_index, doc_categories, query_terms, category_filter)
scores = []
for doc in candidates:
score = score_document_bm25(doc, query_terms, df, N, avgdl)
if score > 0:
scores.append((doc, score))
for idx in candidate_indices:
dd = doc_data[idx]
s = score_bm25(dd, query_terms, df, N, avgdl)
if s > 0:
scores.append((dd, s))
scores.sort(key=lambda x: x[1], reverse=True)
hits = []
for doc, score in scores[:limit]:
hits.append({
"id": doc["id"],
"title": doc["title"],
"score": score,
})
for dd, s in scores[:limit]:
hits.append({"id": dd["id"], "title": dd["title"], "score": s})
return {
"query_id": query["id"],
@ -211,8 +193,9 @@ def simulate_search_indexed(
}
def simulate_distributed_search_indexed(
shard_indexes: Dict[int, Dict[str, List[Tuple[int, Dict]]]],
def simulate_distributed_search(
shard_doc_data: Dict[int, List[DocData]],
shard_indexes: Dict[int, Dict[str, List[int]]],
shard_doc_categories: Dict[int, List[str]],
shard_stats: Dict[int, Tuple[Dict, int, float]],
query: Dict,
@ -224,31 +207,30 @@ def simulate_distributed_search_indexed(
per_shard_limit = limit * 2
all_hits = []
for shard_id, inv_index in shard_indexes.items():
for shard_id in shard_doc_data:
df, N, avgdl = shard_stats[shard_id]
doc_data = shard_doc_data[shard_id]
inv_index = shard_indexes[shard_id]
doc_cats = shard_doc_categories[shard_id]
candidates = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
candidate_indices = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
shard_scores = []
for doc in candidates:
score = score_document_bm25(doc, query_terms, df, N, avgdl)
if score > 0:
shard_scores.append((doc, score))
for idx in candidate_indices:
dd = doc_data[idx]
s = score_bm25(dd, query_terms, df, N, avgdl)
if s > 0:
shard_scores.append((dd, s))
shard_scores.sort(key=lambda x: x[1], reverse=True)
for doc, score in shard_scores[:per_shard_limit]:
all_hits.append((doc, score, shard_id))
for dd, s in shard_scores[:per_shard_limit]:
all_hits.append((dd, s, shard_id))
all_hits.sort(key=lambda x: x[1], reverse=True)
hits = []
for doc, score, shard_id in all_hits[:limit]:
hits.append({
"id": doc["id"],
"title": doc["title"],
"score": score,
"shard": shard_id,
})
for dd, s, shard_id in all_hits[:limit]:
hits.append({"id": dd["id"], "title": dd["title"], "score": s, "shard": shard_id})
return {
"query_id": query["id"],
@ -257,15 +239,16 @@ def simulate_distributed_search_indexed(
"filter": query.get("filter"),
"hits": hits,
"total_hits": len(all_hits),
"shards_queried": list(shard_indexes.keys()),
"shards_queried": list(shard_doc_data.keys()),
}
RRF_K = 60 # RRF constant, matching merger.rs
def simulate_distributed_search_rrf_indexed(
shard_indexes: Dict[int, Dict[str, List[Tuple[int, Dict]]]],
def simulate_distributed_search_rrf(
shard_doc_data: Dict[int, List[DocData]],
shard_indexes: Dict[int, Dict[str, List[int]]],
shard_doc_categories: Dict[int, List[str]],
shard_stats: Dict[int, Tuple[Dict, int, float]],
query: Dict,
@ -277,39 +260,38 @@ def simulate_distributed_search_rrf_indexed(
per_shard_limit = limit * 2
rrf_scores: Dict[str, float] = defaultdict(float)
doc_info: Dict[str, Tuple[Dict, int]] = {}
doc_info: Dict[str, Tuple[DocData, int]] = {}
for shard_id, inv_index in shard_indexes.items():
for shard_id in shard_doc_data:
df, N, avgdl = shard_stats[shard_id]
doc_data = shard_doc_data[shard_id]
inv_index = shard_indexes[shard_id]
doc_cats = shard_doc_categories[shard_id]
candidates = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
candidate_indices = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
shard_scores = []
for doc in candidates:
score = score_document_bm25(doc, query_terms, df, N, avgdl)
if score > 0:
shard_scores.append((doc, score))
for idx in candidate_indices:
dd = doc_data[idx]
s = score_bm25(dd, query_terms, df, N, avgdl)
if s > 0:
shard_scores.append((dd, s))
shard_scores.sort(key=lambda x: x[1], reverse=True)
for rank, (doc, _score) in enumerate(shard_scores[:per_shard_limit]):
doc_id = doc["id"]
for rank, (dd, _s) in enumerate(shard_scores[:per_shard_limit]):
doc_id = dd["id"]
rrf_contribution = 1.0 / (RRF_K + rank + 1)
rrf_scores[doc_id] += rrf_contribution
if doc_id not in doc_info:
doc_info[doc_id] = (doc, shard_id)
doc_info[doc_id] = (dd, shard_id)
sorted_docs = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
hits = []
for doc_id, rrf_score in sorted_docs[:limit]:
doc, shard_id = doc_info[doc_id]
hits.append({
"id": doc_id,
"title": doc["title"],
"score": rrf_score,
"shard": shard_id,
})
dd, shard_id = doc_info[doc_id]
hits.append({"id": doc_id, "title": dd["title"], "score": rrf_score, "shard": shard_id})
return {
"query_id": query["id"],
@ -318,7 +300,7 @@ def simulate_distributed_search_rrf_indexed(
"filter": query.get("filter"),
"hits": hits,
"total_hits": len(sorted_docs),
"shards_queried": list(shard_indexes.keys()),
"shards_queried": list(shard_doc_data.keys()),
"merge_strategy": "rrf",
}
@ -333,12 +315,11 @@ def run_experiment(
"""Run the full experiment."""
print("Loading corpus...")
docs, metadata = load_corpus(corpus_dir)
print(f" Total documents: {len(docs)}")
print(f" Shard count: {shard_count}")
# Load per-shard data
shards = {}
shards: Dict[int, List[Dict]] = {}
for i in range(shard_count):
shard_file = corpus_dir / f"shard-{i:02d}.jsonl"
if shard_file.exists():
@ -347,38 +328,35 @@ def run_experiment(
for line in f:
shard_docs.append(json.loads(line))
shards[i] = shard_docs
print(f" Loaded {len(shards)} shards")
# Compute statistics
print("\nComputing statistics...")
global_stats = compute_global_stats(docs)
print(f" Global: N={global_stats[1]}, avgdl={global_stats[2]:.1f}")
# Pre-compute per-document data
print("\nPre-computing document data...")
global_doc_data = precompute_doc_data(docs)
global_doc_categories = [dd["category"] for dd in global_doc_data]
global_inv_index = build_inverted_index(global_doc_data)
global_stats = compute_stats(global_doc_data)
print(f" Global: N={global_stats[1]}, avgdl={global_stats[2]:.1f}, {len(global_inv_index)} terms")
shard_doc_data: Dict[int, List[DocData]] = {}
shard_indexes: Dict[int, Dict[str, List[int]]] = {}
shard_doc_categories: Dict[int, List[str]] = {}
shard_stats: Dict[int, Tuple[Dict, int, float]] = {}
shard_stats = {}
for shard_id, shard_docs in shards.items():
stats = compute_shard_stats(shard_docs)
shard_stats[shard_id] = stats
print(f" Shard {shard_id}: N={stats[1]}, avgdl={stats[2]:.1f}")
sd = precompute_doc_data(shard_docs)
shard_doc_data[shard_id] = sd
shard_doc_categories[shard_id] = [dd["category"] for dd in sd]
shard_indexes[shard_id] = build_inverted_index(sd)
shard_stats[shard_id] = compute_stats(sd)
print(f" Shard {shard_id}: N={shard_stats[shard_id][1]}, "
f"avgdl={shard_stats[shard_id][2]:.1f}, {len(shard_indexes[shard_id])} terms")
# Load queries
print(f"\nLoading queries from {query_file}...")
queries = load_queries(query_file)
print(f" {len(queries)} queries")
# Build inverted indexes for fast lookup
print("\nBuilding inverted indexes...")
global_inv_index = build_inverted_index(docs)
global_doc_categories = [doc.get("category", "") for doc in docs]
print(f" Global index: {len(global_inv_index)} terms")
shard_indexes = {}
shard_doc_categories = {}
for shard_id, shard_docs in shards.items():
shard_indexes[shard_id] = build_inverted_index(shard_docs)
shard_doc_categories[shard_id] = [d.get("category", "") for d in shard_docs]
print(f" Shard {shard_id}: {len(shard_indexes[shard_id])} terms")
# Run experiments
output_dir.mkdir(parents=True, exist_ok=True)
@ -395,23 +373,20 @@ def run_experiment(
if (i + 1) % 1000 == 0:
print(f" Processed {i + 1} queries...")
# Ground truth: single index with global statistics
gt_result = simulate_search_indexed(
global_inv_index, global_doc_categories,
gt_result = simulate_search(
global_doc_data, global_inv_index, global_doc_categories,
query, global_stats, limit,
)
gt_f.write(json.dumps(gt_result) + "\n")
# Distributed: each shard uses local statistics (score-based merge)
dist_result = simulate_distributed_search_indexed(
shard_indexes, shard_doc_categories,
dist_result = simulate_distributed_search(
shard_doc_data, shard_indexes, shard_doc_categories,
shard_stats, query, limit,
)
dist_f.write(json.dumps(dist_result) + "\n")
# RRF: rank-based merge (no score comparability needed)
rrf_result = simulate_distributed_search_rrf_indexed(
shard_indexes, shard_doc_categories,
rrf_result = simulate_distributed_search_rrf(
shard_doc_data, shard_indexes, shard_doc_categories,
shard_stats, query, limit,
)
rrf_f.write(json.dumps(rrf_result) + "\n")