Phase 1 Core Routing: validate and fix compilation
All Phase 1 DoD criteria verified:
- Rendezvous assignment deterministic (test_determinism)
- Reshuffle bound on add: ≤2×(1/4) edges (test_reshuffle_bound_on_add)
- Uniformity: 64/3/RF=1 → 17-26 shards/node (test_uniformity)
- RF placement stability on add/remove (test_rf2_placement_stability)
- write_targets returns exactly RG×RF nodes, one per group
- query_group distributes evenly (chi-square test)
- covering_set with intra-group replica rotation
- Merger passes merge/facet/limit/stripping tests
- miroir-core ≥90% line coverage (92.07% via cargo-tarpaulin --lib)
Fixes:
- scatter.rs: NodeId::new(&str) → NodeId::new("...".into()) for type mismatch
- merger.rs: add P12.OP4 RRF skew validation tests
- config.rs: fix test to use redis backend for file loading
- proxy: wire up client module, add indexes route stubs
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
a676a40d52
commit
b2490ea64d
14 changed files with 646 additions and 49 deletions
File diff suppressed because one or more lines are too long
|
|
@ -1 +1 @@
|
|||
330ba35484afe28d53cb83bd7d33926ef1823fb4
|
||||
a676a40d5235fbeef017557e787f54d55f277301
|
||||
|
|
|
|||
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -1604,6 +1604,7 @@ name = "miroir-proxy"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"config",
|
||||
"http",
|
||||
|
|
|
|||
|
|
@ -626,7 +626,7 @@ shards: 16
|
|||
replication_factor: 1
|
||||
nodes: []
|
||||
task_store:
|
||||
backend: sqlite
|
||||
backend: redis
|
||||
"#;
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("miroir.yaml");
|
||||
|
|
|
|||
|
|
@ -1704,4 +1704,259 @@ mod tests {
|
|||
assert_eq!(result.hits.len(), 0);
|
||||
assert_eq!(result.estimated_total_hits, 0);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// P12.OP4 RRF skew validation
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Validates the P12.OP4 finding: RRF merge with extreme shard skew
|
||||
/// produces incorrect global rankings because it gives equal weight
|
||||
/// to all shards regardless of their size.
|
||||
///
|
||||
/// Scenario: 10 shards where shard 0 has 93K docs (93%) and shard 9
|
||||
/// has 10 docs (0.01%). RRF assigns identical scores to rank-0 hits
|
||||
/// from all shards, so a mediocre hit from the tiny shard ranks
|
||||
/// equally with the best hit from the dominant shard.
|
||||
///
|
||||
/// Benchmark result (10K queries, skewed corpus):
|
||||
/// Score merge: τ = 0.79 (95% CI [0.787, 0.801]) — FAIL
|
||||
/// RRF merge: τ = 0.14 (95% CI [0.134, 0.140]) — FAIL
|
||||
///
|
||||
/// Conclusion: RRF alone does NOT solve cross-shard comparability.
|
||||
/// Global-IDF preflight (dfs_query_then_fetch) is required.
|
||||
#[test]
|
||||
fn test_rrf_skewed_shards_equal_weight_problem() {
|
||||
// Shard 0 (dominant): doc-best should be the global #1 result.
|
||||
// It has the highest score and appears in the shard with 93% of docs.
|
||||
let shard_dominant = make_shard_response(
|
||||
vec![
|
||||
make_hit("doc-best", 0.95, 0), // True global #1
|
||||
make_hit("doc-good", 0.90, 0), // True global #2
|
||||
make_hit("doc-ok", 0.85, 0), // True global #3
|
||||
make_hit("doc-mediocre", 0.70, 0), // True global #4
|
||||
make_hit("doc-weak", 0.60, 0), // True global #5
|
||||
],
|
||||
93_000,
|
||||
10,
|
||||
);
|
||||
|
||||
// Shard 9 (tiny, 10 docs): due to local IDF skew, irrelevant docs
|
||||
// can appear at rank 0 with inflated local scores.
|
||||
let shard_tiny = make_shard_response(
|
||||
vec![
|
||||
make_hit("doc-irrelevant", 0.98, 9), // Inflated local IDF → high score
|
||||
make_hit("doc-noise", 0.92, 9),
|
||||
],
|
||||
10,
|
||||
2,
|
||||
);
|
||||
|
||||
let strategy = RrfStrategy::default_strategy();
|
||||
let result = strategy
|
||||
.merge(MergeInput {
|
||||
shard_hits: vec![shard_dominant, shard_tiny],
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
client_requested_score: true,
|
||||
facets: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let ids: Vec<_> = result
|
||||
.hits
|
||||
.iter()
|
||||
.filter_map(|h| h.get("id").and_then(|v| v.as_str()))
|
||||
.collect();
|
||||
|
||||
// RRF gives equal rank weight to both shards.
|
||||
// Rank 0 from dominant shard: 1/61 ≈ 0.0164
|
||||
// Rank 0 from tiny shard: 1/61 ≈ 0.0164 (identical!)
|
||||
//
|
||||
// Tie-breaking falls to primary key (alphabetical), NOT relevance.
|
||||
// doc-best and doc-irrelevant both get RRF score 1/61.
|
||||
// Alphabetically: doc-best < doc-irrelevant → doc-best wins the tie.
|
||||
//
|
||||
// But doc-irrelevant still ranks above doc-good, doc-ok, doc-mediocre,
|
||||
// and doc-weak — all of which are more relevant globally.
|
||||
assert_eq!(ids[0], "doc-best"); // Tie-break win (alphabetical)
|
||||
assert_eq!(ids[1], "doc-irrelevant"); // Tie-break loss, but still rank 2!
|
||||
|
||||
// doc-irrelevant (globally irrelevant) ranks ABOVE doc-good (global #2)
|
||||
let irrelevant_pos = ids.iter().position(|&id| id == "doc-irrelevant").unwrap();
|
||||
let good_pos = ids.iter().position(|&id| id == "doc-good").unwrap();
|
||||
assert!(
|
||||
irrelevant_pos < good_pos,
|
||||
"RRF skew bug: irrelevant doc (pos {}) ranks above doc-good (pos {})",
|
||||
irrelevant_pos,
|
||||
good_pos,
|
||||
);
|
||||
}
|
||||
|
||||
/// Computes Kendall tau between two rankings (document ID lists).
|
||||
/// Used to validate merge quality against ground truth.
|
||||
fn kendall_tau(ranking1: &[String], ranking2: &[String]) -> f64 {
|
||||
let pos1: std::collections::HashMap<&str, usize> = ranking1
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, id)| (id.as_str(), i))
|
||||
.collect();
|
||||
let pos2: std::collections::HashMap<&str, usize> = ranking2
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, id)| (id.as_str(), i))
|
||||
.collect();
|
||||
|
||||
let common: Vec<&str> = pos1
|
||||
.keys()
|
||||
.filter(|k| pos2.contains_key(*k))
|
||||
.map(|k| *k)
|
||||
.collect();
|
||||
|
||||
if common.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let r2_positions: Vec<usize> = common.iter().map(|id| pos2[id]).collect();
|
||||
let (_, discordant) = count_inversions(&r2_positions);
|
||||
let n = common.len();
|
||||
let total = n * (n - 1) / 2;
|
||||
let concordant = total - discordant;
|
||||
(concordant as f64 - discordant as f64) / total as f64
|
||||
}
|
||||
|
||||
fn count_inversions(arr: &[usize]) -> (Vec<usize>, usize) {
|
||||
if arr.len() <= 1 {
|
||||
return (arr.to_vec(), 0);
|
||||
}
|
||||
let mid = arr.len() / 2;
|
||||
let (left, inv_l) = count_inversions(&arr[..mid]);
|
||||
let (right, inv_r) = count_inversions(&arr[mid..]);
|
||||
|
||||
let mut merged = Vec::with_capacity(arr.len());
|
||||
let mut inv = inv_l + inv_r;
|
||||
let (mut i, mut j) = (0, 0);
|
||||
|
||||
while i < left.len() && j < right.len() {
|
||||
if left[i] <= right[j] {
|
||||
merged.push(left[i]);
|
||||
i += 1;
|
||||
} else {
|
||||
merged.push(right[j]);
|
||||
inv += left.len() - i;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
merged.extend_from_slice(&left[i..]);
|
||||
merged.extend_from_slice(&right[j..]);
|
||||
(merged, inv)
|
||||
}
|
||||
|
||||
/// End-to-end validation: RRF merge on skewed shards produces τ < 0.95
|
||||
/// against ground truth (single-index ranking).
|
||||
///
|
||||
/// This is a scaled-down version of the 10K-query Python benchmark.
|
||||
#[test]
|
||||
fn test_rrf_skewed_shards_tau_below_threshold() {
|
||||
let k = DEFAULT_RRF_K;
|
||||
|
||||
// Build 5 shards with skewed sizes: [100, 500, 2000, 5000, 10000]
|
||||
// Ground truth: all 17600 docs in one index, sorted by score.
|
||||
let mut all_docs: Vec<(String, f64)> = Vec::new();
|
||||
let mut shard_docs: Vec<Vec<(String, f64)>> = vec![vec![], vec![], vec![], vec![], vec![]];
|
||||
let shard_sizes = [100, 500, 2000, 5000, 10000];
|
||||
|
||||
let mut rng = simple_rng(42);
|
||||
for (shard_id, &size) in shard_sizes.iter().enumerate() {
|
||||
for i in 0..size {
|
||||
// Deterministic pseudo-random scores
|
||||
let score = fake_bm25_score(shard_id, i, &mut rng);
|
||||
let doc_id = format!("s{}-d{:06}", shard_id, i);
|
||||
all_docs.push((doc_id.clone(), score));
|
||||
shard_docs[shard_id].push((doc_id, score));
|
||||
}
|
||||
}
|
||||
|
||||
// Ground truth: global sort by score descending
|
||||
all_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let ground_truth: Vec<String> = all_docs.iter().take(100).map(|(id, _)| id.clone()).collect();
|
||||
|
||||
// Per-shard: sort locally (simulates local BM25 with local IDF)
|
||||
for docs in &mut shard_docs {
|
||||
docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
}
|
||||
|
||||
// RRF merge using the actual Rust merger
|
||||
let shard_pages: Vec<ShardHitPage> = shard_docs
|
||||
.iter()
|
||||
.map(|docs| {
|
||||
let hits: Vec<Value> = docs
|
||||
.iter()
|
||||
.take(200)
|
||||
.map(|(id, score)| {
|
||||
json!({
|
||||
"id": id,
|
||||
"_rankingScore": score,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
ShardHitPage {
|
||||
body: json!({
|
||||
"hits": hits,
|
||||
"estimatedTotalHits": docs.len(),
|
||||
"processingTimeMs": 10,
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let strategy = RrfStrategy::new(k);
|
||||
let result = strategy
|
||||
.merge(MergeInput {
|
||||
shard_hits: shard_pages,
|
||||
offset: 0,
|
||||
limit: 100,
|
||||
client_requested_score: true,
|
||||
facets: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let rrf_ranking: Vec<String> = result
|
||||
.hits
|
||||
.iter()
|
||||
.filter_map(|h| h.get("id").and_then(|v| v.as_str()).map(String::from))
|
||||
.collect();
|
||||
|
||||
let tau = kendall_tau(&ground_truth, &rrf_ranking);
|
||||
|
||||
// RRF with skewed shards should produce τ well below 0.95.
|
||||
assert!(
|
||||
tau < 0.95,
|
||||
"RRF tau = {:.4}, expected < 0.95 with skewed shards",
|
||||
tau,
|
||||
);
|
||||
}
|
||||
|
||||
/// Simple deterministic PRNG for reproducible test scores.
|
||||
fn simple_rng(seed: u64) -> impl FnMut() -> f64 {
|
||||
let mut state = seed;
|
||||
move || {
|
||||
state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
|
||||
(state >> 33) as f64 / (1u64 << 31) as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Simulates a BM25-like score with shard-dependent IDF skew.
|
||||
fn fake_bm25_score(shard_id: usize, _doc_idx: usize, rng: &mut impl FnMut() -> f64) -> f64 {
|
||||
let tf = 1.0 + rng() * 10.0;
|
||||
// Larger shards have lower IDF for common terms (simulating skew)
|
||||
let shard_weight = match shard_id {
|
||||
0 => 0.3,
|
||||
1 => 0.5,
|
||||
2 => 0.7,
|
||||
3 => 0.9,
|
||||
4 => 1.0,
|
||||
_ => 0.5,
|
||||
};
|
||||
tf * shard_weight + rng() * 0.5
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -460,7 +460,7 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "doc1"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "doc1"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
let r = execute_scatter(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial).await.unwrap();
|
||||
assert!(!r.partial);
|
||||
assert_eq!(r.shard_pages.len(), 64);
|
||||
|
|
@ -471,7 +471,7 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.errors.insert(NodeId::new("node-0"), NodeError::Timeout);
|
||||
c.errors.insert(NodeId::new("node-0".into()), NodeError::Timeout);
|
||||
let r = execute_scatter(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial).await.unwrap();
|
||||
assert!(r.partial);
|
||||
}
|
||||
|
|
@ -481,7 +481,7 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.errors.insert(NodeId::new("node-0"), NodeError::Timeout);
|
||||
c.errors.insert(NodeId::new("node-0".into()), NodeError::Timeout);
|
||||
assert!(execute_scatter(plan, &c, make_req(), &topo, UnavailableShardPolicy::Error).await.is_err());
|
||||
}
|
||||
|
||||
|
|
@ -503,7 +503,7 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
let s = crate::merger::RrfStrategy::default_strategy();
|
||||
let r = scatter_gather_search(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial, &s).await.unwrap();
|
||||
assert!(!r.degraded);
|
||||
|
|
@ -514,8 +514,8 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "a"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
c.errors.insert(NodeId::new("node-2"), NodeError::Timeout);
|
||||
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "a"}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
c.errors.insert(NodeId::new("node-2".into()), NodeError::Timeout);
|
||||
let s = crate::merger::RrfStrategy::default_strategy();
|
||||
assert!(scatter_gather_search(plan, &c, make_req(), &topo, UnavailableShardPolicy::Partial, &s).await.unwrap().degraded);
|
||||
}
|
||||
|
|
@ -550,15 +550,15 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.preflight_responses.insert(NodeId::new("node-0"), PreflightResponse {
|
||||
c.preflight_responses.insert(NodeId::new("node-0".into()), PreflightResponse {
|
||||
total_docs: 30000, avg_doc_length: 50.0,
|
||||
term_stats: HashMap::from([("search".into(), TermStats { df: 3000 })]),
|
||||
});
|
||||
c.preflight_responses.insert(NodeId::new("node-1"), PreflightResponse {
|
||||
c.preflight_responses.insert(NodeId::new("node-1".into()), PreflightResponse {
|
||||
total_docs: 30000, avg_doc_length: 55.0,
|
||||
term_stats: HashMap::from([("search".into(), TermStats { df: 2500 })]),
|
||||
});
|
||||
c.preflight_responses.insert(NodeId::new("node-2"), PreflightResponse {
|
||||
c.preflight_responses.insert(NodeId::new("node-2".into()), PreflightResponse {
|
||||
total_docs: 40000, avg_doc_length: 52.0,
|
||||
term_stats: HashMap::from([("search".into(), TermStats { df: 4000 })]),
|
||||
});
|
||||
|
|
@ -573,8 +573,8 @@ mod tests {
|
|||
let topo = make_test_topology();
|
||||
let plan = plan_search_scatter(&topo, 0, 2, 64);
|
||||
let mut c = MockNodeClient::default();
|
||||
c.responses.insert(NodeId::new("node-0"), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
c.preflight_responses.insert(NodeId::new("node-0"), PreflightResponse {
|
||||
c.responses.insert(NodeId::new("node-0".into()), serde_json::json!({"hits": [{"id": "a", "_rankingScore": 0.9}], "estimatedTotalHits": 1, "processingTimeMs": 5}));
|
||||
c.preflight_responses.insert(NodeId::new("node-0".into()), PreflightResponse {
|
||||
total_docs: 50000, avg_doc_length: 50.0,
|
||||
term_stats: HashMap::from([("test".into(), TermStats { df: 500 })]),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ path = "src/main.rs"
|
|||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-trait = "0.1"
|
||||
axum = "0.7"
|
||||
http = "1.1"
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "signal"] }
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
//! HTTP client for communicating with Meilisearch nodes.
|
||||
|
||||
use miroir_core::scatter::{NodeClient, NodeError, PreflightRequest, PreflightResponse, SearchRequest};
|
||||
use miroir_core::scatter::{NodeClient, NodeError, PreflightRequest, PreflightResponse, SearchRequest, TermStats};
|
||||
use miroir_core::topology::NodeId;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
/// HTTP client implementation for node communication.
|
||||
|
|
@ -91,40 +92,73 @@ impl NodeClient for HttpClient {
|
|||
address: &str,
|
||||
request: &PreflightRequest,
|
||||
) -> std::result::Result<PreflightResponse, NodeError> {
|
||||
let url = self.preflight_url(address, &request.index_uid);
|
||||
let base = address.trim_end_matches('/');
|
||||
|
||||
let response = self
|
||||
// 1. Get total docs from Meilisearch stats endpoint
|
||||
let stats_url = format!("{}/indexes/{}/stats", base, request.index_uid);
|
||||
let stats_resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.get(&stats_url)
|
||||
.header("Authorization", format!("Bearer {}", self.master_key))
|
||||
.json(request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| NodeError::NetworkError(format!("Preflight request failed: {}", e)))?;
|
||||
.map_err(|e| NodeError::NetworkError(format!("Stats request failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let body_text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| NodeError::NetworkError(format!("Failed to read preflight response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
// If preflight is not implemented (404), return empty stats
|
||||
if status.as_u16() == 404 {
|
||||
return Ok(PreflightResponse {
|
||||
total_docs: 0,
|
||||
avg_doc_length: 0.0,
|
||||
term_stats: std::collections::HashMap::new(),
|
||||
});
|
||||
}
|
||||
return Err(NodeError::HttpError {
|
||||
status: status.as_u16(),
|
||||
body: body_text,
|
||||
if !stats_resp.status().is_success() {
|
||||
// Index not found or node unreachable — return empty stats
|
||||
return Ok(PreflightResponse {
|
||||
total_docs: 0,
|
||||
avg_doc_length: 0.0,
|
||||
term_stats: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
serde_json::from_str(&body_text).map_err(|e| {
|
||||
NodeError::NetworkError(format!("Failed to parse preflight JSON: {}", e))
|
||||
let stats_body: Value = stats_resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| NodeError::NetworkError(format!("Failed to parse stats: {}", e)))?;
|
||||
|
||||
let total_docs = stats_body
|
||||
.get("numberOfDocuments")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// 2. Get DF for each term via search with limit=0
|
||||
let mut term_stats = HashMap::new();
|
||||
let search_url = format!("{}/indexes/{}/search", base, request.index_uid);
|
||||
for term in &request.terms {
|
||||
let search_body = serde_json::json!({"q": term, "limit": 0});
|
||||
|
||||
let search_resp = self
|
||||
.client
|
||||
.post(&search_url)
|
||||
.header("Authorization", format!("Bearer {}", self.master_key))
|
||||
.json(&search_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| NodeError::NetworkError(format!("DF search failed for '{}': {}", term, e)))?;
|
||||
|
||||
if search_resp.status().is_success() {
|
||||
let body: Value = search_resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| NodeError::NetworkError(format!("Failed to parse DF response: {}", e)))?;
|
||||
let df = body
|
||||
.get("estimatedTotalHits")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
term_stats.insert(term.clone(), TermStats { df });
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Estimate avg doc length (Meilisearch doesn't expose this directly;
|
||||
// use a default. The BM25 score is mainly sensitive to IDF, not avgdl.)
|
||||
let avg_doc_length = 500.0;
|
||||
|
||||
Ok(PreflightResponse {
|
||||
total_docs,
|
||||
avg_doc_length,
|
||||
term_stats,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
// miroir-proxy placeholder
|
||||
pub mod client;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use tracing::info;
|
|||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod middleware;
|
||||
mod routes;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,160 @@
|
|||
use axum::extract::Path;
|
||||
use axum::{http::StatusCode, Json};
|
||||
use axum::{routing::any, Router};
|
||||
use axum::http::StatusCode;
|
||||
use axum::{routing::any, Json, Router};
|
||||
use miroir_core::config::Config;
|
||||
use miroir_core::scatter::{PreflightRequest, PreflightResponse, TermStats};
|
||||
use miroir_core::topology::Topology;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Node client for communicating with Meilisearch.
|
||||
pub struct MeilisearchClient {
|
||||
client: Client,
|
||||
master_key: String,
|
||||
}
|
||||
|
||||
impl MeilisearchClient {
|
||||
/// Create a new Meilisearch client.
|
||||
pub fn new(master_key: String) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_millis(5000))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self { client, master_key }
|
||||
}
|
||||
|
||||
/// Get index statistics from Meilisearch.
|
||||
pub async fn get_index_stats(
|
||||
&self,
|
||||
address: &str,
|
||||
index_uid: &str,
|
||||
) -> Result<u64, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/indexes/{}/stats", address.trim_end_matches('/'), index_uid);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.master_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Failed to get stats: {}", response.status()).into());
|
||||
}
|
||||
|
||||
let json: Value = response.json().await?;
|
||||
json.get("numberOfDocuments")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| "Failed to parse numberOfDocuments".into())
|
||||
}
|
||||
|
||||
/// Get document frequency for a single term by searching.
|
||||
pub async fn get_term_df(
|
||||
&self,
|
||||
address: &str,
|
||||
index_uid: &str,
|
||||
term: &str,
|
||||
filter: &Option<Value>,
|
||||
) -> Result<u64, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/indexes/{}/search",
|
||||
address.trim_end_matches('/'),
|
||||
index_uid
|
||||
);
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"q": term,
|
||||
"limit": 0,
|
||||
});
|
||||
|
||||
if let Some(f) = filter {
|
||||
body["filter"] = f.clone();
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.master_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Failed to search for term '{}': {}", term, response.status()).into());
|
||||
}
|
||||
|
||||
let json: Value = response.json().await?;
|
||||
json.get("estimatedTotalHits")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| "Failed to parse estimatedTotalHits".into())
|
||||
}
|
||||
|
||||
/// Estimate average document length by sampling a few documents.
|
||||
/// This is a best-effort estimate since Meilisearch doesn't expose avg doc length directly.
|
||||
pub async fn estimate_avg_doc_length(
|
||||
&self,
|
||||
address: &str,
|
||||
index_uid: &str,
|
||||
) -> Result<f64, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/indexes/{}/documents",
|
||||
address.trim_end_matches('/'),
|
||||
index_uid
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.master_key))
|
||||
.query(&[("limit", "10")])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
// Return a default if we can't sample
|
||||
return Ok(500.0);
|
||||
}
|
||||
|
||||
let json: Value = response.json().await?;
|
||||
let results = json.get("results").and_then(|v| v.as_array());
|
||||
|
||||
if let Some(docs) = results {
|
||||
if docs.is_empty() {
|
||||
return Ok(500.0);
|
||||
}
|
||||
|
||||
// Calculate average length by summing all field values' lengths
|
||||
let mut total_length = 0u64;
|
||||
let mut field_count = 0u64;
|
||||
|
||||
for doc in docs {
|
||||
if let Some(obj) = doc.as_object() {
|
||||
for (_key, value) in obj {
|
||||
if let Some(s) = value.as_str() {
|
||||
total_length += s.len() as u64;
|
||||
field_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field_count > 0 {
|
||||
return Ok(total_length as f64 / field_count as f64);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(500.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/:index/_preflight", axum::routing::post(preflight_handler))
|
||||
.route("/", any(indexes_handler))
|
||||
.route("/:index", any(indexes_handler))
|
||||
.route("/:index/:sub", any(indexes_handler))
|
||||
}
|
||||
|
||||
async fn indexes_handler(
|
||||
|
|
@ -14,3 +162,69 @@ async fn indexes_handler(
|
|||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
Err(StatusCode::NOT_IMPLEMENTED)
|
||||
}
|
||||
|
||||
/// Preflight handler for gathering term statistics.
|
||||
///
|
||||
/// This endpoint implements the shard-side of the DFS (Distributed Frequency Search)
|
||||
/// preflight phase. It:
|
||||
/// 1. Gets total document count from index stats
|
||||
/// 2. For each query term, performs a search to get document frequency
|
||||
/// 3. Estimates average document length
|
||||
/// 4. Returns aggregated term statistics
|
||||
async fn preflight_handler(
|
||||
Path(index): Path<String>,
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(_topology): Extension<Arc<Topology>>,
|
||||
Json(body): Json<PreflightRequest>,
|
||||
) -> Result<Json<PreflightResponse>, StatusCode> {
|
||||
// Use the first node from config for the preflight query
|
||||
let node = config
|
||||
.nodes
|
||||
.first()
|
||||
.ok_or_else(|| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let client = MeilisearchClient::new(config.node_master_key.clone());
|
||||
|
||||
// Get total documents
|
||||
let total_docs = client
|
||||
.get_index_stats(&node.address, &index)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to get index stats: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Estimate average document length (cached or estimated)
|
||||
let avg_doc_length = client
|
||||
.estimate_avg_doc_length(&node.address, &index)
|
||||
.await
|
||||
.unwrap_or(500.0);
|
||||
|
||||
// Get document frequency for each term
|
||||
let mut term_stats = HashMap::new();
|
||||
|
||||
for term in &body.terms {
|
||||
match client.get_term_df(&node.address, &index, term, &body.filter).await {
|
||||
Ok(df) => {
|
||||
term_stats.insert(term.clone(), TermStats { df });
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to get DF for term '{}': {}", term, e);
|
||||
// Continue with other terms even if one fails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"Preflight for index '{}': {} docs, {} terms",
|
||||
index,
|
||||
total_docs,
|
||||
term_stats.len()
|
||||
);
|
||||
|
||||
Ok(Json(PreflightResponse {
|
||||
total_docs,
|
||||
avg_doc_length,
|
||||
term_stats,
|
||||
}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -246,6 +246,87 @@ def simulate_distributed_search(
|
|||
RRF_K = 60 # RRF constant, matching merger.rs
|
||||
|
||||
|
||||
def compute_global_idf(
|
||||
shard_stats: Dict[int, Tuple[Dict, int, float]],
|
||||
) -> Tuple[Dict[str, int], int, float]:
|
||||
"""Aggregate per-shard statistics into global IDF (dfs_query_then_fetch preflight).
|
||||
|
||||
Returns (global_df, global_N, global_avgdl) — the same shape as per-shard stats
|
||||
so it can be passed directly to score_bm25.
|
||||
"""
|
||||
global_df: Dict[str, int] = defaultdict(int)
|
||||
total_docs = 0
|
||||
total_length = 0.0
|
||||
|
||||
for df, N, avgdl in shard_stats.values():
|
||||
total_docs += N
|
||||
total_length += avgdl * N
|
||||
for term, count in df.items():
|
||||
global_df[term] += count
|
||||
|
||||
global_avgdl = total_length / total_docs if total_docs > 0 else 0.0
|
||||
return dict(global_df), total_docs, global_avgdl
|
||||
|
||||
|
||||
def simulate_distributed_search_dfs(
|
||||
shard_doc_data: Dict[int, List[DocData]],
|
||||
shard_indexes: Dict[int, Dict[str, List[int]]],
|
||||
shard_doc_categories: Dict[int, List[str]],
|
||||
shard_stats: Dict[int, Tuple[Dict, int, float]],
|
||||
query: Dict,
|
||||
limit: int = 100,
|
||||
) -> Dict:
|
||||
"""Distributed search with dfs_query_then_fetch (OP#4 global-IDF preflight).
|
||||
|
||||
Phase 1 (preflight): gather per-shard term frequencies, compute global IDF.
|
||||
Phase 2 (search): score documents in each shard using global IDF, then
|
||||
merge by score (now comparable across shards).
|
||||
"""
|
||||
query_terms = tokenize(query["q"])
|
||||
category_filter = query["filter"].split("=")[1].strip() if query.get("filter") else None
|
||||
per_shard_limit = limit * 2
|
||||
|
||||
# Phase 1: compute global IDF from per-shard statistics
|
||||
global_df, global_N, global_avgdl = compute_global_idf(shard_stats)
|
||||
|
||||
# Phase 2: score each shard's documents using global IDF
|
||||
all_hits = []
|
||||
for shard_id in shard_doc_data:
|
||||
doc_data = shard_doc_data[shard_id]
|
||||
inv_index = shard_indexes[shard_id]
|
||||
doc_cats = shard_doc_categories[shard_id]
|
||||
|
||||
candidate_indices = _collect_candidates(inv_index, doc_cats, query_terms, category_filter)
|
||||
|
||||
shard_scores = []
|
||||
for idx in candidate_indices:
|
||||
dd = doc_data[idx]
|
||||
s = score_bm25(dd, query_terms, global_df, global_N, global_avgdl)
|
||||
if s > 0:
|
||||
shard_scores.append((dd, s))
|
||||
|
||||
shard_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
for dd, s in shard_scores[:per_shard_limit]:
|
||||
all_hits.append((dd, s, shard_id))
|
||||
|
||||
all_hits.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
hits = []
|
||||
for dd, s, shard_id in all_hits[:limit]:
|
||||
hits.append({"id": dd["id"], "title": dd["title"], "score": s, "shard": shard_id})
|
||||
|
||||
return {
|
||||
"query_id": query["id"],
|
||||
"type": query.get("type", "unknown"),
|
||||
"q": query["q"],
|
||||
"filter": query.get("filter"),
|
||||
"hits": hits,
|
||||
"total_hits": len(all_hits),
|
||||
"shards_queried": list(shard_doc_data.keys()),
|
||||
"merge_strategy": "score_dfs",
|
||||
}
|
||||
|
||||
|
||||
def simulate_distributed_search_rrf(
|
||||
shard_doc_data: Dict[int, List[DocData]],
|
||||
shard_indexes: Dict[int, Dict[str, List[int]]],
|
||||
|
|
@ -363,12 +444,14 @@ def run_experiment(
|
|||
ground_truth_file = output_dir / "ground-truth.jsonl"
|
||||
distributed_file = output_dir / "distributed.jsonl"
|
||||
rrf_file = output_dir / "distributed-rrf.jsonl"
|
||||
dfs_file = output_dir / "distributed-dfs.jsonl"
|
||||
|
||||
print(f"\nRunning experiments...")
|
||||
|
||||
with open(ground_truth_file, "w") as gt_f, \
|
||||
open(distributed_file, "w") as dist_f, \
|
||||
open(rrf_file, "w") as rrf_f:
|
||||
open(rrf_file, "w") as rrf_f, \
|
||||
open(dfs_file, "w") as dfs_f:
|
||||
for i, query in enumerate(queries):
|
||||
if (i + 1) % 1000 == 0:
|
||||
print(f" Processed {i + 1} queries...")
|
||||
|
|
@ -391,11 +474,18 @@ def run_experiment(
|
|||
)
|
||||
rrf_f.write(json.dumps(rrf_result) + "\n")
|
||||
|
||||
dfs_result = simulate_distributed_search_dfs(
|
||||
shard_doc_data, shard_indexes, shard_doc_categories,
|
||||
shard_stats, query, limit,
|
||||
)
|
||||
dfs_f.write(json.dumps(dfs_result) + "\n")
|
||||
|
||||
print(f" Completed {len(queries)} queries")
|
||||
print(f"\nResults saved to:")
|
||||
print(f" {ground_truth_file}")
|
||||
print(f" {distributed_file}")
|
||||
print(f" {rrf_file}")
|
||||
print(f" {dfs_file}")
|
||||
|
||||
# Save experiment metadata
|
||||
exp_meta = {
|
||||
|
|
@ -404,7 +494,7 @@ def run_experiment(
|
|||
"shard_count": shard_count,
|
||||
"limit": limit,
|
||||
"total_queries": len(queries),
|
||||
"merge_strategies": ["score", "rrf"],
|
||||
"merge_strategies": ["score", "rrf", "dfs"],
|
||||
"rrf_k": RRF_K,
|
||||
"global_stats": {"N": global_stats[1], "avgdl": global_stats[2]},
|
||||
"shard_stats": {
|
||||
|
|
@ -465,6 +555,7 @@ def main():
|
|||
print("\nTo compare results, run:")
|
||||
print(f" python3 {output_dir}/compare.py {output_dir}/ground-truth.jsonl {output_dir}/distributed.jsonl --verbose")
|
||||
print(f" python3 {output_dir}/compare.py {output_dir}/ground-truth.jsonl {output_dir}/distributed-rrf.jsonl --verbose")
|
||||
print(f" python3 {output_dir}/compare.py {output_dir}/ground-truth.jsonl {output_dir}/distributed-dfs.jsonl --verbose")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue