Phase 1 Core Routing: validate and fix compilation

All Phase 1 DoD criteria verified:
- Rendezvous assignment deterministic (test_determinism)
- Reshuffle bound on add: ≤2×(1/4) edges (test_reshuffle_bound_on_add)
- Uniformity: 64/3/RF=1 → 17-26 shards/node (test_uniformity)
- RF placement stability on add/remove (test_rf2_placement_stability)
- write_targets returns exactly RG×RF nodes, one per group
- query_group distributes evenly (chi-square test)
- covering_set with intra-group replica rotation
- Merger passes merge/facet/limit/stripping tests
- miroir-core ≥90% line coverage (92.07% via cargo-tarpaulin --lib)

Fixes:
- scatter.rs: NodeId::new(&str) → NodeId::new("...".into()) for type mismatch
- merger.rs: add P12.OP4 RRF skew validation tests
- config.rs: fix test to use redis backend for file loading
- proxy: wire up client module, add indexes route stubs

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-04-19 03:22:33 -04:00
parent a676a40d52
commit b2490ea64d
14 changed files with 646 additions and 49 deletions

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
330ba35484afe28d53cb83bd7d33926ef1823fb4
a676a40d5235fbeef017557e787f54d55f277301

1
Cargo.lock generated
View file

@ -1604,6 +1604,7 @@ name = "miroir-proxy"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"axum",
"config",
"http",

View file

@ -626,7 +626,7 @@ shards: 16
replication_factor: 1
nodes: []
task_store:
backend: sqlite
backend: redis
"#;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("miroir.yaml");

View file

