feat(dump-import): implement multipart upload and broadcast fallback
- Add multipart/form-data file upload support for POST /_miroir/dumps/import - Implement fallback broadcast mode for dump_import config - Update CLI to use multipart upload instead of JSON base64 - Add axum multipart feature to miroir-proxy - Add reqwest multipart feature to miroir-ctl - Update test to reflect broadcast mode acceptance Acceptance criteria met: - Streaming import routes documents per-shard (not 100% to each node) - Large imports complete with batched per-target writes - Metrics track bytes read, documents routed, rate - Fallback broadcast mode works when streaming is disabled Closes: bf-4u2n4 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
55d44f715d
commit
d86a68ca0a
5 changed files with 311 additions and 103 deletions
|
|
@ -158,10 +158,15 @@ impl<C: NodeClient + Send + Sync + 'static> DumpImportManager<C> {
|
|||
primary_key: String,
|
||||
shard_count: u32,
|
||||
) -> Result<String> {
|
||||
if self.config.mode != "streaming" {
|
||||
return Err(MiroirError::InvalidRequest(
|
||||
"streaming dump import is disabled".into(),
|
||||
));
|
||||
// In broadcast mode, fall back to legacy behavior
|
||||
if self.config.mode == "broadcast" {
|
||||
tracing::info!(
|
||||
index = %index_uid,
|
||||
"Dump import using legacy broadcast mode (all documents to all nodes)"
|
||||
);
|
||||
return self
|
||||
.start_broadcast_import(index_uid, dump_data, primary_key)
|
||||
.await;
|
||||
}
|
||||
|
||||
let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4());
|
||||
|
|
@ -219,6 +224,174 @@ impl<C: NodeClient + Send + Sync + 'static> DumpImportManager<C> {
|
|||
imports.get(import_id).cloned()
|
||||
}
|
||||
|
||||
/// Start a legacy broadcast dump import (all documents to all nodes).
|
||||
async fn start_broadcast_import(
|
||||
&self,
|
||||
index_uid: String,
|
||||
dump_data: Vec<u8>,
|
||||
primary_key: String,
|
||||
) -> Result<String> {
|
||||
let import_id = format!("dump-{}-{}", index_uid, uuid::Uuid::new_v4());
|
||||
let now = millis_now();
|
||||
|
||||
// Create initial status
|
||||
let status = DumpImportStatus {
|
||||
id: import_id.clone(),
|
||||
index_uid: index_uid.clone(),
|
||||
phase: DumpImportPhase::Reading.as_str().to_string(),
|
||||
documents_processed: 0,
|
||||
total_documents: 0,
|
||||
bytes_read: 0,
|
||||
phase_started_at: now,
|
||||
error: None,
|
||||
};
|
||||
|
||||
{
|
||||
let mut imports = self.active_imports.write().await;
|
||||
imports.insert(import_id.clone(), status);
|
||||
}
|
||||
|
||||
// Run the broadcast import
|
||||
let result = Self::run_broadcast_import(
|
||||
&import_id,
|
||||
index_uid,
|
||||
dump_data,
|
||||
primary_key,
|
||||
self.topology.clone(),
|
||||
self.active_imports.clone(),
|
||||
self.client.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::error!("Broadcast dump import {} failed: {}", import_id, e);
|
||||
|
||||
// Update status to failed
|
||||
let mut imports = self.active_imports.write().await;
|
||||
if let Some(status) = imports.get_mut(&import_id) {
|
||||
status.phase = DumpImportPhase::Failed.as_str().to_string();
|
||||
status.error = Some(e.to_string());
|
||||
status.phase_started_at = millis_now();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(import_id)
|
||||
}
|
||||
|
||||
/// Run a legacy broadcast import (sends all documents to all nodes).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_broadcast_import(
|
||||
import_id: &str,
|
||||
index_uid: String,
|
||||
dump_data: Vec<u8>,
|
||||
primary_key: String,
|
||||
topology: Arc<RwLock<Topology>>,
|
||||
imports: Arc<RwLock<HashMap<String, DumpImportStatus>>>,
|
||||
client: Arc<C>,
|
||||
) -> Result<()> {
|
||||
// Update phase to reading
|
||||
Self::update_phase(&imports, import_id, DumpImportPhase::Reading).await;
|
||||
|
||||
// Parse NDJSON
|
||||
let data_str = std::str::from_utf8(&dump_data)
|
||||
.map_err(|e| MiroirError::InvalidRequest(format!("invalid UTF-8 in dump: {e}")))?;
|
||||
|
||||
let mut documents = Vec::new();
|
||||
let bytes_read = dump_data.len() as u64;
|
||||
|
||||
for line in data_str.lines() {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let doc: Value = serde_json::from_str(line)
|
||||
.map_err(|e| MiroirError::InvalidRequest(format!("invalid JSON in dump: {e}")))?;
|
||||
documents.push(doc);
|
||||
}
|
||||
|
||||
let total_docs = documents.len() as u64;
|
||||
|
||||
// Get all nodes
|
||||
let topology = topology.read().await;
|
||||
let node_ids: Vec<NodeId> = topology.nodes().map(|n| n.id.clone()).collect();
|
||||
|
||||
if node_ids.is_empty() {
|
||||
return Err(MiroirError::Topology("no nodes available".into()));
|
||||
}
|
||||
|
||||
// Send all documents to all nodes
|
||||
Self::update_phase(&imports, import_id, DumpImportPhase::Routing).await;
|
||||
|
||||
let mut write_tasks = Vec::new();
|
||||
|
||||
for node_id in node_ids.clone() {
|
||||
let docs = documents.clone();
|
||||
let index = index_uid.clone();
|
||||
let client_ref = client.clone();
|
||||
let pk = primary_key.clone();
|
||||
|
||||
write_tasks.push(async move {
|
||||
let write_req = WriteRequest {
|
||||
index_uid: index.clone(),
|
||||
documents: docs,
|
||||
primary_key: Some(pk),
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let result = client_ref.write_documents(&node_id, "", &write_req).await;
|
||||
(node_id, result)
|
||||
});
|
||||
}
|
||||
|
||||
// Execute all writes in parallel
|
||||
let results = futures_util::stream::iter(write_tasks)
|
||||
.buffer_unordered(8)
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
// Check for errors
|
||||
for (node_id, result) in results {
|
||||
match result {
|
||||
Ok(resp) if resp.success => {
|
||||
tracing::debug!(
|
||||
"Broadcast {} documents to node {} for index {}",
|
||||
total_docs,
|
||||
node_id,
|
||||
index_uid
|
||||
);
|
||||
}
|
||||
Ok(resp) => {
|
||||
tracing::warn!(
|
||||
"Failed to broadcast documents to node {} for index {}: {}",
|
||||
node_id,
|
||||
index_uid,
|
||||
resp.message.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Error broadcasting documents to node {} for index {}: {:?}",
|
||||
node_id,
|
||||
index_uid,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update status
|
||||
let mut imports_guard = imports.write().await;
|
||||
if let Some(status) = imports_guard.get_mut(import_id) {
|
||||
status.documents_processed = total_docs;
|
||||
status.total_documents = total_docs;
|
||||
status.bytes_read = bytes_read;
|
||||
}
|
||||
|
||||
// Mark complete
|
||||
Self::update_phase(&imports, import_id, DumpImportPhase::Complete).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run the import pipeline.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_import(
|
||||
|
|
@ -488,20 +661,34 @@ mod tests {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_import_rejects_broadcast_mode() {
|
||||
async fn test_import_accepts_broadcast_mode() {
|
||||
let config = DumpImportConfig {
|
||||
mode: "broadcast".into(),
|
||||
..Default::default()
|
||||
};
|
||||
let topology = Arc::new(RwLock::new(Topology::new(64, 2, 1)));
|
||||
let client = MockNodeClient::default();
|
||||
|
||||
// Create a mock client that returns success
|
||||
let mut client = MockNodeClient::default();
|
||||
client.write_responses.insert(
|
||||
NodeId::new("node-0".into()),
|
||||
WriteResponse {
|
||||
success: true,
|
||||
task_uid: Some(1),
|
||||
message: None,
|
||||
code: None,
|
||||
error_type: None,
|
||||
},
|
||||
);
|
||||
|
||||
let manager = DumpImportManager::new(config, topology, client);
|
||||
|
||||
let result = manager
|
||||
.start_import("products".into(), vec![1, 2, 3], "id".into(), 64)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
// Broadcast mode should now succeed
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ path = "src/main.rs"
|
|||
[dependencies]
|
||||
base64 = "0.22"
|
||||
clap = { version = "4.5", features = ["derive", "env"] }
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls", "multipart"], default-features = false }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1.42", features = ["full"] }
|
||||
|
|
|
|||
|
|
@ -97,23 +97,27 @@ pub async fn run(
|
|||
} => {
|
||||
let client = Client::new();
|
||||
|
||||
// Read the dump file
|
||||
let dump_data = std::fs::read_to_string(&file)?;
|
||||
|
||||
// Build the request
|
||||
let request_body = serde_json::json!({
|
||||
"index_uid": index,
|
||||
"primary_key": primary_key,
|
||||
"shard_count": shard_count,
|
||||
"dump_data": dump_data,
|
||||
});
|
||||
// Read the dump file as bytes
|
||||
let dump_bytes = std::fs::read(&file)?;
|
||||
|
||||
// Build the multipart form
|
||||
let url = format!("{}/_miroir/dumps/import", api_url.trim_end_matches('/'));
|
||||
|
||||
let form_part = reqwest::multipart::Part::bytes(dump_bytes)
|
||||
.file_name(file.clone())
|
||||
.mime_str("application/octet-stream")
|
||||
.map_err(|e| format!("Failed to set mime type: {e}"))?;
|
||||
|
||||
let form = reqwest::multipart::Form::new()
|
||||
.text("index_uid", index.clone())
|
||||
.text("primary_key", primary_key.clone())
|
||||
.text("shard_count", shard_count.to_string())
|
||||
.part("dump_file", form_part);
|
||||
|
||||
let response = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {admin_key}"))
|
||||
.json(&request_body)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ path = "src/main.rs"
|
|||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-trait = "0.1"
|
||||
axum = { version = "0.7", features = ["macros"] }
|
||||
axum = { version = "0.7", features = ["macros", "multipart"] }
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
http = "1.1"
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "signal"] }
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
//! Dump import routes (plan §13.9).
|
||||
//!
|
||||
//! Admin API endpoints for streaming routed dump import:
|
||||
//! - `POST /_miroir/dumps/import` — start a dump import
|
||||
//! - `POST /_miroir/dumps/import` — start a dump import (multipart)
|
||||
//! - `GET /_miroir/dumps/import/{id}/status` — get import status
|
||||
|
||||
use axum::extract::{Extension, FromRef, Path};
|
||||
use axum::extract::{Extension, FromRef, Multipart, Path};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use miroir_core::api_error::{MeilisearchError, MiroirCode};
|
||||
|
|
@ -12,18 +12,11 @@ use miroir_core::dump_import::{DumpImportManager, DumpImportPhase, DumpImportSta
|
|||
|
||||
use crate::client::HttpClient;
|
||||
|
||||
/// Request body for starting a dump import.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct DumpImportRequest {
|
||||
/// Index UID to import into.
|
||||
index_uid: String,
|
||||
/// Primary key field name.
|
||||
primary_key: String,
|
||||
/// Number of shards for the index.
|
||||
shard_count: u32,
|
||||
/// Dump file contents (base64-encoded or raw NDJSON).
|
||||
dump_data: String,
|
||||
}
|
||||
/// Multipart field names for dump import.
|
||||
const FIELD_INDEX_UID: &str = "index_uid";
|
||||
const FIELD_PRIMARY_KEY: &str = "primary_key";
|
||||
const FIELD_SHARD_COUNT: &str = "shard_count";
|
||||
const FIELD_DUMP_FILE: &str = "dump_file";
|
||||
|
||||
/// Response for starting a dump import.
|
||||
#[derive(serde::Serialize)]
|
||||
|
|
@ -48,51 +41,117 @@ where
|
|||
/// POST /_miroir/dumps/import
|
||||
///
|
||||
/// Start a streaming routed dump import.
|
||||
///
|
||||
/// Requires multipart/form-data with fields: index_uid, primary_key, shard_count, dump_file (file).
|
||||
async fn start_import<S>(
|
||||
Extension(state): Extension<crate::routes::admin_endpoints::AppState>,
|
||||
Json(req): Json<DumpImportRequest>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<DumpImportResponse>, MeilisearchError>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
crate::routes::admin_endpoints::AppState: FromRef<S>,
|
||||
{
|
||||
// Handle multipart form data
|
||||
let mut index_uid = None;
|
||||
let mut primary_key = None;
|
||||
let mut shard_count = None;
|
||||
let mut dump_data = None;
|
||||
|
||||
while let Some(field) = multipart.next_field().await.map_err(|e| {
|
||||
MeilisearchError::new(MiroirCode::InvalidRequest, format!("multipart error: {e}"))
|
||||
})? {
|
||||
let name = field.name().unwrap_or("").to_string();
|
||||
|
||||
match name.as_str() {
|
||||
FIELD_INDEX_UID => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
format!("error reading index_uid: {e}"),
|
||||
)
|
||||
})?;
|
||||
index_uid = Some(value);
|
||||
}
|
||||
FIELD_PRIMARY_KEY => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
format!("error reading primary_key: {e}"),
|
||||
)
|
||||
})?;
|
||||
primary_key = Some(value);
|
||||
}
|
||||
FIELD_SHARD_COUNT => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
format!("error reading shard_count: {e}"),
|
||||
)
|
||||
})?;
|
||||
let count = value.parse::<u32>().map_err(|_| {
|
||||
MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
"shard_count must be a number".to_string(),
|
||||
)
|
||||
})?;
|
||||
shard_count = Some(count);
|
||||
}
|
||||
FIELD_DUMP_FILE => {
|
||||
let filename = field.file_name().map(|s| s.to_string());
|
||||
let data = field.bytes().await.map_err(|e| {
|
||||
MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
format!("error reading dump_file: {e}"),
|
||||
)
|
||||
})?;
|
||||
tracing::debug!(
|
||||
filename,
|
||||
size = data.len(),
|
||||
"Received dump file via multipart upload"
|
||||
);
|
||||
dump_data = Some(data.to_vec());
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(unknown_field = %name, "Ignoring unknown field in multipart dump import");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let index_uid = index_uid.ok_or_else(|| {
|
||||
MeilisearchError::new(MiroirCode::InvalidRequest, "index_uid is required")
|
||||
})?;
|
||||
let primary_key = primary_key.ok_or_else(|| {
|
||||
MeilisearchError::new(MiroirCode::InvalidRequest, "primary_key is required")
|
||||
})?;
|
||||
let shard_count = shard_count.ok_or_else(|| {
|
||||
MeilisearchError::new(MiroirCode::InvalidRequest, "shard_count is required")
|
||||
})?;
|
||||
let dump_data = dump_data.ok_or_else(|| {
|
||||
MeilisearchError::new(MiroirCode::InvalidRequest, "dump_file is required")
|
||||
})?;
|
||||
|
||||
// Validate request
|
||||
if req.index_uid.is_empty() {
|
||||
if index_uid.is_empty() {
|
||||
return Err(MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
"index_uid is required",
|
||||
));
|
||||
}
|
||||
|
||||
if req.primary_key.is_empty() {
|
||||
if primary_key.is_empty() {
|
||||
return Err(MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
"primary_key is required",
|
||||
));
|
||||
}
|
||||
|
||||
if req.shard_count == 0 {
|
||||
if shard_count == 0 {
|
||||
return Err(MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
"shard_count must be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
// Decode dump data (assume base64 if it looks like it, otherwise treat as raw)
|
||||
let dump_data = if looks_like_base64(&req.dump_data) {
|
||||
match base64_decode(&req.dump_data) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
return Err(MeilisearchError::new(
|
||||
MiroirCode::InvalidRequest,
|
||||
format!("invalid base64 dump_data: {e}"),
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
req.dump_data.into_bytes()
|
||||
};
|
||||
|
||||
let bytes_read = dump_data.len() as u64;
|
||||
|
||||
// Create HTTP client
|
||||
|
|
@ -109,10 +168,10 @@ where
|
|||
// Start the import
|
||||
let import_id = manager
|
||||
.start_import(
|
||||
req.index_uid.clone(),
|
||||
index_uid.clone(),
|
||||
dump_data,
|
||||
req.primary_key.clone(),
|
||||
req.shard_count,
|
||||
primary_key.clone(),
|
||||
shard_count,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
|
|
@ -126,14 +185,14 @@ where
|
|||
state.metrics.inc_dump_import_bytes_read(bytes_read);
|
||||
state
|
||||
.metrics
|
||||
.set_dump_import_phase(&req.index_uid, &import_id, DumpImportPhase::Reading as u8);
|
||||
.set_dump_import_phase(&index_uid, &import_id, DumpImportPhase::Reading as u8);
|
||||
|
||||
tracing::info!(
|
||||
"Started dump import {} for index {} (shard_count={}, primary_key={}, bytes={})",
|
||||
import_id,
|
||||
req.index_uid,
|
||||
req.shard_count,
|
||||
req.primary_key,
|
||||
index_uid,
|
||||
shard_count,
|
||||
primary_key,
|
||||
bytes_read
|
||||
);
|
||||
|
||||
|
|
@ -197,25 +256,6 @@ where
|
|||
Ok(Json(status))
|
||||
}
|
||||
|
||||
/// Check if a string looks like base64-encoded data.
|
||||
fn looks_like_base64(s: &str) -> bool {
|
||||
// Base64 strings are typically multiples of 4 and only contain A-Za-z0-9+/
|
||||
if s.len() % 4 != 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
s.chars()
|
||||
.all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=')
|
||||
}
|
||||
|
||||
/// Decode a base64 string.
|
||||
fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD
|
||||
.decode(s)
|
||||
.map_err(|e| format!("base64 decode failed: {e}"))
|
||||
}
|
||||
|
||||
/// Get current UNIX timestamp in milliseconds.
|
||||
fn millis_now() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
|
|
@ -225,27 +265,4 @@ fn millis_now() -> u64 {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_looks_like_base64() {
|
||||
assert!(looks_like_base64("SGVsbG8gV29ybGQ=")); // "Hello World"
|
||||
assert!(!looks_like_base64("Hello World"));
|
||||
assert!(!looks_like_base64("not base64!"));
|
||||
assert!(looks_like_base64("eyJpZCI6ICIxIn0=")); // JSON
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_decode() {
|
||||
let encoded = "SGVsbG8gV29ybGQ=";
|
||||
let decoded = base64_decode(encoded).unwrap();
|
||||
assert_eq!(String::from_utf8(decoded).unwrap(), "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_decode_invalid() {
|
||||
let result = base64_decode("invalid!base64");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
mod tests {}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue