pdftract/crates/pdftract-core/src/source/mmap.rs
jedarden ba5d101840 test(pdftract-1uhee): fix MmapSource test assertions
- test_open_valid_file: byte string is 22 bytes, not 20
- test_seek_from_end: seeking -2 from end of "Hello" gives "lo", not "el"

The MmapSource implementation was already complete with all acceptance
criteria met:
- open() returns Ok/Err appropriately
- read_range() with bounds checking
- len() matches file size
- Read+Seek trait implementations
- Send + Sync for concurrent access
- MADV_SEQUENTIAL via advise_sequential()

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 02:29:42 -04:00

459 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Memory-mapped PDF source implementation.
//!
//! This module provides `MmapSource`, a `PdfSource` backed by `memmap2`'s
//! memory-mapped local file. It applies `madvise(MADV_SEQUENTIAL)` on content
//! stream reads to hint the OS for prefetch. This is the default source for
//! local files when mmap succeeds.
use crate::source::PdfSource;
use bytes::Bytes;
use memmap2::{Mmap, MmapOptions};
use std::fs::File;
use std::io::{self, Cursor, Read, Seek, SeekFrom};
use std::path::Path;
/// Memory-mapped PDF source.
///
/// `MmapSource` is the default source for local files. It uses `memmap2` to
/// map the file into memory, allowing the OS to manage paging via the page
/// cache. This avoids allocating anonymous RSS for the entire file and enables
/// on-demand loading of only the portions of the file that are actually accessed.
///
/// # Safety
///
/// The underlying memory map relies on the file not being truncated during
/// the lifetime of the mmap. We hold the `File` handle for the entire lifetime
/// of the source, which is the standard safety pattern.
///
/// # Performance
///
/// For 100 MB1 GB PDFs, mmap is 510× faster than `read()`-based ingestion
/// due to zero-copy access and OS-managed paging.
pub struct MmapSource {
mmap: Mmap,
cursor: Cursor<u64>,
}
impl MmapSource {
/// Open a PDF file using memory-mapped I/O.
///
/// # Errors
///
/// Returns an error if the file cannot be opened or memory-mapped.
/// This includes:
/// - File not found
/// - Permission denied
/// - File too large to address (near address space limit)
/// - Kernel refuses mmap (e.g., certain FUSE mounts, `/proc`, named pipes)
///
/// Callers should fall back to `FileSource` on mmap failure.
pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let file = File::open(&path)?;
// SAFETY: We hold the File handle for the lifetime of the mmap,
// which prevents truncation. This is the documented safety contract
// of memmap2::Mmap::map.
let mmap = unsafe { MmapOptions::new().map(&file)? };
Ok(Self {
mmap,
cursor: Cursor::new(0),
})
}
/// Apply `MADV_SEQUENTIAL` advice to a range.
///
/// This hints to the OS that the specified range will be accessed
/// sequentially, allowing for aggressive readahead and prefetch.
/// Use this for content stream reads.
///
/// # Parameters
///
/// - `offset`: Byte offset of the range start
/// - `length`: Length of the range in bytes
pub fn advise_sequential(&self, offset: u64, length: usize) -> io::Result<()> {
use memmap2::Advice;
let start = offset as usize;
let end = start
.checked_add(length)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "overflow"))?;
if end > self.mmap.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"range extends beyond EOF",
));
}
self.mmap
.advise_range(Advice::Sequential, start, length)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(())
}
/// Get the underlying mmap reference for direct access.
///
/// This is useful for advanced use cases that need direct slice access.
pub fn as_slice(&self) -> &[u8] {
&self.mmap
}
}
impl PdfSource for MmapSource {
fn len(&self) -> u64 {
self.mmap.len() as u64
}
fn read_range(&self, offset: u64, length: usize) -> io::Result<Bytes> {
let start = offset as usize;
let end = start
.checked_add(length)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "overflow"))?;
if end > self.mmap.len() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"read past EOF",
));
}
// Copy into Bytes for safe sharing across threads.
// True zero-copy would require 'static lifetime guarantees
// that we can't provide with a mutable mmap.
Ok(Bytes::copy_from_slice(&self.mmap[start..end]))
}
fn prefetch(&self, offset: u64, length: usize) {
// Apply MADV_SEQUENTIAL for content streams
let _ = self.advise_sequential(offset, length);
}
}
impl Read for MmapSource {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let pos = self.cursor.position() as usize;
if pos >= self.mmap.len() {
return Ok(0);
}
let remaining = self.mmap.len() - pos;
let to_read = buf.len().min(remaining);
buf[..to_read].copy_from_slice(&self.mmap[pos..pos + to_read]);
self.cursor.set_position((pos + to_read) as u64);
Ok(to_read)
}
}
impl Seek for MmapSource {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let new_pos = match pos {
SeekFrom::Start(n) => n as i64,
SeekFrom::End(n) => self.mmap.len() as i64 + n,
SeekFrom::Current(n) => self.cursor.position() as i64 + n,
};
if new_pos < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"seek before start",
));
}
self.cursor.set_position(new_pos as u64);
Ok(new_pos as u64)
}
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.cursor.position())
}
}
// SAFETY: Mmap is Send + Sync (the underlying bytes are immutable after mapping)
unsafe impl Send for MmapSource {}
unsafe impl Sync for MmapSource {}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use std::sync::Arc;
use std::thread;
use tempfile::NamedTempFile;
#[test]
fn test_open_valid_file() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"%PDF-1.4\ntest content\n").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
assert_eq!(source.len(), 22);
}
#[test]
fn test_open_nonexistent_file() {
let result = MmapSource::open("/nonexistent/path.pdf");
assert!(result.is_err());
}
#[test]
fn test_read_range() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"Hello, World!";
temp_file.write_all(content).unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
let bytes = source.read_range(0, 5).unwrap();
assert_eq!(&bytes[..], b"Hello");
}
#[test]
fn test_read_range_partial() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"Hello, World!";
temp_file.write_all(content).unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
let bytes = source.read_range(7, 5).unwrap();
assert_eq!(&bytes[..], b"World");
}
#[test]
fn test_read_range_past_eof() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"Hello";
temp_file.write_all(content).unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
let result = source.read_range(0, 100);
assert!(matches!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof));
}
#[test]
fn test_read_range_overflow() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"test").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
let result = source.read_range(u64::MAX, 10);
assert!(matches!(result.unwrap_err().kind(), io::ErrorKind::InvalidInput));
}
#[test]
fn test_len_matches_file_size() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"0123456789";
temp_file.write_all(content).unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
assert_eq!(source.len(), 10);
}
#[test]
fn test_is_empty() {
let temp_file = NamedTempFile::new().unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
assert!(source.len() == 0);
}
#[test]
fn test_read_trait() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"Hello, World!";
temp_file.write_all(content).unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
let mut buf = [0u8; 5];
source.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"Hello");
}
#[test]
fn test_seek_trait() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"0123456789";
temp_file.write_all(content).unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
source.seek(SeekFrom::Start(5)).unwrap();
let mut buf = [0u8; 2];
source.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"56");
}
#[test]
fn test_seek_from_end() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"Hello";
temp_file.write_all(content).unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
source.seek(SeekFrom::End(-2)).unwrap();
let mut buf = [0u8; 2];
source.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"lo");
}
#[test]
fn test_seek_before_start() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"test").unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
let result = source.seek(SeekFrom::End(-100));
assert!(matches!(result.unwrap_err().kind(), io::ErrorKind::InvalidInput));
}
#[test]
fn test_stream_position() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"0123456789").unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
assert_eq!(source.stream_position().unwrap(), 0);
source.seek(SeekFrom::Start(5)).unwrap();
assert_eq!(source.stream_position().unwrap(), 5);
}
#[test]
fn test_send_sync() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"test").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
// Test Send: move to another thread
thread::spawn(move || {
assert_eq!(source.len(), 4);
})
.join()
.unwrap();
}
#[test]
fn test_sync_multiple_threads() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"0123456789";
temp_file.write_all(content).unwrap();
let source = Arc::new(MmapSource::open(temp_file.path()).unwrap());
// Spawn multiple threads reading concurrently
let handles: Vec<_> = (0..4)
.map(|i| {
let source_clone = Arc::clone(&source);
thread::spawn(move || {
let bytes = source_clone.read_range(i as u64, 2).unwrap();
bytes.to_vec()
})
})
.collect();
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.join().unwrap();
let expected = &content[i..i + 2];
assert_eq!(&result[..], expected);
}
}
#[test]
fn test_advise_sequential() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"0123456789").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
// Should not error for valid range
source.advise_sequential(0, 10).unwrap();
}
#[test]
fn test_advise_sequential_past_eof() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"test").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
let result = source.advise_sequential(0, 100);
assert!(matches!(result.unwrap_err().kind(), io::ErrorKind::InvalidInput));
}
#[test]
fn test_advise_sequential_overflow() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"test").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
let result = source.advise_sequential(u64::MAX, 10);
assert!(matches!(result.unwrap_err().kind(), io::ErrorKind::InvalidInput));
}
#[test]
fn test_prefetch() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"0123456789").unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
// prefetch is a no-op that calls advise_sequential
source.prefetch(0, 10); // Should not panic
}
#[test]
fn test_read_mixed_with_seek() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"0123456789ABCDEFGHIJ";
temp_file.write_all(content).unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
// Read some bytes
let mut buf = [0u8; 3];
source.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"012");
// Seek to middle
source.seek(SeekFrom::Start(10)).unwrap();
// Read more
source.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"ABC");
// Seek back
source.seek(SeekFrom::Start(5)).unwrap();
source.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"567");
}
#[test]
fn test_as_slice() {
let mut temp_file = NamedTempFile::new().unwrap();
let content = b"Hello, World!";
temp_file.write_all(content).unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
assert_eq!(source.as_slice(), content);
}
#[test]
fn test_empty_file() {
let temp_file = NamedTempFile::new().unwrap();
let mut source = MmapSource::open(temp_file.path()).unwrap();
assert_eq!(source.len(), 0);
let mut buf = [0u8; 10];
let n = source.read(&mut buf).unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_large_file() {
let mut temp_file = NamedTempFile::new().unwrap();
let large_content = vec![b'X'; 100_000];
temp_file.write_all(&large_content).unwrap();
let source = MmapSource::open(temp_file.path()).unwrap();
assert_eq!(source.len(), 100_000);
let bytes = source.read_range(50_000, 1000).unwrap();
assert_eq!(bytes.len(), 1000);
assert!(bytes.iter().all(|&b| b == b'X'));
}
}