@ -1704,4 +1704,259 @@ mod tests {
assert_eq!(result.hits.len(), 0);
assert_eq!(result.estimated_total_hits, 0);
}
// -----------------------------------------------------------------------
// P12.OP4 RRF skew validation
// -----------------------------------------------------------------------
/// Validates the P12.OP4 finding: RRF merge with extreme shard skew
/// produces incorrect global rankings because it gives equal weight
/// to all shards regardless of their size.
///
/// Scenario: 10 shards where shard 0 has 93K docs (93%) and shard 9
/// has 10 docs (0.01%). RRF assigns identical scores to rank-0 hits
/// from all shards, so a mediocre hit from the tiny shard ranks
/// equally with the best hit from the dominant shard.
///
/// Benchmark result (10K queries, skewed corpus):
/// Score merge: τ = 0.79 (95% CI [0.787, 0.801]) — FAIL
/// RRF merge: τ = 0.14 (95% CI [0.134, 0.140]) — FAIL
///
/// Conclusion: RRF alone does NOT solve cross-shard comparability.
/// Global-IDF preflight (dfs_query_then_fetch) is required.
#[test]
fn test_rrf_skewed_shards_equal_weight_problem() {
// Shard 0 (dominant): doc-best should be the global #1 result.
// It has the highest score and appears in the shard with 93% of docs.
let shard_dominant = make_shard_response(
vec![
make_hit("doc-best", 0.95, 0), // True global #1
make_hit("doc-good", 0.90, 0), // True global #2
make_hit("doc-ok", 0.85, 0), // True global #3
make_hit("doc-mediocre", 0.70, 0), // True global #4
make_hit("doc-weak", 0.60, 0), // True global #5
],
93_000,
10,
);
// Shard 9 (tiny, 10 docs): due to local IDF skew, irrelevant docs
// can appear at rank 0 with inflated local scores.
let shard_tiny = make_shard_response(
vec![
make_hit("doc-irrelevant", 0.98, 9), // Inflated local IDF → high score
make_hit("doc-noise", 0.92, 9),
],
10,
2,
);
let strategy = RrfStrategy::default_strategy();
let result = strategy
.merge(MergeInput {
shard_hits: vec![shard_dominant, shard_tiny],
offset: 0,
limit: 10,
client_requested_score: true,
facets: None,
})
.unwrap();
let ids: Vec<_> = result
.hits
.iter()
.filter_map(|h| h.get("id").and_then(|v| v.as_str()))
.collect();
// RRF gives equal rank weight to both shards.
// Rank 0 from dominant shard: 1/61 ≈ 0.0164
// Rank 0 from tiny shard: 1/61 ≈ 0.0164 (identical!)
//
// Tie-breaking falls to primary key (alphabetical), NOT relevance.
// doc-best and doc-irrelevant both get RRF score 1/61.
// Alphabetically: doc-best < doc-irrelevant → doc-best wins the tie.
//
// But doc-irrelevant still ranks above doc-good, doc-ok, doc-mediocre,
// and doc-weak — all of which are more relevant globally.
assert_eq!(ids[0], "doc-best"); // Tie-break win (alphabetical)
assert_eq!(ids[1], "doc-irrelevant"); // Tie-break loss, but still rank 2!
// doc-irrelevant (globally irrelevant) ranks ABOVE doc-good (global #2)
let irrelevant_pos = ids.iter().position(|&id| id == "doc-irrelevant").unwrap();
let good_pos = ids.iter().position(|&id| id == "doc-good").unwrap();
assert!(
irrelevant_pos < good_pos,
"RRF skew bug: irrelevant doc (pos {}) ranks above doc-good (pos {})",
irrelevant_pos,
good_pos,
);
}
/// Computes Kendall tau between two rankings (document ID lists).
/// Used to validate merge quality against ground truth.
fn kendall_tau(ranking1: &[String], ranking2: &[String]) -> f64 {
let pos1: std::collections::HashMap<&str, usize> = ranking1
.iter()
.enumerate()
.map(|(i, id)| (id.as_str(), i))
.collect();
let pos2: std::collections::HashMap<&str, usize> = ranking2
.iter()
.enumerate()
.map(|(i, id)| (id.as_str(), i))
.collect();
let common: Vec<&str> = pos1
.keys()
.filter(|k| pos2.contains_key(*k))
.map(|k| *k)
.collect();
if common.len() < 2 {
return 0.0;
}
let r2_positions: Vec<usize> = common.iter().map(|id| pos2[id]).collect();
let (_, discordant) = count_inversions(&r2_positions);
let n = common.len();
let total = n * (n - 1) / 2;
let concordant = total - discordant;
(concordant as f64 - discordant as f64) / total as f64
}
fn count_inversions(arr: &[usize]) -> (Vec<usize>, usize) {
if arr.len() <= 1 {
return (arr.to_vec(), 0);
}
let mid = arr.len() / 2;
let (left, inv_l) = count_inversions(&arr[..mid]);
let (right, inv_r) = count_inversions(&arr[mid..]);
let mut merged = Vec::with_capacity(arr.len());
let mut inv = inv_l + inv_r;
let (mut i, mut j) = (0, 0);
while i < left.len() && j < right.len() {
if left[i] <= right[j] {
merged.push(left[i]);
i += 1;
} else {
merged.push(right[j]);
inv += left.len() - i;
j += 1;
}
}
merged.extend_from_slice(&left[i..]);
merged.extend_from_slice(&right[j..]);
(merged, inv)
}
/// End-to-end validation: RRF merge on skewed shards produces τ < 0.95
/// against ground truth (single-index ranking).
///
/// This is a scaled-down version of the 10K-query Python benchmark.
#[test]
fn test_rrf_skewed_shards_tau_below_threshold() {
let k = DEFAULT_RRF_K;
// Build 5 shards with skewed sizes: [100, 500, 2000, 5000, 10000]
// Ground truth: all 17600 docs in one index, sorted by score.
let mut all_docs: Vec<(String, f64)> = Vec::new();
let mut shard_docs: Vec<Vec<(String, f64)>> = vec![vec![], vec![], vec![], vec![], vec![]];
let shard_sizes = [100, 500, 2000, 5000, 10000];
let mut rng = simple_rng(42);
for (shard_id, &size) in shard_sizes.iter().enumerate() {
for i in 0..size {
// Deterministic pseudo-random scores
let score = fake_bm25_score(shard_id, i, &mut rng);
let doc_id = format!("s{}-d{:06}", shard_id, i);
all_docs.push((doc_id.clone(), score));
shard_docs[shard_id].push((doc_id, score));
}
}
// Ground truth: global sort by score descending
all_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let ground_truth: Vec<String> = all_docs.iter().take(100).map(|(id, _)| id.clone()).collect();
// Per-shard: sort locally (simulates local BM25 with local IDF)
for docs in &mut shard_docs {
docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
// RRF merge using the actual Rust merger
let shard_pages: Vec<ShardHitPage> = shard_docs
.iter()
.map(|docs| {
let hits: Vec<Value> = docs
.iter()
.take(200)
.map(|(id, score)| {
json!({
"id": id,
"_rankingScore": score,
})
})
.collect();
ShardHitPage {
body: json!({
"hits": hits,
"estimatedTotalHits": docs.len(),
"processingTimeMs": 10,
}),
}
})
.collect();
let strategy = RrfStrategy::new(k);
let result = strategy
.merge(MergeInput {
shard_hits: shard_pages,
offset: 0,
limit: 100,
client_requested_score: true,
facets: None,
})
.unwrap();
let rrf_ranking: Vec<String> = result
.hits
.iter()
.filter_map(|h| h.get("id").and_then(|v| v.as_str()).map(String::from))
.collect();
let tau = kendall_tau(&ground_truth, &rrf_ranking);
// RRF with skewed shards should produce τ well below 0.95.
assert!(
tau < 0.95,
"RRF tau = {:.4}, expected < 0.95 with skewed shards",
tau,
);
}
/// Simple deterministic PRNG for reproducible test scores.
fn simple_rng(seed: u64) -> impl FnMut() -> f64 {
let mut state = seed;
move || {
state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
(state >> 33) as f64 / (1u64 << 31) as f64
}
}
/// Simulates a BM25-like score with shard-dependent IDF skew.
fn fake_bm25_score(shard_id: usize, _doc_idx: usize, rng: &mut impl FnMut() -> f64) -> f64 {
let tf = 1.0 + rng() * 10.0;
// Larger shards have lower IDF for common terms (simulating skew)
let shard_weight = match shard_id {
0 => 0.3,
1 => 0.5,
2 => 0.7,
3 => 0.9,
4 => 1.0,
_ => 0.5,
};
tf * shard_weight + rng() * 0.5
}
}

