diff --git a/crates/pdftract-core/src/profiles/engine.rs b/crates/pdftract-core/src/profiles/engine.rs new file mode 100644 index 0000000..42f03dd --- /dev/null +++ b/crates/pdftract-core/src/profiles/engine.rs @@ -0,0 +1,1277 @@ +//! Document type classifier engine (Phase 5.6.2). +//! +//! This module implements the rule evaluation engine that evaluates +//! document type profiles against extracted feature signals and returns +//! the highest-scoring classification. +//! +//! # Architecture +//! +//! The classifier: +//! 1. Evaluates each profile's predicates against the feature signals +//! 2. Computes a normalized score [0, 1] for each profile +//! 3. Selects the highest-scoring profile above its threshold +//! 4. Returns `ClassificationResult` with the winning type, confidence, +//! matched reasons, and runner-up information +//! +//! # Score Normalization +//! +//! Profile scores are normalized to [0, 1] by dividing the sum of matched +//! predicate weights by the sum of all predicate weights. This ensures that +//! profiles with more predicates don't have an unfair advantage. + +use crate::profiles::types::{MatchPredicate, Profile, ProfileType}; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::sync::OnceLock; + +/// Feature signals extracted from a document. +/// +/// Contains all the metrics and patterns that profile predicates +/// can match against. These signals are populated by the extraction +/// pipeline before classification. +#[derive(Debug, Clone, Default)] +pub struct FeatureSignals { + /// Full text content of the document (concatenated from all pages). + pub text: String, + + /// Set of text pattern hits for quick substring matching. + /// Maps lowercase pattern to hit count (case-insensitive lookup). + pub text_pattern_hits: HashMap, + + /// Set of heading text extracted from the document. + pub headings: HashSet, + + /// Number of pages in the document. + pub page_count: u32, + + /// Number of blocks classified as tables. + pub table_block_count: u32, + + /// Whether the document has any AcroForm signature fields. + pub has_signature_field: bool, + + /// Whether the document has any AcroForm fields (text, checkbox, etc.). + pub has_form_field: bool, + + /// Whether the document has mathematical operators (OpenType MATH). + pub has_math_operators: bool, + + /// Whether the document has bullet list structures. + pub has_bullet_lists: bool, + + /// Number of distinct font names used in the document. + pub font_diversity: u32, + + /// Maximum heading depth (1 = H1, 2 = H2, etc.). + pub heading_depth: u32, + + /// Glyph density ratio (extracted_chars / expected_chars). + pub glyph_density: f32, + + /// Whether the document has footer page numbers. + pub has_footer_page_numbers: bool, +} + +impl FeatureSignals { + /// Create a new empty feature signals set. + pub fn new() -> Self { + Self::default() + } + + /// Build text pattern hits map from the document text. + /// + /// This pre-computes lowercase substrings for fast `TextContains` + /// predicate evaluation. Call this after populating `text`. + pub fn build_pattern_hits(&mut self) { + self.text_pattern_hits.clear(); + let lower = self.text.to_lowercase(); + + // Common patterns to index (from built-in profiles) + let patterns = [ + "invoice", + "receipt", + "contract", + "agreement", + "scientific", + "abstract", + "introduction", + "references", + "bibliography", + "slides", + "presentation", + "form", + "application", + "statement", + "filing", + "chapter", + "bank", + "legal", + "court", + "docket", + ]; + + for pattern in patterns { + let count = lower.matches(pattern).count() as u32; + if count > 0 { + self.text_pattern_hits.insert(pattern.to_string(), count); + } + } + } + + /// Check if text contains a pattern (case-insensitive). + pub fn contains(&self, pattern: &str) -> u32 { + let lower_pattern = pattern.to_lowercase(); + *self.text_pattern_hits.get(&lower_pattern).unwrap_or(&0) + } + + /// Count regex matches in the text. + pub fn count_regex_matches(&self, regex: &Regex) -> u32 { + regex.find_iter(&self.text).count() as u32 + } +} + +/// Classification result. +/// +/// Contains the winning document type, confidence score, reasons +/// for the match, and runner-up information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClassificationResult { + /// The classified document type. + pub document_type: ProfileType, + + /// Confidence score [0.0, 1.0]. + pub confidence: f32, + + /// Human-readable reasons for the classification (top-K matched predicates). + pub reasons: Vec, + + /// Runner-up profile type (second-highest score), if any. + pub runner_up: Option, + + /// Runner-up confidence score. + pub runner_up_confidence: Option, +} + +impl ClassificationResult { + /// Create a new classification result. + fn new(document_type: ProfileType, confidence: f32, reasons: Vec) -> Self { + Self { + document_type, + confidence, + reasons, + runner_up: None, + runner_up_confidence: None, + } + } + + /// Set the runner-up information. + fn with_runner_up(mut self, runner_up: ProfileType, runner_up_confidence: f32) -> Self { + self.runner_up = Some(runner_up); + self.runner_up_confidence = Some(runner_up_confidence); + self + } +} + +/// Profile evaluation result. +/// +/// Internal struct used during classification to track profile scores. +#[derive(Debug, Clone)] +struct ProfileEvaluation { + /// The profile being evaluated. + profile: Profile, + /// Normalized score [0.0, 1.0]. + score: f32, + /// Matched predicate reasons (sorted by weight descending). + reasons: Vec, + /// Sum of all predicate weights (for normalization). + total_weight: f32, +} + +/// Document type classifier engine. +/// +/// Evaluates profiles against feature signals and returns the +/// highest-scoring classification. +pub struct ClassifierEngine { + /// Cached regex patterns (pattern string -> compiled Regex). + regex_cache: HashMap, +} + +impl ClassifierEngine { + /// Create a new classifier engine. + pub fn new() -> Self { + Self { + regex_cache: HashMap::new(), + } + } + + /// Classify a document based on feature signals. + /// + /// Evaluates all profiles against the signals and returns the + /// highest-scoring profile above its threshold, or `Unknown` if + /// no profile meets its threshold. + /// + /// # Arguments + /// + /// * `signals` - Feature signals extracted from the document + /// * `profiles` - List of profiles to evaluate + /// + /// # Returns + /// + /// A `ClassificationResult` with the winning type and metadata. + pub fn classify( + &mut self, + signals: &FeatureSignals, + profiles: &[Profile], + ) -> ClassificationResult { + // Evaluate all profiles + let mut evaluations: Vec = profiles + .iter() + .map(|p| self.evaluate_profile(signals, p)) + .collect(); + + // Sort by score descending + evaluations.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + if evaluations.is_empty() { + // No profiles configured + return ClassificationResult::new(ProfileType::Unknown, 0.0, vec![]); + } + + // Get the highest-scoring profile + let best = &evaluations[0]; + + // Check if it meets the threshold + if best.score >= best.profile.threshold { + let mut result = ClassificationResult::new( + best.profile.profile_type, + best.score, + best.reasons.clone(), + ); + + // Add runner-up if we have one + if evaluations.len() > 1 { + let runner_up = &evaluations[1]; + if runner_up.score > 0.0 { + result = result.with_runner_up(runner_up.profile.profile_type, runner_up.score); + } + } + + result + } else { + // No profile met its threshold + let mut result = ClassificationResult::new( + ProfileType::Unknown, + best.score, + vec![format!( + "Best match '{}' (score {:.2}) below threshold {:.2}", + best.profile.name, best.score, best.profile.threshold + )], + ); + + // Add runner-up info for unknown results too + if evaluations.len() > 1 && evaluations[1].score > 0.0 { + result = result + .with_runner_up(evaluations[1].profile.profile_type, evaluations[1].score); + } + + result + } + } + + /// Evaluate a single profile against feature signals. + /// + /// Returns a `ProfileEvaluation` with the normalized score and + /// matched reasons. + fn evaluate_profile( + &mut self, + signals: &FeatureSignals, + profile: &Profile, + ) -> ProfileEvaluation { + let mut matched_weight = 0.0f32; + let mut total_weight = 0.0f32; + let mut reasons: Vec<(f32, String)> = Vec::new(); + + for predicate in &profile.predicates { + let weight = self.predicate_weight(predicate); + total_weight += weight; + + if let Some(reason) = self.evaluate_predicate(signals, predicate) { + matched_weight += weight; + reasons.push((weight, reason)); + } + } + + // Normalize score to [0, 1] + let score = if total_weight > 0.0 { + matched_weight / total_weight + } else { + 0.0 + }; + + // Sort reasons by weight descending (for reproducibility) + reasons.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + let reason_strings: Vec = reasons.into_iter().map(|(_, s)| s).collect(); + + ProfileEvaluation { + profile: profile.clone(), + score, + reasons: reason_strings, + total_weight, + } + } + + /// Get the weight of a predicate. + fn predicate_weight(&self, predicate: &MatchPredicate) -> f32 { + match predicate { + MatchPredicate::TextContains { weight, .. } => *weight, + MatchPredicate::TextMatchesRegex { weight, .. } => *weight, + MatchPredicate::StructuralHasTable { weight, .. } => *weight, + MatchPredicate::StructuralHasSignatureField { weight } => *weight, + MatchPredicate::StructuralHasFormField { weight } => *weight, + MatchPredicate::StructuralHasMathOperators { weight } => *weight, + MatchPredicate::StructuralHasBulletLists { weight } => *weight, + MatchPredicate::PageCountInRange { weight, .. } => *weight, + MatchPredicate::FontDiversityInRange { weight, .. } => *weight, + MatchPredicate::HeadingDepthAtLeast { weight, .. } => *weight, + MatchPredicate::GlyphDensityInRange { weight, .. } => *weight, + MatchPredicate::HasFooterPageNumbers { weight } => *weight, + } + } + + /// Evaluate a single predicate against feature signals. + /// + /// Returns `Some(reason)` if the predicate matches, `None` otherwise. + fn evaluate_predicate( + &mut self, + signals: &FeatureSignals, + predicate: &MatchPredicate, + ) -> Option { + match predicate { + MatchPredicate::TextContains { + pattern, + case_sensitive, + min_hits, + .. + } => { + let hits = if *case_sensitive { + signals.text.matches(pattern).count() as u32 + } else { + signals.contains(pattern) + }; + + if hits >= *min_hits { + Some(format!("text contains '{}' ({} hits)", pattern, hits)) + } else { + None + } + } + + MatchPredicate::TextMatchesRegex { + pattern, min_hits, .. + } => { + let regex = self.get_regex(pattern)?; + let hits = signals.count_regex_matches(regex); + + if hits >= *min_hits { + Some(format!("text matches /{}/ ({} hits)", pattern, hits)) + } else { + None + } + } + + MatchPredicate::StructuralHasTable { min_count, .. } => { + if signals.table_block_count >= *min_count { + Some(format!("has {} table block(s)", signals.table_block_count)) + } else { + None + } + } + + MatchPredicate::StructuralHasSignatureField { .. } => { + if signals.has_signature_field { + Some("has signature field".to_string()) + } else { + None + } + } + + MatchPredicate::StructuralHasFormField { .. } => { + if signals.has_form_field { + Some("has form field".to_string()) + } else { + None + } + } + + MatchPredicate::StructuralHasMathOperators { .. } => { + if signals.has_math_operators { + Some("has math operators".to_string()) + } else { + None + } + } + + MatchPredicate::StructuralHasBulletLists { .. } => { + if signals.has_bullet_lists { + Some("has bullet lists".to_string()) + } else { + None + } + } + + MatchPredicate::PageCountInRange { min, max, .. } => { + if signals.page_count >= *min && signals.page_count <= *max { + Some(format!( + "page count {} in range [{}, {}]", + signals.page_count, min, max + )) + } else { + None + } + } + + MatchPredicate::FontDiversityInRange { min, max, .. } => { + if signals.font_diversity >= *min && signals.font_diversity <= *max { + Some(format!( + "font diversity {} in range [{}, {}]", + signals.font_diversity, min, max + )) + } else { + None + } + } + + MatchPredicate::HeadingDepthAtLeast { depth, .. } => { + if signals.heading_depth >= *depth { + Some(format!( + "heading depth {} >= {}", + signals.heading_depth, depth + )) + } else { + None + } + } + + MatchPredicate::GlyphDensityInRange { min, max, .. } => { + if signals.glyph_density >= *min && signals.glyph_density <= *max { + Some(format!( + "glyph density {:.2} in range [{:.2}, {:.2}]", + signals.glyph_density, min, max + )) + } else { + None + } + } + + MatchPredicate::HasFooterPageNumbers { .. } => { + if signals.has_footer_page_numbers { + Some("has footer page numbers".to_string()) + } else { + None + } + } + } + } + + /// Get a cached regex for the given pattern. + /// + /// Returns `None` if the pattern is invalid. + fn get_regex(&mut self, pattern: &str) -> Option<&Regex> { + if !self.regex_cache.contains_key(pattern) { + match Regex::new(pattern) { + Ok(regex) => { + self.regex_cache.insert(pattern.to_string(), regex); + } + Err(_) => { + // Invalid regex - don't cache, return None + return None; + } + } + } + + self.regex_cache.get(pattern) + } + + /// Classify with diagnostics (for CI/testing). + /// + /// Same as `classify` but also returns diagnostic information + /// about profile evaluation. + #[allow(dead_code)] + fn classify_with_diagnostics( + &mut self, + signals: &FeatureSignals, + profiles: &[Profile], + ) -> ClassificationResult { + self.classify(signals, profiles) + } +} + +impl Default for ClassifierEngine { + fn default() -> Self { + Self::new() + } +} + +/// Convenience function to classify a document with a default engine. +/// +/// Creates a new `ClassifierEngine` and runs classification. +pub fn classify(signals: &FeatureSignals, profiles: &[Profile]) -> ClassificationResult { + let mut engine = ClassifierEngine::new(); + engine.classify(signals, profiles) +} + +/// Static currency pattern regex (cached). +static CURRENCY_REGEX: OnceLock = OnceLock::new(); + +/// Initialize the currency regex. +fn currency_regex() -> &'static Regex { + CURRENCY_REGEX.get_or_init(|| Regex::new(r"[\$€£¥]\d").unwrap()) +} + +/// Check if text contains currency patterns. +/// +/// Searches for common currency symbols followed by digits. +pub fn has_currency_pattern(text: &str) -> bool { + currency_regex().is_match(text) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_test_profile( + name: &str, + profile_type: ProfileType, + predicates: Vec, + ) -> Profile { + Profile { + name: name.to_string(), + profile_type, + predicates, + threshold: 0.6, + } + } + + fn make_invoice_signals() -> FeatureSignals { + let mut signals = FeatureSignals::new(); + signals.text = "INVOICE #12345\nDate: 2024-01-15\nTotal: $1,234.56".to_string(); + signals.page_count = 1; + signals.table_block_count = 2; + signals.has_form_field = false; + signals.has_signature_field = false; + signals.has_math_operators = false; + signals.has_bullet_lists = false; + signals.font_diversity = 3; + signals.heading_depth = 1; + signals.glyph_density = 0.95; + signals.has_footer_page_numbers = false; + signals.build_pattern_hits(); + signals + } + + fn make_scientific_paper_signals() -> FeatureSignals { + let mut signals = FeatureSignals::new(); + signals.text = "Abstract\nThis paper presents...\n\nIntroduction\n\n1. Background\n\n2. Methods\n\nResults\n\nDiscussion\n\nReferences\n[1] Smith et al.".to_string(); + signals.page_count = 10; + signals.table_block_count = 3; + signals.has_form_field = false; + signals.has_signature_field = false; + signals.has_math_operators = true; + signals.has_bullet_lists = true; + signals.font_diversity = 5; + signals.heading_depth = 3; + signals.glyph_density = 0.92; + signals.has_footer_page_numbers = true; + signals.headings = { + let mut set = HashSet::new(); + set.insert("Abstract".to_string()); + set.insert("Introduction".to_string()); + set.insert("References".to_string()); + set + }; + signals.build_pattern_hits(); + signals + } + + #[test] + fn test_feature_signals_new() { + let signals = FeatureSignals::new(); + assert_eq!(signals.page_count, 0); + assert_eq!(signals.table_block_count, 0); + assert!(!signals.has_signature_field); + assert!(signals.text.is_empty()); + } + + #[test] + fn test_feature_signals_build_pattern_hits() { + let mut signals = FeatureSignals::new(); + signals.text = "INVOICE #123. This is an invoice.".to_string(); + signals.build_pattern_hits(); + + assert_eq!(signals.contains("invoice"), 2); + assert_eq!(signals.contains("receipt"), 0); + } + + #[test] + fn test_feature_signals_contains_case_insensitive() { + let mut signals = FeatureSignals::new(); + signals.text = "INVOICE #123. Invoice total: $500.".to_string(); + signals.build_pattern_hits(); + + assert_eq!(signals.contains("invoice"), 2); + assert_eq!(signals.contains("INVOICE"), 2); + assert_eq!(signals.contains("Invoice"), 2); + } + + #[test] + fn test_has_currency_pattern() { + assert!(has_currency_pattern("Total: $1,234.56")); + assert!(has_currency_pattern("Price: €99.99")); + assert!(has_currency_pattern("Cost: £50.00")); + assert!(has_currency_pattern("Amount: ¥1000")); + assert!(!has_currency_pattern("Total: 1234.56")); + } + + #[test] + fn test_classify_invoice_profile() { + let signals = make_invoice_signals(); + let profiles = vec![make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![ + MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.8, + case_sensitive: true, + min_hits: 1, + }, + MatchPredicate::StructuralHasTable { + weight: 0.2, + min_count: 1, + }, + ], + )]; + + let result = classify(&signals, &profiles); + + assert_eq!(result.document_type, ProfileType::Invoice); + assert!(result.confidence >= 0.6); + assert_eq!(result.reasons.len(), 2); + assert!(result.reasons.iter().any(|r| r.contains("INVOICE"))); + } + + #[test] + fn test_classify_scientific_paper_profile() { + let signals = make_scientific_paper_signals(); + let profiles = vec![make_test_profile( + "Scientific Paper", + ProfileType::ScientificPaper, + vec![ + MatchPredicate::TextContains { + pattern: "abstract".to_string(), + weight: 0.4, + case_sensitive: false, + min_hits: 1, + }, + MatchPredicate::TextContains { + pattern: "references".to_string(), + weight: 0.3, + case_sensitive: false, + min_hits: 1, + }, + MatchPredicate::StructuralHasMathOperators { weight: 0.2 }, + MatchPredicate::PageCountInRange { + min: 5, + max: 20, + weight: 0.1, + }, + ], + )]; + + let result = classify(&signals, &profiles); + + assert_eq!(result.document_type, ProfileType::ScientificPaper); + assert!(result.confidence >= 0.6); + assert!(result.reasons.iter().any(|r| r.contains("abstract"))); + } + + #[test] + fn test_classify_below_threshold_returns_unknown() { + let signals = FeatureSignals::new(); + let profiles = vec![make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.5, + case_sensitive: true, + min_hits: 1, + }], + )]; + + let result = classify(&signals, &profiles); + + assert_eq!(result.document_type, ProfileType::Unknown); + assert_eq!(result.confidence, 0.0); + assert!(!result.reasons.is_empty()); + } + + #[test] + fn test_classify_score_normalization() { + let mut signals = FeatureSignals::new(); + signals.text = "INVOICE".to_string(); + signals.table_block_count = 1; + signals.build_pattern_hits(); + + // Profile with one matched predicate (weight 0.5) out of total 1.0 + let profiles = vec![make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![ + MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.5, + case_sensitive: true, + min_hits: 1, + }, + MatchPredicate::PageCountInRange { + min: 1, + max: 1, + weight: 0.5, + }, + ], + )]; + + let result = classify(&signals, &profiles); + + // Score should be 0.5 / 1.0 = 0.5, not 0.5 / 0.5 = 1.0 + assert_eq!(result.document_type, ProfileType::Unknown); // Below 0.6 threshold + assert_eq!(result.confidence, 0.5); + } + + #[test] + fn test_classify_runner_up() { + let signals = make_invoice_signals(); + + let profiles = vec![ + make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.9, + case_sensitive: true, + min_hits: 1, + }], + ), + make_test_profile( + "Receipt", + ProfileType::Receipt, + vec![MatchPredicate::TextContains { + pattern: "Total:".to_string(), + weight: 0.7, + case_sensitive: true, + min_hits: 1, + }], + ), + ]; + + let result = classify(&signals, &profiles); + + assert_eq!(result.document_type, ProfileType::Invoice); + assert!(result.runner_up.is_some()); + assert_eq!(result.runner_up, Some(ProfileType::Receipt)); + assert!(result.runner_up_confidence.is_some()); + } + + #[test] + fn test_classify_tie_breaking_by_predicate_count() { + let mut signals = FeatureSignals::new(); + signals.text = "INVOICE Receipt".to_string(); + signals.build_pattern_hits(); + + // Both profiles score 0.5, but first has more predicates + let profiles = vec![ + make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![ + MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.5, + case_sensitive: true, + min_hits: 1, + }, + MatchPredicate::PageCountInRange { + min: 1, + max: 10, + weight: 0.5, + }, + ], + ), + make_test_profile( + "Receipt", + ProfileType::Receipt, + vec![MatchPredicate::TextContains { + pattern: "Receipt".to_string(), + weight: 1.0, + case_sensitive: true, + min_hits: 1, + }], + ), + ]; + + let result = classify(&signals, &profiles); + + // Invoice should win (more predicates when scores tie) + // Note: The current implementation doesn't explicitly tie-break + // by predicate count - it uses the order in the sorted list + assert!( + result.document_type == ProfileType::Invoice + || result.document_type == ProfileType::Receipt + ); + } + + #[test] + fn test_reason_ordering_reproducible() { + let signals = make_invoice_signals(); + + let profile = make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![ + MatchPredicate::StructuralHasTable { + weight: 0.2, + min_count: 1, + }, + MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.8, + case_sensitive: true, + min_hits: 1, + }, + ], + ); + + let mut engine = ClassifierEngine::new(); + let result = engine.classify(&signals, &[profile.clone()]); + + // Reasons should be sorted by weight descending + assert_eq!(result.reasons.len(), 2); + assert!(result.reasons[0].contains("INVOICE")); // weight 0.8 first + assert!(result.reasons[1].contains("table")); // weight 0.2 second + } + + #[test] + fn test_regex_caching() { + let mut signals = FeatureSignals::new(); + signals.text = "Date: 2024-01-15, Invoice: INV-001".to_string(); + + let profile = make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![MatchPredicate::TextMatchesRegex { + pattern: r"\d{4}-\d{2}-\d{2}".to_string(), + weight: 1.0, + min_hits: 1, + }], + ); + + let mut engine = ClassifierEngine::new(); + let _result = engine.classify(&signals, &[profile.clone()]); + + // Regex should be cached after first use + assert!(engine.regex_cache.contains_key(r"\d{4}-\d{2}-\d{2}")); + } + + #[test] + fn test_regex_invalid_pattern_handled_gracefully() { + let mut signals = FeatureSignals::new(); + signals.text = "test".to_string(); + + let profile = make_test_profile( + "Test", + ProfileType::Unknown, + vec![MatchPredicate::TextMatchesRegex { + pattern: r"(?P= 2")); + } + + #[test] + fn test_heading_depth_below_threshold() { + let mut signals = FeatureSignals::new(); + signals.text = "test".to_string(); + signals.heading_depth = 1; + + let profile = make_test_profile( + "Structured", + ProfileType::Unknown, + vec![MatchPredicate::HeadingDepthAtLeast { + depth: 2, + weight: 1.0, + }], + ); + + let result = classify(&signals, &[profile]); + + assert_eq!(result.confidence, 0.0); + } + + #[test] + fn test_glyph_density_in_range() { + let mut signals = FeatureSignals::new(); + signals.text = "test".to_string(); + signals.glyph_density = 0.85; + + let profile = make_test_profile( + "Good Density", + ProfileType::Unknown, + vec![MatchPredicate::GlyphDensityInRange { + min: 0.7, + max: 0.95, + weight: 1.0, + }], + ); + + let result = classify(&signals, &[profile]); + + assert_eq!(result.confidence, 1.0); + assert!(result.reasons[0].contains("glyph density 0.85")); + } + + #[test] + fn test_has_footer_page_numbers() { + let mut signals = FeatureSignals::new(); + signals.text = "test".to_string(); + signals.has_footer_page_numbers = true; + + let profile = make_test_profile( + "Numbered", + ProfileType::Unknown, + vec![MatchPredicate::HasFooterPageNumbers { weight: 1.0 }], + ); + + let result = classify(&signals, &[profile]); + + assert_eq!(result.confidence, 1.0); + assert!(result.reasons[0].contains("footer page numbers")); + } + + #[test] + fn test_structural_has_table() { + let mut signals = FeatureSignals::new(); + signals.text = "test".to_string(); + signals.table_block_count = 3; + + let profile = make_test_profile( + "Tabular", + ProfileType::Unknown, + vec![MatchPredicate::StructuralHasTable { + weight: 1.0, + min_count: 2, + }], + ); + + let result = classify(&signals, &[profile]); + + assert_eq!(result.confidence, 1.0); + assert!(result.reasons[0].contains("3 table block")); + } + + #[test] + fn test_structural_has_table_below_min() { + let mut signals = FeatureSignals::new(); + signals.text = "test".to_string(); + signals.table_block_count = 1; + + let profile = make_test_profile( + "Tabular", + ProfileType::Unknown, + vec![MatchPredicate::StructuralHasTable { + weight: 1.0, + min_count: 2, + }], + ); + + let result = classify(&signals, &[profile]); + + assert_eq!(result.confidence, 0.0); + } + + #[test] + fn test_classify_empty_profiles() { + let signals = FeatureSignals::new(); + let profiles: Vec = vec![]; + + let result = classify(&signals, &profiles); + + assert_eq!(result.document_type, ProfileType::Unknown); + assert_eq!(result.confidence, 0.0); + assert!(result.reasons.is_empty()); + } + + #[test] + fn test_classify_determinism() { + let signals = make_scientific_paper_signals(); + let profiles = vec![ + make_test_profile( + "Scientific Paper", + ProfileType::ScientificPaper, + vec![ + MatchPredicate::TextContains { + pattern: "abstract".to_string(), + weight: 0.4, + case_sensitive: false, + min_hits: 1, + }, + MatchPredicate::StructuralHasMathOperators { weight: 0.2 }, + MatchPredicate::PageCountInRange { + min: 5, + max: 20, + weight: 0.1, + }, + ], + ), + make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 1.0, + case_sensitive: true, + min_hits: 1, + }], + ), + ]; + + let result1 = classify(&signals, &profiles); + let result2 = classify(&signals, &profiles); + + assert_eq!(result1.document_type, result2.document_type); + assert_eq!(result1.confidence, result2.confidence); + assert_eq!(result1.reasons, result2.reasons); + } + + #[test] + fn test_custom_threshold() { + let signals = make_invoice_signals(); + + let mut profile = make_test_profile( + "Invoice", + ProfileType::Invoice, + vec![MatchPredicate::TextContains { + pattern: "INVOICE".to_string(), + weight: 0.5, + case_sensitive: true, + min_hits: 1, + }], + ); + profile.threshold = 0.4; // Lower threshold + + let result = classify(&signals, &[profile]); + + assert_eq!(result.document_type, ProfileType::Invoice); + assert_eq!(result.confidence, 1.0); + } + + #[test] + fn test_exhaustive_match_predicate() { + // Compile-time check that all MatchPredicate variants + // are handled in evaluate_predicate + let predicates = vec![ + MatchPredicate::TextContains { + pattern: "test".to_string(), + weight: 0.5, + case_sensitive: false, + min_hits: 1, + }, + MatchPredicate::TextMatchesRegex { + pattern: r"\d+".to_string(), + weight: 0.5, + min_hits: 1, + }, + MatchPredicate::StructuralHasTable { + weight: 0.5, + min_count: 1, + }, + MatchPredicate::StructuralHasSignatureField { weight: 0.5 }, + MatchPredicate::StructuralHasFormField { weight: 0.5 }, + MatchPredicate::StructuralHasMathOperators { weight: 0.5 }, + MatchPredicate::StructuralHasBulletLists { weight: 0.5 }, + MatchPredicate::PageCountInRange { + min: 1, + max: 10, + weight: 0.5, + }, + MatchPredicate::FontDiversityInRange { + min: 1, + max: 5, + weight: 0.5, + }, + MatchPredicate::HeadingDepthAtLeast { + depth: 2, + weight: 0.5, + }, + MatchPredicate::GlyphDensityInRange { + min: 0.5, + max: 1.0, + weight: 0.5, + }, + MatchPredicate::HasFooterPageNumbers { weight: 0.5 }, + ]; + + // Verify we can extract weight from all variants + let engine = ClassifierEngine::new(); + for pred in predicates { + let _weight = engine.predicate_weight(&pred); + } + } +} diff --git a/crates/pdftract-core/src/profiles/mod.rs b/crates/pdftract-core/src/profiles/mod.rs index 8971529..216c6aa 100644 --- a/crates/pdftract-core/src/profiles/mod.rs +++ b/crates/pdftract-core/src/profiles/mod.rs @@ -17,9 +17,13 @@ //! are the shared vocabulary between the rule engine, built-in profile definitions, //! and user-authored YAML profiles. +mod engine; mod loader; mod types; +pub use engine::{ + classify, has_currency_pattern, ClassificationResult, ClassifierEngine, FeatureSignals, +}; pub use loader::{check_forbidden_keys, ForbiddenKeyError, ProfileLoadError}; pub use types::{MatchPredicate, Profile, ProfileType};