miroir/docs/research/score-comparability/results/compare.py
jedarden 360378bde2 P11.8: Amend plan §12 to reflect Rust-idiomatic test layout
The plan §12 previously specified tests/ at root with integration/
and chaos/ subdirectories. However, the actual implementation uses
the idiomatic Rust convention with tests in crates/*/tests/.

This commit:
- Updates plan §12 repository structure to document the actual layout
- Moves tests/benches/score-comparability to docs/research/ (research artifacts)
- Removes the now-empty tests/ directory

CI already runs cargo test --all --all-features which correctly
discovers and runs all crate-level integration tests.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 06:49:04 -04:00

261 lines
8.6 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Compare result rankings and compute Kendall tau correlation.
Kendall tau measures the ordinal association between two ranked sequences.
τ = (concordant pairs - discordant pairs) / total pairs
Range: [-1, 1], where 1 = perfect agreement, 0 = independent, -1 = perfect disagreement
"""
import argparse
import json
import math
from pathlib import Path
from typing import List, Dict, Tuple
def kendall_tau(ranking1: List[str], ranking2: List[str]) -> Tuple[float, Dict]:
"""
Compute Kendall tau correlation between two rankings.
Uses O(n log n) algorithm via sorting.
"""
# Build position maps
pos1 = {doc_id: i for i, doc_id in enumerate(ranking1)}
pos2 = {doc_id: i for i, doc_id in enumerate(ranking2)}
# Get common documents (documents that appear in both rankings)
common_docs = set(pos1.keys()) & set(pos2.keys())
if len(common_docs) < 2:
return 0.0, {
"concordant": 0,
"discordant": 0,
"total_pairs": 0,
"common_docs": len(common_docs),
"only_in_r1": len(pos1) - len(common_docs),
"only_in_r2": len(pos2) - len(common_docs),
}
# Sort common docs by position in ranking1
sorted_by_r1 = sorted(common_docs, key=lambda x: pos1[x])
# Count inversions in ranking2 order
# Each inversion is a discordant pair
r2_positions = [pos2[doc] for doc in sorted_by_r1]
# Count discordant pairs using merge sort
def count_inversions(arr):
if len(arr) <= 1:
return arr, 0
mid = len(arr) // 2
left, inv_left = count_inversions(arr[:mid])
right, inv_right = count_inversions(arr[mid:])
merged = []
inv_count = inv_left + inv_right
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
merged.append(left[i])
i += 1
else:
merged.append(right[j])
inv_count += len(left) - i
j += 1
merged.extend(left[i:])
merged.extend(right[j:])
return merged, inv_count
_, discordant = count_inversions(r2_positions)
total_pairs = len(common_docs) * (len(common_docs) - 1) // 2
concordant = total_pairs - discordant
tau = (concordant - discordant) / total_pairs if total_pairs > 0 else 0.0
return tau, {
"concordant": concordant,
"discordant": discordant,
"total_pairs": total_pairs,
"common_docs": len(common_docs),
"only_in_r1": len(pos1) - len(common_docs),
"only_in_r2": len(pos2) - len(common_docs),
}
def load_results(results_file: Path) -> Dict:
"""Load search results from JSON file."""
with open(results_file) as f:
return json.load(f)
def extract_ranking(results: Dict, top_k: int = None) -> List[str]:
"""Extract document IDs from search results in ranking order."""
hits = results.get("hits", [])
if top_k:
hits = hits[:top_k]
return [hit.get("id") or hit.get("_id", "") for hit in hits]
def compare_query_sets(
ground_truth_file: Path,
distributed_file: Path,
top_k: int = 100,
) -> Dict:
"""
Compare two query result sets.
Returns statistics including:
- Average Kendall tau across all queries
- Per-query tau values
- Query types where divergence is highest
"""
with open(ground_truth_file) as f:
ground_truth = {json.loads(line)["query_id"]: json.loads(line) for line in f}
with open(distributed_file) as f:
distributed = {json.loads(line)["query_id"]: json.loads(line) for line in f}
results = []
tau_by_type = {}
for query_id, gt_result in ground_truth.items():
if query_id not in distributed:
continue
dist_result = distributed[query_id]
gt_ranking = extract_ranking(gt_result, top_k)
dist_ranking = extract_ranking(dist_result, top_k)
if not gt_ranking or not dist_ranking:
continue
tau, details = kendall_tau(gt_ranking, dist_ranking)
query_type = gt_result.get("type", "unknown")
if query_type not in tau_by_type:
tau_by_type[query_type] = []
tau_by_type[query_type].append(tau)
results.append({
"query_id": query_id,
"query_type": query_type,
"tau": tau,
"details": details,
"query": gt_result.get("q", ""),
})
if not results:
return {"error": "No common queries found"}
# Compute statistics
tau_values = [r["tau"] for r in results]
avg_tau = sum(tau_values) / len(tau_values)
min_tau = min(tau_values)
max_tau = max(tau_values)
# Count queries below threshold
below_095 = sum(1 for t in tau_values if t < 0.95)
below_090 = sum(1 for t in tau_values if t < 0.90)
below_080 = sum(1 for t in tau_values if t < 0.80)
# 95% confidence intervals (normal approximation, n >= 10000)
variance = sum((t - avg_tau) ** 2 for t in tau_values) / (len(tau_values) - 1)
stddev = math.sqrt(variance)
stderr = stddev / math.sqrt(len(tau_values))
z = 1.96
ci_low = avg_tau - z * stderr
ci_high = avg_tau + z * stderr
# Per-type statistics
type_stats = {}
for qtype, taus in tau_by_type.items():
tn = len(taus)
tmean = sum(taus) / tn if taus else 0
tvar = sum((t - tmean) ** 2 for t in taus) / (tn - 1) if tn > 1 else 0
tsd = math.sqrt(tvar)
tse = tsd / math.sqrt(tn) if tn > 0 else 0
type_stats[qtype] = {
"count": tn,
"avg_tau": tmean,
"min_tau": min(taus) if taus else 0,
"max_tau": max(taus) if taus else 0,
"ci_95": [tmean - z * tse, tmean + z * tse] if tn > 1 else None,
"stddev": tsd,
}
return {
"total_queries": len(results),
"avg_tau": avg_tau,
"min_tau": min_tau,
"max_tau": max_tau,
"ci_95": [ci_low, ci_high],
"stddev": stddev,
"stderr": stderr,
"below_095_count": below_095,
"below_090_count": below_090,
"below_080_count": below_080,
"pass_criteria": avg_tau >= 0.95,
"type_stats": type_stats,
"per_query": results,
}
def main():
parser = argparse.ArgumentParser(description="Compare search result rankings")
parser.add_argument("ground_truth", type=str, help="Ground truth results file (JSONL)")
parser.add_argument("distributed", type=str, help="Distributed results file (JSONL)")
parser.add_argument("--output", type=str, help="Output file for comparison report")
parser.add_argument("--top-k", type=int, default=100, help="Compare top K results")
parser.add_argument("--verbose", action="store_true", help="Show per-query details")
args = parser.parse_args()
result = compare_query_sets(
Path(args.ground_truth),
Path(args.distributed),
args.top_k,
)
if "error" in result:
print(f"Error: {result['error']}")
return
# Print summary
print(f"Comparison Summary (top-{args.top_k})")
print(f"=" * 50)
print(f"Total queries: {result['total_queries']}")
ci = result['ci_95']
print(f"Avg Kendall tau: {result['avg_tau']:.4f} (95% CI: [{ci[0]:.4f}, {ci[1]:.4f}])")
print(f"Min tau: {result['min_tau']:.4f}")
print(f"Max tau: {result['max_tau']:.4f}")
print(f"Queries below 0.95: {result['below_095_count']} ({100*result['below_095_count']/result['total_queries']:.1f}%)")
print(f"Queries below 0.90: {result['below_090_count']} ({100*result['below_090_count']/result['total_queries']:.1f}%)")
print(f"Queries below 0.80: {result['below_080_count']} ({100*result['below_080_count']/result['total_queries']:.1f}%)")
print(f"Pass criteria (avg >= 0.95): {'PASS' if result['pass_criteria'] else 'FAIL'}")
print(f"\nPer-query type:")
for qtype, stats in result["type_stats"].items():
ci_str = f", 95% CI: [{stats['ci_95'][0]:.4f}, {stats['ci_95'][1]:.4f}]" if stats.get('ci_95') else ""
print(f" {qtype}: avg={stats['avg_tau']:.4f}{ci_str}, min={stats['min_tau']:.4f}, max={stats['max_tau']:.4f} (n={stats['count']})")
if args.verbose:
print(f"\nPer-query details:")
for qr in sorted(result["per_query"], key=lambda x: x["tau"])[:10]:
print(f" {qr['query_id']}: tau={qr['tau']:.4f} ({qr['query_type']}) - '{qr['query'][:50]}'")
print(f" ... (showing 10 worst)")
# Save to file if requested
if args.output:
with open(args.output, "w") as f:
json.dump(result, f, indent=2)
print(f"\nResults saved to {args.output}")
if __name__ == "__main__":
main()