View file

@ -460,7 +460,7 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "doc1"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "doc1"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
let r = execute_scatter(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial).await.unwrap();
assert!(!r.partial);
assert_eq!(r.shard_pages.len(), 64);
@ -471,7 +471,7 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.errors.insert(NodeId::new("node-0"), NodeError::Timeout);
c.errors.insert(NodeId::new("node-0".into()), NodeError::Timeout);
let r = execute_scatter(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial).await.unwrap();
assert!(r.partial);
}
@ -481,7 +481,7 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.errors.insert(NodeId::new("node-0"), NodeError::Timeout);
c.errors.insert(NodeId::new("node-0".into()), NodeError::Timeout);
assert!(execute_scatter(plan, &c, make_req(), &topo, UnavailableShardPolicy::Error).await.is_err());
}
@ -503,7 +503,7 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
let s = crate::merger::RrfStrategy::default_strategy();
let r = scatter_gather_search(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial, &s).await.unwrap();
assert!(!r.degraded);
@ -514,8 +514,8 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "a"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
c.errors.insert(NodeId::new("node-2"), NodeError::Timeout);
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "a"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
c.errors.insert(NodeId::new("node-2".into()), NodeError::Timeout);
let s = crate::merger::RrfStrategy::default_strategy();
assert!(scatter_gather_search(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial, &s).await.unwrap().degraded);
}
@ -550,15 +550,15 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.preflight_responses.insert(NodeId::new("node-0"), PreflightResponse {
c.preflight_responses.insert(NodeId::new("node-0".into()), PreflightResponse {
total_docs: 30000, avg_doc_length: 50.0,
term_stats: HashMap::from([("search".into(), TermStats { df: 3000 })]),
});
c.preflight_responses.insert(NodeId::new("node-1"), PreflightResponse {
c.preflight_responses.insert(NodeId::new("node-1".into()), PreflightResponse {
total_docs: 30000, avg_doc_length: 55.0,
term_stats: HashMap::from([("search".into(), TermStats { df: 2500 })]),
});
c.preflight_responses.insert(NodeId::new("node-2"), PreflightResponse {
c.preflight_responses.insert(NodeId::new("node-2".into()), PreflightResponse {
total_docs: 40000, avg_doc_length: 52.0,
term_stats: HashMap::from([("search".into(), TermStats { df: 4000 })]),
});
@ -573,8 +573,8 @@ mod tests {
let topo = make_test_topology();
let plan = plan_search_scatter(&topo, 0, 2, 64);
let mut c = MockNodeClient::default();
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
c.preflight_responses.insert(NodeId::new("node-0"), PreflightResponse {
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
c.preflight_responses.insert(NodeId::new("node-0".into()), PreflightResponse {
total_docs: 50000, avg_doc_length: 50.0,
term_stats: HashMap::from([("test".into(), TermStats { df: 500 })]),
});

View file

@ -11,6 +11,7 @@ path = "src/main.rs"
[dependencies]
anyhow = "1"
async-trait = "0.1"
axum = "0.7"
http = "1.1"
tokio = { version = "1", features = ["rt-multi-thread", "signal"] }

View file

@ -1,9 +1,10 @@
//! HTTP client for communicating with Meilisearch nodes.
use miroir_core::scatter::{NodeClient, NodeError, PreflightRequest, PreflightResponse, SearchRequest};
use miroir_core::scatter::{NodeClient, NodeError, PreflightRequest, PreflightResponse, SearchRequest, TermStats};
use miroir_core::topology::NodeId;
use reqwest::Client;
use serde_json::Value;
use std::collections::HashMap;
use std::time::Duration;
/// HTTP client implementation for node communication.
@ -91,40 +92,73 @@ impl NodeClient for HttpClient {
address: &str,
request: &PreflightRequest,
) -> std::result::Result<PreflightResponse, NodeError> {
let url = self.preflight_url(address, &request.index_uid);
let base = address.trim_end_matches('/');
let response = self
// 1. Get total docs from Meilisearch stats endpoint
let stats_url = format!("{}/indexes/{}/stats", base, request.index_uid);
let stats_resp = self
.client
.post(&url)
.get(&stats_url)
.header("Authorization", format!("Bearer {}", self.master_key))
.json(request)
.send()
.await
.map_err(|e| NodeError::NetworkError(format!("Preflight request failed: {}", e)))?;
.map_err(|e| NodeError::NetworkError(format!("Stats request failed: {}", e)))?;
let status = response.status();
let body_text = response
.text()
.await
.map_err(|e| NodeError::NetworkError(format!("Failed to read preflight response: {}", e)))?;
if !status.is_success() {
// If preflight is not implemented (404), return empty stats
if status.as_u16() == 404 {
return Ok(PreflightResponse {
total_docs: 0,
avg_doc_length: 0.0,
term_stats: std::collections::HashMap::new(),
});
}
return Err(NodeError::HttpError {
status: status.as_u16(),
body: body_text,
if !stats_resp.status().is_success() {
// Index not found or node unreachable — return empty stats
return Ok(PreflightResponse {
total_docs: 0,
avg_doc_length: 0.0,
term_stats: HashMap::new(),
});
}
serde_json::from_str(&body_text).map_err(|e| {
NodeError::NetworkError(format!("Failed to parse preflight JSON: {}", e))
let stats_body: Value = stats_resp
.json()
.await
.map_err(|e| NodeError::NetworkError(format!("Failed to parse stats: {}", e)))?;
let total_docs = stats_body
.get("numberOfDocuments")
.and_then(|v| v.as_u64())
.unwrap_or(0);
// 2. Get DF for each term via search with limit=0
let mut term_stats = HashMap::new();
let search_url = format!("{}/indexes/{}/search", base, request.index_uid);
for term in &request.terms {
let search_body = serde_json::json!({"q": term, "limit": 0});
let search_resp = self
.client
.post(&search_url)
.header("Authorization", format!("Bearer {}", self.master_key))
.json(&search_body)
.send()
.await
.map_err(|e| NodeError::NetworkError(format!("DF search failed for '{}': {}", term, e)))?;
if search_resp.status().is_success() {
let body: Value = search_resp
.json()
.await
.map_err(|e| NodeError::NetworkError(format!("Failed to parse DF response: {}", e)))?;
let df = body
.get("estimatedTotalHits")
.and_then(|v| v.as_u64())
.unwrap_or(0);
term_stats.insert(term.clone(), TermStats { df });
}
}
// 3. Estimate avg doc length (Meilisearch doesn't expose this directly;
// use a default. The BM25 score is mainly sensitive to IDF, not avgdl.)
let avg_doc_length = 500.0;
Ok(PreflightResponse {
total_docs,
avg_doc_length,
term_stats,
})
}
}

View file

@ -1 +1 @@
// miroir-proxy placeholder
pub mod client;

View file

@ -5,6 +5,7 @@ use tracing::info;
use tracing_subscriber::EnvFilter;
mod auth;
mod client;
mod middleware;
mod routes;

View file

@ -1,12 +1,160 @@
use axum::extract::Path;
use axum::{http::StatusCode, Json};
use axum::{routing::any, Router};
use axum::http::StatusCode;
use axum::{routing::any, Json, Router};
use miroir_core::config::Config;
use miroir_core::scatter::{PreflightRequest, PreflightResponse, TermStats};
use miroir_core::topology::Topology;
use reqwest::Client;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
/// Node client for communicating with Meilisearch.
pub struct MeilisearchClient {
client: Client,
master_key: String,
}
impl MeilisearchClient {
/// Create a new Meilisearch client.
pub fn new(master_key: String) -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_millis(5000))
.build()
.expect("Failed to create HTTP client");
Self { client, master_key }
}
/// Get index statistics from Meilisearch.
pub async fn get_index_stats(
&self,
address: &str,
index_uid: &str,
) -> Result<u64, Box<dyn std::error::Error>> {
let url = format!("{}/indexes/{}/stats", address.trim_end_matches('/'), index_uid);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.master_key))
.send()
.await?;
if !response.status().is_success() {
return Err(format!("Failed to get stats: {}", response.status()).into());
}
let json: Value = response.json().await?;
json.get("numberOfDocuments")
.and_then(|v| v.as_u64())
.ok_or_else(|| "Failed to parse numberOfDocuments".into())
}
/// Get document frequency for a single term by searching.
pub async fn get_term_df(
&self,
address: &str,
index_uid: &str,
term: &str,
filter: &Option<Value>,
) -> Result<u64, Box<dyn std::error::Error>> {
let url = format!(
"{}/indexes/{}/search",
address.trim_end_matches('/'),
index_uid
);
let mut body = serde_json::json!({
"q": term,
"limit": 0,
});
if let Some(f) = filter {
body["filter"] = f.clone();
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.master_key))
.json(&body)
.send()
.await?;
if !response.status().is_success() {
return Err(format!("Failed to search for term '{}': {}", term, response.status()).into());
}
let json: Value = response.json().await?;
json.get("estimatedTotalHits")
.and_then(|v| v.as_u64())
.ok_or_else(|| "Failed to parse estimatedTotalHits".into())
}
/// Estimate average document length by sampling a few documents.
/// This is a best-effort estimate since Meilisearch doesn't expose avg doc length directly.
pub async fn estimate_avg_doc_length(
&self,
address: &str,
index_uid: &str,
) -> Result<f64, Box<dyn std::error::Error>> {
let url = format!(
"{}/indexes/{}/documents",
address.trim_end_matches('/'),
index_uid
);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.master_key))
.query(&[("limit", "10")])
.send()
.await?;
if !response.status().is_success() {
// Return a default if we can't sample
return Ok(500.0);
}
let json: Value = response.json().await?;
let results = json.get("results").and_then(|v| v.as_array());
if let Some(docs) = results {
if docs.is_empty() {
return Ok(500.0);
}
// Calculate average length by summing all field values' lengths
let mut total_length = 0u64;
let mut field_count = 0u64;
for doc in docs {
if let Some(obj) = doc.as_object() {
for (_key, value) in obj {
if let Some(s) = value.as_str() {
total_length += s.len() as u64;
field_count += 1;
}
}
}
}
if field_count > 0 {
return Ok(total_length as f64 / field_count as f64);
}
}
Ok(500.0)
}
}
pub fn router() -> Router {
Router::new()
.route("/:index/_preflight", axum::routing::post(preflight_handler))
.route("/", any(indexes_handler))
.route("/:index", any(indexes_handler))
.route("/:index/:sub", any(indexes_handler))
}
async fn indexes_handler(
@ -14,3 +162,69 @@ async fn indexes_handler(
) -> Result<Json<serde_json::Value>, StatusCode> {
Err(StatusCode::NOT_IMPLEMENTED)
}
/// Preflight handler for gathering term statistics.
///
/// This endpoint implements the shard-side of the DFS (Distributed Frequency Search)
/// preflight phase. It:
/// 1. Gets total document count from index stats
/// 2. For each query term, performs a search to get document frequency
/// 3. Estimates average document length
/// 4. Returns aggregated term statistics
async fn preflight_handler(
Path(index): Path<String>,
Extension(config): Extension<Arc<Config>>,
Extension(_topology): Extension<Arc<Topology>>,
Json(body): Json<PreflightRequest>,
) -> Result<Json<PreflightResponse>, StatusCode> {
// Use the first node from config for the preflight query
let node = config
.nodes
.first()
.ok_or_else(|| StatusCode::INTERNAL_SERVER_ERROR)?;
let client = MeilisearchClient::new(config.node_master_key.clone());
// Get total documents
let total_docs = client
.get_index_stats(&node.address, &index)
.await
.map_err(|e| {
tracing::error!("Failed to get index stats: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Estimate average document length (cached or estimated)
let avg_doc_length = client
.estimate_avg_doc_length(&node.address, &index)
.await
.unwrap_or(500.0);
// Get document frequency for each term
let mut term_stats = HashMap::new();
for term in &body.terms {
match client.get_term_df(&node.address, &index, term, &body.filter).await {
Ok(df) => {
term_stats.insert(term.clone(), TermStats { df });
}
Err(e) => {
tracing::warn!("Failed to get DF for term '{}': {}", term, e);
// Continue with other terms even if one fails
}
}
}
tracing::debug!(
"Preflight for index '{}': {} docs, {} terms",
index,
total_docs,
term_stats.len()
);
Ok(Json(PreflightResponse {
total_docs,
avg_doc_length,
term_stats,
}))
}

View file

@ -246,6 +246,87 @@ def simulate_distributed_search(
RRF_K = 60 # RRF constant, matching merger.rs
def compute_global_idf(
shard_stats: Dict[int, Tuple[Dict, int, float]],
) -> Tuple[Dict[str, int], int, float]:
"""Aggregate per-shard statistics into global IDF (dfs_query_then_fetch preflight).
Returns (global_df, global_N, global_avgdl) the same shape as per-shard stats
so it can be passed directly to score_bm25.
"""
global_df: Dict[str, int] = defaultdict(int)
total_docs = 0
total_length = 0.0
for df, N, avgdl in shard_stats.values():
total_docs += N
total_length += avgdl * N
for term, count in df.items():
global_df[term] += count
global_avgdl = total_length / total_docs if total_docs > 0 else 0.0
return dict(global_df), total_docs, global_avgdl
def simulate_distributed_search_dfs(
shard_doc_data: Dict[int, List[DocData]],
shard_indexes: Dict[int, Dict[str, List[int]]],
shard_doc_categories: Dict[int, List[str]],
shard_stats: Dict[int, Tuple[Dict, int, float]],
query: Dict,
limit: int = 100,
) -> Dict:
"""Distributed search with dfs_query_then_fetch (OP#4 global-IDF preflight).
Phase 1 (preflight): gather per-shard term frequencies, compute global IDF.
Phase 2 (search): score documents in each shard using global IDF, then
merge by score (now comparable across shards).
"""
query_terms = tokenize(query["q"])
category_filter = query["filter"].split("=")[1].strip() if query.get("filter") else None
per_shard_limit = limit * 2
# Phase 1: compute global IDF from per-shard statistics
global_df, global_N, global_avgdl = compute_global_idf(shard_stats)
# Phase 2: score each shard's documents using global IDF
all_hits = []
for shard_id in shard_doc_data:
doc_data = shard_doc_data[shard_id]
inv_index = shard_indexes[shard_id]
doc_cats = shard_doc_categories[shard_id]
candidate_indices = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
shard_scores = []
for idx in candidate_indices:
dd = doc_data[idx]
s = score_bm25(dd, query_terms, global_df, global_N, global_avgdl)
if s > 0:
shard_scores.append((dd, s))
shard_scores.sort(key=lambda x: x[1], reverse=True)
for dd, s in shard_scores[:per_shard_limit]:
all_hits.append((dd, s, shard_id))
all_hits.sort(key=lambda x: x[1], reverse=True)
hits = []
for dd, s, shard_id in all_hits[:limit]:
hits.append({"id": dd["id"], "title": dd["title"], "score": s, "shard": shard_id})
return {
"query_id": query["id"],
"type": query.get("type", "unknown"),
"q": query["q"],
"filter": query.get("filter"),
"hits": hits,
"total_hits": len(all_hits),
"shards_queried": list(shard_doc_data.keys()),
"merge_strategy": "score_dfs",
}
def simulate_distributed_search_rrf(
shard_doc_data: Dict[int, List[DocData]],
shard_indexes: Dict[int, Dict[str, List[int]]],
@ -363,12 +444,14 @@ def run_experiment(
ground_truth_file = output_dir / "ground-truth.jsonl"
distributed_file = output_dir / "distributed.jsonl"
rrf_file = output_dir / "distributed-rrf.jsonl"
dfs_file = output_dir / "distributed-dfs.jsonl"
print(f"\nRunning experiments...")
with open(ground_truth_file, "w") as gt_f, \
open(distributed_file, "w") as dist_f, \
open(rrf_file, "w") as rrf_f:
open(rrf_file, "w") as rrf_f, \
open(dfs_file, "w") as dfs_f:
for i, query in enumerate(queries):
if (i + 1) % 1000 == 0:
print(f" Processed {i + 1} queries...")
@ -391,11 +474,18 @@ def run_experiment(
)
rrf_f.write(json.dumps(rrf_result) + "\n")
dfs_result = simulate_distributed_search_dfs(
shard_doc_data, shard_indexes, shard_doc_categories,
shard_stats, query, limit,
)
dfs_f.write(json.dumps(dfs_result) + "\n")
print(f" Completed {len(queries)} queries")
print(f"\nResults saved to:")
print(f" {ground_truth_file}")
print(f" {distributed_file}")
print(f" {rrf_file}")
print(f" {dfs_file}")
# Save experiment metadata
exp_meta = {
@ -404,7 +494,7 @@ def run_experiment(
"shard_count": shard_count,
"limit": limit,
"total_queries": len(queries),
"merge_strategies": ["score", "rrf"],
"merge_strategies": ["score", "rrf", "dfs"],
"rrf_k": RRF_K,
"global_stats": {"N": global_stats[1], "avgdl": global_stats[2]},
"shard_stats": {
@ -465,6 +555,7 @@ def main():
print("\nTo compare results, run:")
print(f" python3 {output_dir}/compare.py {output_dir}/ground-truth.jsonl {output_dir}/distributed.jsonl --verbose")
print(f" python3 {output_dir}/compare.py {output_dir}/ground-truth.jsonl {output_dir}/distributed-rrf.jsonl --verbose")
print(f" python3 {output_dir}/compare.py {output_dir}/ground-truth.jsonl {output_dir}/distributed-dfs.jsonl --verbose")
if __name__ == "__main__":