pdftract/crates/pdftract-cli/src/classify.rs
jedarden 6000c654ce fix: resolve compilation errors across codebase
- Fixed missing fields in BlockJson, SpanJson, ExtractionOptions initializations
- Added feature gates to ocr_integration tests for conditional compilation
- Fixed McpServerState::new calls to include audit writer argument
- Fixed CCITTFaxDecoder::decode calls to use instance method
- Fixed type casts for ObjRef::new calls
- Fixed serde_json::Value method calls (is_some -> !is_null)
- Fixed ProfileType test feature gates
- Worked around lifetime issues in schema roundtrip tests

These changes fix numerous compilation errors that were blocking the
codebase from building. The main library and tests now compile successfully.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 08:38:04 -04:00

269 lines
9.2 KiB
Rust

//! Document type classification CLI subcommand.
//!
//! This module implements the `pdftract classify` command that classifies
//! a PDF document type without performing full extraction.
use anyhow::{Context, Result};
use pdftract_core::extract::extract_pdf;
use pdftract_core::options::ExtractionOptions;
use serde::Serialize;
use std::path::{Path, PathBuf};
// The profiles feature must be enabled for classification
#[cfg(feature = "profiles")]
use pdftract_core::profiles::{
classify, extract_signals_from_results, load_builtins, FeatureSignals, ProfileType,
};
/// Classification result for JSON output.
#[derive(Debug, Serialize)]
pub struct ClassificationOutput {
document_type: String,
confidence: f32,
reasons: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
runner_up: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
runner_up_confidence: Option<f32>,
}
/// Arguments for the classify subcommand.
pub struct ClassifyArgs {
/// Path to the PDF file
pub input: PathBuf,
/// Optional profiles directory
pub profiles_dir: Option<PathBuf>,
/// Pretty-print JSON output
pub pretty: bool,
/// Top-K reasons to include (0 = all)
pub top_k: usize,
/// Exit with code 1 if document_type is unknown
pub exit_on_unknown: bool,
}
/// Run classification on a PDF file.
#[cfg(feature = "profiles")]
pub fn run_classify(args: ClassifyArgs) -> Result<ClassificationOutput> {
// Validate input file exists
if !args.input.exists() {
anyhow::bail!("Input file not found: {}", args.input.display());
}
// Validate and canonicalize profiles directory if provided
let profiles_dir = if let Some(ref dir) = args.profiles_dir {
Some(canonicalize_profiles_dir(dir)?)
} else {
None
};
// Load built-in profiles
let mut profiles = load_builtins();
// Load custom profiles from directory if provided
if let Some(ref dir) = profiles_dir {
let custom_profiles = load_custom_profiles(dir)?;
profiles.extend(custom_profiles);
}
if profiles.is_empty() {
anyhow::bail!("No profiles available. Built-in profiles may not be enabled.");
}
// Perform extraction with minimal options (fast path for classification)
let options = ExtractionOptions::default();
let result =
extract_pdf(&args.input, &options).context("Failed to extract PDF for classification")?;
// Check for form fields and signature fields
let has_signature_field = !result.signatures.is_empty();
let has_form_field = !result.form_fields.is_empty();
// Convert pages to (blocks, spans) tuples for signal extraction
let page_data: Vec<(Vec<_>, Vec<_>)> = result
.pages
.iter()
.map(|p| (p.blocks.clone(), p.spans.clone()))
.collect();
// Extract feature signals
let signals = extract_signals_from_results(&page_data, has_signature_field, has_form_field);
// Run classification
let classification = classify(&signals, &profiles);
// Apply top-k filter to reasons if specified
let reasons = if args.top_k > 0 && args.top_k < classification.reasons.len() {
classification.reasons[..args.top_k].to_vec()
} else {
classification.reasons
};
// Handle exit_on_unknown
if args.exit_on_unknown && classification.document_type == ProfileType::Unknown {
anyhow::bail!(
"Document type is unknown (confidence: {:.2})",
classification.confidence
);
}
// Map ProfileType to string
let document_type = profile_type_to_string(classification.document_type);
let runner_up = classification.runner_up.map(profile_type_to_string);
Ok(ClassificationOutput {
document_type,
confidence: classification.confidence,
reasons,
runner_up,
runner_up_confidence: classification.runner_up_confidence,
})
}
/// Run classification on a PDF file (without profiles feature).
#[cfg(not(feature = "profiles"))]
pub fn run_classify(_args: ClassifyArgs) -> Result<ClassificationOutput> {
anyhow::bail!("Classification requires the 'profiles' feature to be enabled. Build pdftract with: --features profiles")
}
/// Format classification output as JSON.
pub fn format_json(output: &ClassificationOutput, pretty: bool) -> String {
if pretty {
serde_json::to_string_pretty(output).unwrap_or_else(|_| "{}".to_string())
} else {
serde_json::to_string(output).unwrap_or_else(|_| "{}".to_string())
}
}
/// Convert ProfileType to string for JSON output.
#[cfg(feature = "profiles")]
fn profile_type_to_string(profile_type: ProfileType) -> String {
match profile_type {
ProfileType::Invoice => "invoice".to_string(),
ProfileType::Receipt => "receipt".to_string(),
ProfileType::Contract => "contract".to_string(),
ProfileType::ScientificPaper => "scientific_paper".to_string(),
ProfileType::SlideDeck => "slide_deck".to_string(),
ProfileType::Form => "form".to_string(),
ProfileType::BankStatement => "bank_statement".to_string(),
ProfileType::LegalFiling => "legal_filing".to_string(),
ProfileType::BookChapter => "book_chapter".to_string(),
ProfileType::Unknown => "unknown".to_string(),
}
}
/// Canonicalize and validate profiles directory path.
///
/// Ensures the directory exists and does not escape the current working directory
/// (path traversal protection).
fn canonicalize_profiles_dir(dir: &Path) -> Result<PathBuf> {
// Canonicalize the path
let canonical = dir.canonicalize().context(format!(
"Failed to canonicalize profiles directory: {}",
dir.display()
))?;
// Check that it exists and is a directory
if !canonical.exists() {
anyhow::bail!("Profiles directory does not exist: {}", canonical.display());
}
if !canonical.is_dir() {
anyhow::bail!("Profiles path is not a directory: {}", canonical.display());
}
// Path traversal protection: ensure the canonical path doesn't escape CWD
let cwd = std::env::current_dir().context("Failed to get current working directory")?;
// Check if canonical starts with cwd (allowing for symlink resolution differences)
if !canonical.starts_with(&cwd) {
anyhow::bail!(
"Profiles directory escapes current working directory: {}",
canonical.display()
);
}
Ok(canonical)
}
/// Load custom profiles from a directory or file.
///
/// If the path is a directory, loads all *.yaml files from it.
/// If the path is a file, loads just that file.
#[cfg(feature = "profiles")]
fn load_custom_profiles(dir: &Path) -> Result<Vec<pdftract_core::profiles::Profile>> {
use pdftract_core::profiles::ProfileLoadError;
// load_profiles_from_dir handles both files and directories
// (re-exported from profiles module)
pdftract_core::profiles::load_profiles_from_dir(dir)
.map_err(|e| anyhow::anyhow!("Failed to load profiles: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classification_output_serialization() {
let output = ClassificationOutput {
document_type: "invoice".to_string(),
confidence: 0.87,
reasons: vec![
"text contains 'INVOICE' (1 hits)".to_string(),
"has 2 table block(s)".to_string(),
],
runner_up: Some("receipt".to_string()),
runner_up_confidence: Some(0.42),
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("\"document_type\":\"invoice\""));
assert!(json.contains("\"confidence\":0.87"));
assert!(json.contains("\"runner_up\":\"receipt\""));
}
#[test]
fn test_format_json_pretty() {
let output = ClassificationOutput {
document_type: "invoice".to_string(),
confidence: 0.87,
reasons: vec!["test reason".to_string()],
runner_up: None,
runner_up_confidence: None,
};
let pretty = format_json(&output, true);
let compact = format_json(&output, false);
assert!(pretty.len() > compact.len());
assert!(pretty.contains("\n"));
assert!(!compact.contains("\n"));
}
#[test]
#[cfg(feature = "profiles")]
fn test_profile_type_to_string() {
assert_eq!(profile_type_to_string(ProfileType::Invoice), "invoice");
assert_eq!(profile_type_to_string(ProfileType::Receipt), "receipt");
assert_eq!(profile_type_to_string(ProfileType::Contract), "contract");
assert_eq!(
profile_type_to_string(ProfileType::ScientificPaper),
"scientific_paper"
);
assert_eq!(profile_type_to_string(ProfileType::SlideDeck), "slide_deck");
assert_eq!(profile_type_to_string(ProfileType::Form), "form");
assert_eq!(
profile_type_to_string(ProfileType::BankStatement),
"bank_statement"
);
assert_eq!(
profile_type_to_string(ProfileType::LegalFiling),
"legal_filing"
);
assert_eq!(
profile_type_to_string(ProfileType::BookChapter),
"book_chapter"
);
assert_eq!(profile_type_to_string(ProfileType::Unknown), "unknown");
}
}