pdftract/crates/pdftract-core/tests/classifier_corpus.rs
jedarden 6000c654ce fix: resolve compilation errors across codebase
- Fixed missing fields in BlockJson, SpanJson, ExtractionOptions initializations
- Added feature gates to ocr_integration tests for conditional compilation
- Fixed McpServerState::new calls to include audit writer argument
- Fixed CCITTFaxDecoder::decode calls to use instance method
- Fixed type casts for ObjRef::new calls
- Fixed serde_json::Value method calls (is_some -> !is_null)
- Fixed ProfileType test feature gates
- Worked around lifetime issues in schema roundtrip tests

These changes fix numerous compilation errors that were blocking the
codebase from building. The main library and tests now compile successfully.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 08:38:04 -04:00

478 lines
14 KiB
Rust

//! Classifier corpus validation tests
//!
//! This module tests the document type classifier against the 200-document
//! labeled corpus at `tests/fixtures/classifier/`.
//!
//! The corpus is partitioned as:
//! - 50 invoices
//! - 50 scientific papers
//! - 50 contracts
//! - 50 misc (receipts, forms, bank statements, slide decks, legal filings, book excerpts, magazines)
//!
//! Acceptance criteria (from plan.md Phase 5.6):
//! - Per-class precision and recall >= 0.85
//! - Macro-F1 >= 0.88
//! - Reproducibility: classifying the same document twice produces identical output
use std::collections::HashMap;
use std::path::{Path, PathBuf};
// Import pdftract_core modules for classification
#[cfg(feature = "profiles")]
use pdftract_core::extract::extract_pdf;
#[cfg(feature = "profiles")]
use pdftract_core::options::ExtractionOptions;
#[cfg(feature = "profiles")]
use pdftract_core::profiles::{classify, extract_signals_from_results, load_builtins, ProfileType};
/// Get the corpus directory path, handling both workspace and crate test locations
fn get_corpus_dir() -> PathBuf {
// Try from crate tests directory first (when running from crate)
let crate_path = Path::new("../../../tests/fixtures/classifier");
if crate_path.exists() {
return crate_path.to_path_buf();
}
// Try workspace root (when running from workspace)
let workspace_path = Path::new("tests/fixtures/classifier");
if workspace_path.exists() {
return workspace_path.to_path_buf();
}
// Try using CARGO_MANIFEST_DIR
if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
// CARGO_MANIFEST_DIR points to the crate root (e.g., /path/to/crates/pdftract-core)
// We need to go up to the workspace root and then into tests/fixtures/classifier
let from_manifest = PathBuf::from(manifest_dir).join("../../tests/fixtures/classifier");
if from_manifest.exists() {
return from_manifest;
}
}
// Fallback: panic with helpful message
panic!(
"Classifier corpus directory not found. Tried:\n 1. {}\n 2. {}\n 3. $CARGO_MANIFEST_DIR/../../tests/fixtures/classifier",
crate_path.display(),
workspace_path.display()
);
}
/// Minimum per-class precision/recall threshold
const MIN_PRECISION_RECALL: f64 = 0.85;
/// Minimum macro-F1 threshold
const MIN_MACRO_F1: f64 = 0.88;
/// Document type classification result
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ClassificationResult {
/// Predicted document type
predicted_type: String,
/// Expected document type (from MANIFEST.tsv)
expected_type: String,
/// Document path
path: PathBuf,
}
/// Per-class statistics
#[derive(Debug, Default)]
struct ClassStats {
/// True positives: correctly classified as this class
tp: usize,
/// False positives: incorrectly classified as this class
fp: usize,
/// False negatives: this class incorrectly classified as something else
fn_val: usize,
}
impl ClassStats {
/// Calculate precision: TP / (TP + FP)
fn precision(&self) -> f64 {
let denominator = self.tp + self.fp;
if denominator == 0 {
0.0
} else {
self.tp as f64 / denominator as f64
}
}
/// Calculate recall: TP / (TP + FN)
fn recall(&self) -> f64 {
let denominator = self.tp + self.fn_val;
if denominator == 0 {
0.0
} else {
self.tp as f64 / denominator as f64
}
}
/// Calculate F1 score: 2 * (precision * recall) / (precision + recall)
fn f1(&self) -> f64 {
let p = self.precision();
let r = self.recall();
if p + r == 0.0 {
0.0
} else {
2.0 * (p * r) / (p + r)
}
}
}
/// Manifest entry
struct ManifestEntry {
path: PathBuf,
expected_type: String,
source_url: String,
license: String,
}
/// Parse MANIFEST.tsv file
fn parse_manifest() -> Vec<ManifestEntry> {
let corpus_dir = get_corpus_dir();
let manifest_path = corpus_dir.join("MANIFEST.tsv");
// Skip test if corpus not present (e.g., in CI without test data)
if !manifest_path.exists() {
eprintln!(
"SKIPPED: Classifier corpus not found at {}",
manifest_path.display()
);
eprintln!(
"To run this test, generate the corpus using: python3 scripts/generate_test_corpus.py"
);
std::process::exit(0); // Exit with success since this is expected in some environments
}
let content = std::fs::read_to_string(&manifest_path)
.unwrap_or_else(|e| panic!("Failed to read manifest: {e}"));
let mut entries = Vec::new();
for (line_num, line) in content.lines().enumerate() {
// Skip header
if line_num == 0 {
continue;
}
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() < 4 {
continue;
}
entries.push(ManifestEntry {
path: PathBuf::from(parts[0]),
expected_type: parts[1].to_string(),
source_url: parts[2].to_string(),
license: parts[3].to_string(),
});
}
entries
}
/// Classify a document using the pdftract classifier
///
/// Extracts the PDF, computes feature signals, and runs classification.
/// Returns the document type as a string, or None if classification fails.
#[cfg(feature = "profiles")]
fn classify_document(path: &Path) -> Option<String> {
// Extract PDF with default options
let options = ExtractionOptions::default();
let result = match extract_pdf(path, &options) {
Ok(r) => r,
Err(e) => {
eprintln!("WARNING: Failed to extract PDF {}: {:?}", path.display(), e);
return None;
}
};
// Check for form fields and signature fields
let has_signature_field = !result.signatures.is_empty();
let has_form_field = !result.form_fields.is_empty();
// Convert pages to (blocks, spans) tuples for signal extraction
let page_data: Vec<(Vec<_>, Vec<_>)> = result
.pages
.iter()
.map(|p| (p.blocks.clone(), p.spans.clone()))
.collect();
// Extract feature signals
let signals = extract_signals_from_results(&page_data, has_signature_field, has_form_field);
// Load built-in profiles
let profiles = load_builtins();
if profiles.is_empty() {
eprintln!("WARNING: No built-in profiles available (profiles feature may be disabled)");
return None;
}
// Run classification
let classification = classify(&signals, &profiles);
// Map ProfileType to string (matching classify.rs mapping)
let doc_type = match classification.document_type {
ProfileType::Invoice => "invoice",
ProfileType::Receipt => "receipt",
ProfileType::Contract => "contract",
ProfileType::ScientificPaper => "scientific_paper",
ProfileType::SlideDeck => "slide_deck",
ProfileType::Form => "form",
ProfileType::BankStatement => "bank_statement",
ProfileType::LegalFiling => "legal_filing",
ProfileType::BookChapter => "book_chapter",
ProfileType::Unknown => "unknown",
};
Some(doc_type.to_string())
}
/// Classify a document using the pdftract classifier (without profiles feature).
///
/// Returns None when the profiles feature is disabled.
#[cfg(not(feature = "profiles"))]
fn classify_document(_path: &Path) -> Option<String> {
None
}
/// Run classification on all documents in the corpus
fn run_corpus_classification() -> Vec<ClassificationResult> {
let manifest = parse_manifest();
let corpus_base = get_corpus_dir();
let mut results = Vec::new();
for entry in &manifest {
let full_path = corpus_base.join(&entry.path);
if !full_path.exists() {
panic!("Corpus file not found: {}", full_path.display());
}
// Skip classification if not implemented yet
if let Some(predicted) = classify_document(&full_path) {
results.push(ClassificationResult {
predicted_type: predicted,
expected_type: entry.expected_type.clone(),
path: full_path,
});
}
}
results
}
/// Compute per-class statistics from classification results
fn compute_class_stats(results: &[ClassificationResult]) -> HashMap<String, ClassStats> {
let mut stats: HashMap<String, ClassStats> = HashMap::new();
for result in results {
// Update stats for the predicted class
let pred_stats = stats.entry(result.predicted_type.clone()).or_default();
if result.predicted_type == result.expected_type {
pred_stats.tp += 1;
} else {
pred_stats.fp += 1;
}
// Update stats for the expected class (for FN counting)
let exp_stats = stats.entry(result.expected_type.clone()).or_default();
if result.predicted_type != result.expected_type {
exp_stats.fn_val += 1;
}
}
stats
}
/// Calculate macro-F1 score (average of per-class F1 scores)
fn compute_macro_f1(stats: &HashMap<String, ClassStats>) -> f64 {
if stats.is_empty() {
return 0.0;
}
let total_f1: f64 = stats.values().map(|s| s.f1()).sum();
total_f1 / stats.len() as f64
}
#[test]
fn test_classifier_corpus_accuracy() {
// This test will be enabled once the classifier is implemented
// For now, it's a placeholder that documents the expected structure
let results = run_corpus_classification();
if results.is_empty() {
// Classifier not implemented yet - skip gracefully
eprintln!("SKIP: Classifier not yet implemented (Phase 5.6)");
return;
}
let stats = compute_class_stats(&results);
// Check per-class precision and recall
for (class_name, class_stats) in &stats {
let precision = class_stats.precision();
let recall = class_stats.recall();
println!(
"{}: precision={:.3}, recall={:.3}, f1={:.3}",
class_name,
precision,
recall,
class_stats.f1()
);
assert!(
precision >= MIN_PRECISION_RECALL,
"{} precision ({:.3}) below threshold ({:.3})",
class_name,
precision,
MIN_PRECISION_RECALL
);
assert!(
recall >= MIN_PRECISION_RECALL,
"{} recall ({:.3}) below threshold ({:.3})",
class_name,
recall,
MIN_PRECISION_RECALL
);
}
// Check macro-F1
let macro_f1 = compute_macro_f1(&stats);
println!("Macro-F1: {:.3}", macro_f1);
assert!(
macro_f1 >= MIN_MACRO_F1,
"Macro-F1 ({:.3}) below threshold ({:.3})",
macro_f1,
MIN_MACRO_F1
);
}
#[test]
fn test_classifier_reproducibility() {
// Test that classifying the same document twice produces identical output
// Sample 20 documents for this test
let manifest = parse_manifest();
let corpus_base = get_corpus_dir();
// Sample first 20 documents
let sample_docs: Vec<_> = manifest.iter().take(20).collect();
for entry in sample_docs {
let full_path = corpus_base.join(&entry.path);
if !full_path.exists() {
continue;
}
// Classify twice
let result1 = classify_document(&full_path);
let result2 = classify_document(&full_path);
// Check for reproducibility
match (result1, result2) {
(Some(r1), Some(r2)) => {
assert_eq!(
r1,
r2,
"Classification not reproducible for {}",
full_path.display()
);
}
(None, None) => {
// Classifier not implemented - skip
continue;
}
_ => {
panic!(
"Inconsistent classification results for {}",
full_path.display()
);
}
}
}
}
#[test]
fn test_corpus_manifest_validity() {
// Test that the manifest is well-formed and all referenced files exist
let manifest = parse_manifest();
let corpus_base = get_corpus_dir();
assert!(!manifest.is_empty(), "Manifest is empty");
// Count documents per type
let mut type_counts: HashMap<&str, usize> = HashMap::new();
for entry in &manifest {
let full_path = corpus_base.join(&entry.path);
assert!(
full_path.exists(),
"Referenced file not found: {}",
full_path.display()
);
*type_counts.entry(&entry.expected_type).or_insert(0) += 1;
// Check that source_url and license are present
assert!(
!entry.source_url.is_empty(),
"Missing source_url for {}",
entry.path.display()
);
assert!(
!entry.license.is_empty(),
"Missing license for {}",
entry.path.display()
);
}
// Verify expected counts
assert_eq!(
type_counts.get("invoice").copied().unwrap_or(0),
50,
"Expected 50 invoices"
);
assert_eq!(
type_counts.get("scientific_paper").copied().unwrap_or(0),
50,
"Expected 50 scientific papers"
);
assert_eq!(
type_counts.get("contract").copied().unwrap_or(0),
50,
"Expected 50 contracts"
);
// Verify misc subtypes
let misc_total = type_counts
.iter()
.filter(|(k, _)| {
matches!(
*k,
&"receipt"
| &"form"
| &"bank_statement"
| &"slide_deck"
| &"legal_filing"
| &"book_excerpt"
| &"magazine"
)
})
.map(|(_, v)| *v)
.sum::<usize>();
assert_eq!(misc_total, 50, "Expected 50 misc documents");
println!("Manifest validity check passed:");
println!(" - Total documents: {}", manifest.len());
for (type_name, count) in &type_counts {
println!(" - {}: {}", type_name, count);
}
}