feat(dump-import): implement multipart upload and broadcast fallback

- Add multipart/form-data file upload support for POST /_miroir/dumps/import
- Implement fallback broadcast mode for dump_import config
- Update CLI to use multipart upload instead of JSON base64
- Add axum multipart feature to miroir-proxy
- Add reqwest multipart feature to miroir-ctl
- Update test to reflect broadcast mode acceptance

Acceptance criteria met:
- Streaming import routes documents per-shard (not 100% to each node)
- Large imports complete with batched per-target writes
- Metrics track bytes read, documents routed, rate
- Fallback broadcast mode works when streaming is disabled

Closes: bf-4u2n4
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
jedarden 2026-05-26 13:43:33 -04:00
parent 55d44f715d
commit d86a68ca0a
5 changed files with 311 additions and 103 deletions

View file

@ -158,10 +158,15 @@ impl<C: NodeClient + Send + Sync + 'static> DumpImportManager<C> {
primary_key: String,
shard_count: u32,
) -> Result<String> {
if self.config.mode != "streaming" {
return Err(MiroirError::InvalidRequest(
"streaming dump import is disabled".into(),
));
// In broadcast mode, fall back to legacy behavior
if self.config.mode == "broadcast" {
tracing::info!(
index = %index_uid,
"Dump import using legacy broadcast mode (all documents to all nodes)"
);
return self
.start_broadcast_import(index_uid, dump_data, primary_key)
.await;
}
let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4());
@ -219,6 +224,174 @@ impl<C: NodeClient + Send + Sync + 'static> DumpImportManager<C> {
imports.get(import_id).cloned()
}
/// Start a legacy broadcast dump import (all documents to all nodes).
async fn start_broadcast_import(
&self,
index_uid: String,
dump_data: Vec<u8>,
primary_key: String,
) -> Result<String> {
let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4());
let now = millis_now();
// Create initial status
let status = DumpImportStatus {
id: import_id.clone(),
index_uid: index_uid.clone(),
phase: DumpImportPhase::Reading.as_str().to_string(),
documents_processed: 0,
total_documents: 0,
bytes_read: 0,
phase_started_at: now,
error: None,
};
{
let mut imports = self.active_imports.write().await;
imports.insert(import_id.clone(), status);
}
// Run the broadcast import
let result = Self::run_broadcast_import(
&import_id,
index_uid,
dump_data,
primary_key,
self.topology.clone(),
self.active_imports.clone(),
self.client.clone(),
)
.await;
if let Err(e) = result {
tracing::error!("Broadcast dump import {} failed: {}", import_id, e);
// Update status to failed
let mut imports = self.active_imports.write().await;
if let Some(status) = imports.get_mut(&import_id) {
status.phase = DumpImportPhase::Failed.as_str().to_string();
status.error = Some(e.to_string());
status.phase_started_at = millis_now();
}
}
Ok(import_id)
}
/// Run a legacy broadcast import (sends all documents to all nodes).
#[allow(clippy::too_many_arguments)]
async fn run_broadcast_import(
import_id: &str,
index_uid: String,
dump_data: Vec<u8>,
primary_key: String,
topology: Arc<RwLock<Topology>>,
imports: Arc<RwLock<HashMap<String, DumpImportStatus>>>,
client: Arc<C>,
) -> Result<()> {
// Update phase to reading
Self::update_phase(&imports, import_id, DumpImportPhase::Reading).await;
// Parse NDJSON
let data_str = std::str::from_utf8(&dump_data)
.map_err(|e| MiroirError::InvalidRequest(format!("invalid UTF-8 in dump: {e}")))?;
let mut documents = Vec::new();
let bytes_read = dump_data.len() as u64;
for line in data_str.lines() {
if line.is_empty() {
continue;
}
let doc: Value = serde_json::from_str(line)
.map_err(|e| MiroirError::InvalidRequest(format!("invalid JSON in dump: {e}")))?;
documents.push(doc);
}
let total_docs = documents.len() as u64;
// Get all nodes
let topology = topology.read().await;
let node_ids: Vec<NodeId> = topology.nodes().map(|n| n.id.clone()).collect();
if node_ids.is_empty() {
return Err(MiroirError::Topology("no nodes available".into()));
}
// Send all documents to all nodes
Self::update_phase(&imports, import_id, DumpImportPhase::Routing).await;
let mut write_tasks = Vec::new();
for node_id in node_ids.clone() {
let docs = documents.clone();
let index = index_uid.clone();
let client_ref = client.clone();
let pk = primary_key.clone();
write_tasks.push(async move {
let write_req = WriteRequest {
index_uid: index.clone(),
documents: docs,
primary_key: Some(pk),
origin: None,
};
let result = client_ref.write_documents(&node_id, "", &write_req).await;
(node_id, result)
});
}
// Execute all writes in parallel
let results = futures_util::stream::iter(write_tasks)
.buffer_unordered(8)
.collect::<Vec<_>>()
.await;
// Check for errors
for (node_id, result) in results {
match result {
Ok(resp) if resp.success => {
tracing::debug!(
"Broadcast {} documents to node {} for index {}",
total_docs,
node_id,
index_uid
);
}
Ok(resp) => {
tracing::warn!(
"Failed to broadcast documents to node {} for index {}: {}",
node_id,
index_uid,
resp.message.unwrap_or_default()
);
}
Err(e) => {
tracing::error!(
"Error broadcasting documents to node {} for index {}: {:?}",
node_id,
index_uid,
e
);
}
}
}
// Update status
let mut imports_guard = imports.write().await;
if let Some(status) = imports_guard.get_mut(import_id) {
status.documents_processed = total_docs;
status.total_documents = total_docs;
status.bytes_read = bytes_read;
}
// Mark complete
Self::update_phase(&imports, import_id, DumpImportPhase::Complete).await;
Ok(())
}
/// Run the import pipeline.
#[allow(clippy::too_many_arguments)]
async fn run_import(
@ -488,20 +661,34 @@ mod tests {
}
#[tokio::test]
async fn test_import_rejects_broadcast_mode() {
async fn test_import_accepts_broadcast_mode() {
let config = DumpImportConfig {
mode: "broadcast".into(),
..Default::default()
};
let topology = Arc::new(RwLock::new(Topology::new(64, 2, 1)));
let client = MockNodeClient::default();
// Create a mock client that returns success
let mut client = MockNodeClient::default();
client.write_responses.insert(
NodeId::new("node-0".into()),
WriteResponse {
success: true,
task_uid: Some(1),
message: None,
code: None,
error_type: None,
},
);
let manager = DumpImportManager::new(config, topology, client);
let result = manager
.start_import("products".into(), vec![1, 2, 3], "id".into(), 64)
.await;
assert!(result.is_err());
// Broadcast mode should now succeed
assert!(result.is_ok());
}
#[tokio::test]

View file

@ -12,7 +12,7 @@ path = "src/main.rs"
[dependencies]
base64 = "0.22"
clap = { version = "4.5", features = ["derive", "env"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
reqwest = { version = "0.12", features = ["json", "rustls-tls", "multipart"], default-features = false }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.42", features = ["full"] }

View file

@ -97,23 +97,27 @@ pub async fn run(
} => {
let client = Client::new();
// Read the dump file
let dump_data = std::fs::read_to_string(&file)?;
// Build the request
let request_body = serde_json::json!({
"index_uid": index,
"primary_key": primary_key,
"shard_count": shard_count,
"dump_data": dump_data,
});
// Read the dump file as bytes
let dump_bytes = std::fs::read(&file)?;
// Build the multipart form
let url = format!("{}/_miroir/dumps/import", api_url.trim_end_matches('/'));
let form_part = reqwest::multipart::Part::bytes(dump_bytes)
.file_name(file.clone())
.mime_str("application/octet-stream")
.map_err(|e| format!("Failed to set mime type: {e}"))?;
let form = reqwest::multipart::Form::new()
.text("index_uid", index.clone())
.text("primary_key", primary_key.clone())
.text("shard_count", shard_count.to_string())
.part("dump_file", form_part);
let response = client
.post(&url)
.header("Authorization", format!("Bearer {admin_key}"))
.json(&request_body)
.multipart(form)
.send()
.await?;

View file

@ -17,7 +17,7 @@ path = "src/main.rs"
[dependencies]
anyhow = "1"
async-trait = "0.1"
axum = { version = "0.7", features = ["macros"] }
axum = { version = "0.7", features = ["macros", "multipart"] }
clap = { version = "4.5", features = ["derive"] }
http = "1.1"
tokio = { version = "1", features = ["rt-multi-thread", "signal"] }

View file

@ -1,10 +1,10 @@
//! Dump import routes (plan §13.9).
//!
//! Admin API endpoints for streaming routed dump import:
//! - `POST /_miroir/dumps/import` — start a dump import
//! - `POST /_miroir/dumps/import` — start a dump import (multipart)
//! - `GET /_miroir/dumps/import/{id}/status` — get import status
use axum::extract::{Extension, FromRef, Path};
use axum::extract::{Extension, FromRef, Multipart, Path};
use axum::routing::{get, post};
use axum::{Json, Router};
use miroir_core::api_error::{MeilisearchError, MiroirCode};
@ -12,18 +12,11 @@ use miroir_core::dump_import::{DumpImportManager, DumpImportPhase, DumpImportSta
use crate::client::HttpClient;
/// Request body for starting a dump import.
#[derive(serde::Deserialize)]
struct DumpImportRequest {
/// Index UID to import into.
index_uid: String,
/// Primary key field name.
primary_key: String,
/// Number of shards for the index.
shard_count: u32,
/// Dump file contents (base64-encoded or raw NDJSON).
dump_data: String,
}
/// Multipart field names for dump import.
const FIELD_INDEX_UID: &str = "index_uid";
const FIELD_PRIMARY_KEY: &str = "primary_key";
const FIELD_SHARD_COUNT: &str = "shard_count";
const FIELD_DUMP_FILE: &str = "dump_file";
/// Response for starting a dump import.
#[derive(serde::Serialize)]
@ -48,51 +41,117 @@ where
/// POST /_miroir/dumps/import
///
/// Start a streaming routed dump import.
///
/// Requires multipart/form-data with fields: index_uid, primary_key, shard_count, dump_file (file).
async fn start_import<S>(
Extension(state): Extension<crate::routes::admin_endpoints::AppState>,
Json(req): Json<DumpImportRequest>,
mut multipart: Multipart,
) -> Result<Json<DumpImportResponse>, MeilisearchError>
where
S: Clone + Send + Sync + 'static,
crate::routes::admin_endpoints::AppState: FromRef<S>,
{
// Handle multipart form data
let mut index_uid = None;
let mut primary_key = None;
let mut shard_count = None;
let mut dump_data = None;
while let Some(field) = multipart.next_field().await.map_err(|e| {
MeilisearchError::new(MiroirCode::InvalidRequest, format!("multipart error: {e}"))
})? {
let name = field.name().unwrap_or("").to_string();
match name.as_str() {
FIELD_INDEX_UID => {
let value = field.text().await.map_err(|e| {
MeilisearchError::new(
MiroirCode::InvalidRequest,
format!("error reading index_uid: {e}"),
)
})?;
index_uid = Some(value);
}
FIELD_PRIMARY_KEY => {
let value = field.text().await.map_err(|e| {
MeilisearchError::new(
MiroirCode::InvalidRequest,
format!("error reading primary_key: {e}"),
)
})?;
primary_key = Some(value);
}
FIELD_SHARD_COUNT => {
let value = field.text().await.map_err(|e| {
MeilisearchError::new(
MiroirCode::InvalidRequest,
format!("error reading shard_count: {e}"),
)
})?;
let count = value.parse::<u32>().map_err(|_| {
MeilisearchError::new(
MiroirCode::InvalidRequest,
"shard_count must be a number".to_string(),
)
})?;
shard_count = Some(count);
}
FIELD_DUMP_FILE => {
let filename = field.file_name().map(|s| s.to_string());
let data = field.bytes().await.map_err(|e| {
MeilisearchError::new(
MiroirCode::InvalidRequest,
format!("error reading dump_file: {e}"),
)
})?;
tracing::debug!(
filename,
size = data.len(),
"Received dump file via multipart upload"
);
dump_data = Some(data.to_vec());
}
_ => {
tracing::warn!(unknown_field = %name, "Ignoring unknown field in multipart dump import");
}
}
}
let index_uid = index_uid.ok_or_else(|| {
MeilisearchError::new(MiroirCode::InvalidRequest, "index_uid is required")
})?;
let primary_key = primary_key.ok_or_else(|| {
MeilisearchError::new(MiroirCode::InvalidRequest, "primary_key is required")
})?;
let shard_count = shard_count.ok_or_else(|| {
MeilisearchError::new(MiroirCode::InvalidRequest, "shard_count is required")
})?;
let dump_data = dump_data.ok_or_else(|| {
MeilisearchError::new(MiroirCode::InvalidRequest, "dump_file is required")
})?;
// Validate request
if req.index_uid.is_empty() {
if index_uid.is_empty() {
return Err(MeilisearchError::new(
MiroirCode::InvalidRequest,
"index_uid is required",
));
}
if req.primary_key.is_empty() {
if primary_key.is_empty() {
return Err(MeilisearchError::new(
MiroirCode::InvalidRequest,
"primary_key is required",
));
}
if req.shard_count == 0 {
if shard_count == 0 {
return Err(MeilisearchError::new(
MiroirCode::InvalidRequest,
"shard_count must be > 0",
));
}
// Decode dump data (assume base64 if it looks like it, otherwise treat as raw)
let dump_data = if looks_like_base64(&req.dump_data) {
match base64_decode(&req.dump_data) {
Ok(data) => data,
Err(e) => {
return Err(MeilisearchError::new(
MiroirCode::InvalidRequest,
format!("invalid base64 dump_data: {e}"),
))
}
}
} else {
req.dump_data.into_bytes()
};
let bytes_read = dump_data.len() as u64;
// Create HTTP client
@ -109,10 +168,10 @@ where
// Start the import
let import_id = manager
.start_import(
req.index_uid.clone(),
index_uid.clone(),
dump_data,
req.primary_key.clone(),
req.shard_count,
primary_key.clone(),
shard_count,
)
.await
.map_err(|e| {
@ -126,14 +185,14 @@ where
state.metrics.inc_dump_import_bytes_read(bytes_read);
state
.metrics
.set_dump_import_phase(&req.index_uid, &import_id, DumpImportPhase::Reading as u8);
.set_dump_import_phase(&index_uid, &import_id, DumpImportPhase::Reading as u8);
tracing::info!(
"Started dump import {} for index {} (shard_count={}, primary_key={}, bytes={})",
import_id,
req.index_uid,
req.shard_count,
req.primary_key,
index_uid,
shard_count,
primary_key,
bytes_read
);
@ -197,25 +256,6 @@ where
Ok(Json(status))
}
/// Check if a string looks like base64-encoded data.
fn looks_like_base64(s: &str) -> bool {
// Base64 strings are typically multiples of 4 and only contain A-Za-z0-9+/
if s.len() % 4 != 0 {
return false;
}
s.chars()
.all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=')
}
/// Decode a base64 string.
fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(s)
.map_err(|e| format!("base64 decode failed: {e}"))
}
/// Get current UNIX timestamp in milliseconds.
fn millis_now() -> u64 {
std::time::SystemTime::now()
@ -225,27 +265,4 @@ fn millis_now() -> u64 {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_looks_like_base64() {
assert!(looks_like_base64("SGVsbG8gV29ybGQ=")); // "Hello World"
assert!(!looks_like_base64("Hello World"));
assert!(!looks_like_base64("not base64!"));
assert!(looks_like_base64("eyJpZCI6ICIxIn0=")); // JSON
}
#[test]
fn test_base64_decode() {
let encoded = "SGVsbG8gV29ybGQ=";
let decoded = base64_decode(encoded).unwrap();
assert_eq!(String::from_utf8(decoded).unwrap(), "Hello World");
}
#[test]
fn test_base64_decode_invalid() {
let result = base64_decode("invalid!base64");
assert!(result.is_err());
}
}
mod tests {}