-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] multipart S3 file uploads #2590
Changes from 2 commits
5378576
5ec3be2
e0872ca
11f44fd
8ab22c6
7fded1e
bd0e286
a38b817
adb5028
58a42fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,17 @@ use aws_sdk_s3::error::ProvideErrorMetadata; | |
use aws_sdk_s3::error::SdkError; | ||
use aws_sdk_s3::operation::create_bucket::CreateBucketError; | ||
use aws_sdk_s3::primitives::ByteStream; | ||
use aws_sdk_s3::types::CompletedMultipartUpload; | ||
use aws_sdk_s3::types::CompletedPart; | ||
use aws_smithy_types::byte_stream::Length; | ||
use futures::Stream; | ||
use std::clone::Clone; | ||
use std::time::Duration; | ||
use thiserror::Error; | ||
|
||
// todo: make this more principled | ||
const MULTIPART_UPLOAD_CHUNK_SIZE: u64 = 1024 * 1024 * 32; | ||
HammadB marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#[derive(Clone)] | ||
pub(crate) struct S3Storage { | ||
bucket: String, | ||
|
@@ -161,13 +167,88 @@ impl S3Storage { | |
} | ||
|
||
pub(crate) async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
let bytestream = ByteStream::from_path(path).await; | ||
match bytestream { | ||
Ok(bytestream) => return self.put_bytestream(key, bytestream).await, | ||
Err(e) => { | ||
return Err(S3PutError::S3PutError(e.to_string())); | ||
let file_size = tokio::fs::metadata(path) | ||
.await | ||
.map_err(|err| S3PutError::S3PutError(err.to_string()))? | ||
.len(); | ||
let mut chunk_count = (file_size / MULTIPART_UPLOAD_CHUNK_SIZE) + 1; | ||
let mut size_of_last_chunk = file_size % MULTIPART_UPLOAD_CHUNK_SIZE; | ||
if size_of_last_chunk == 0 { | ||
size_of_last_chunk = MULTIPART_UPLOAD_CHUNK_SIZE; | ||
chunk_count -= 1; | ||
} | ||
|
||
let upload_id = match self | ||
.client | ||
.create_multipart_upload() | ||
.bucket(&self.bucket) | ||
.key(key) | ||
.send() | ||
.await | ||
.map_err(|err| S3PutError::S3PutError(err.to_string()))? | ||
.upload_id | ||
{ | ||
Some(upload_id) => upload_id, | ||
None => { | ||
return Err(S3PutError::S3PutError( | ||
"Multipart upload creation response missing upload ID".to_string(), | ||
)); | ||
} | ||
}; | ||
|
||
let mut upload_parts = Vec::new(); | ||
for chunk_index in 0..chunk_count { | ||
let this_chunk = if chunk_count - 1 == chunk_index { | ||
size_of_last_chunk | ||
} else { | ||
MULTIPART_UPLOAD_CHUNK_SIZE | ||
}; | ||
|
||
let stream = ByteStream::read_from() | ||
.path(path) | ||
.offset(chunk_index * MULTIPART_UPLOAD_CHUNK_SIZE) | ||
.length(Length::Exact(this_chunk)) | ||
.build() | ||
.await | ||
.map_err(|err| S3PutError::S3PutError(err.to_string()))?; | ||
|
||
//Chunk index needs to start at 0, but part numbers start at 1. | ||
let part_number = (chunk_index as i32) + 1; | ||
let upload_part_res = self | ||
.client | ||
.upload_part() | ||
.key(key) | ||
.bucket(&self.bucket) | ||
.upload_id(&upload_id) | ||
.body(stream) | ||
.part_number(part_number) | ||
.send() | ||
.await | ||
.map_err(|err| S3PutError::S3PutError(err.to_string()))?; | ||
|
||
upload_parts.push( | ||
CompletedPart::builder() | ||
.e_tag(upload_part_res.e_tag.unwrap_or_default()) | ||
.part_number(part_number) | ||
.build(), | ||
); | ||
} | ||
|
||
self.client | ||
.complete_multipart_upload() | ||
.bucket(&self.bucket) | ||
.key(key) | ||
.multipart_upload( | ||
CompletedMultipartUpload::builder() | ||
.set_parts(Some(upload_parts)) | ||
.build(), | ||
) | ||
.upload_id(&upload_id) | ||
.send() | ||
.await | ||
.map_err(|err| S3PutError::S3PutError(err.to_string()))?; | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn put_bytestream(&self, key: &str, bytestream: ByteStream) -> Result<(), S3PutError> { | ||
codetheweb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -289,13 +370,14 @@ impl Configurable<StorageConfig> for S3Storage { | |
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::io::Write; | ||
|
||
use super::*; | ||
use tempfile::tempdir; | ||
use tokio::io::AsyncReadExt; | ||
use futures::StreamExt; | ||
use rand::{Rng, SeedableRng}; | ||
use tempfile::{tempdir, NamedTempFile}; | ||
|
||
#[tokio::test] | ||
#[cfg(CHROMA_KUBERNETES_INTEGRATION)] | ||
async fn test_get() { | ||
fn get_s3_client() -> aws_sdk_s3::Client { | ||
// Set up credentials assuming minio is running locally | ||
let cred = aws_sdk_s3::config::Credentials::new( | ||
"minio", | ||
|
@@ -313,26 +395,98 @@ mod tests { | |
.region(aws_sdk_s3::config::Region::new("us-east-1")) | ||
.force_path_style(true) | ||
.build(); | ||
let client = aws_sdk_s3::Client::from_conf(config); | ||
|
||
aws_sdk_s3::Client::from_conf(config) | ||
} | ||
|
||
#[tokio::test] | ||
#[cfg(CHROMA_KUBERNETES_INTEGRATION)] | ||
async fn test_put_get_key() { | ||
let client = get_s3_client(); | ||
|
||
let storage = S3Storage { | ||
bucket: "test".to_string(), | ||
client, | ||
}; | ||
storage.create_bucket().await.unwrap(); | ||
|
||
// Write some data to a test file, put it in s3, get it back and verify its contents | ||
let tmp_dir = tempdir().unwrap(); | ||
let persist_path = tmp_dir.path().to_str().unwrap().to_string(); | ||
|
||
let test_data = "test data"; | ||
let test_file_in = format!("{}/test_file_in", persist_path); | ||
std::fs::write(&test_file_in, test_data).unwrap(); | ||
storage.put_file("test", &test_file_in).await.unwrap(); | ||
let mut bytes = storage.get("test").await.unwrap(); | ||
storage | ||
.put_bytes("test", test_data.as_bytes().to_vec()) | ||
.await | ||
.unwrap(); | ||
|
||
let mut buf = String::new(); | ||
bytes.read_to_string(&mut buf).await.unwrap(); | ||
let mut stream = storage.get("test").await.unwrap(); | ||
|
||
let mut buf = Vec::new(); | ||
while let Some(chunk) = stream.next().await { | ||
match chunk { | ||
Ok(data) => { | ||
buf.extend_from_slice(&data); | ||
} | ||
Err(e) => { | ||
panic!("Error reading stream: {}", e); | ||
} | ||
} | ||
} | ||
|
||
let buf = String::from_utf8(buf).unwrap(); | ||
assert_eq!(buf, test_data); | ||
} | ||
|
||
async fn test_put_file(file_size: usize) { | ||
let client = get_s3_client(); | ||
|
||
let storage = S3Storage { | ||
bucket: "test".to_string(), | ||
client, | ||
}; | ||
storage.create_bucket().await.unwrap(); | ||
|
||
let mut temp_file = NamedTempFile::new().unwrap(); | ||
|
||
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(0); | ||
let mut remaining_file_size = file_size; | ||
|
||
while remaining_file_size > 0 { | ||
let chunk_size = std::cmp::min(remaining_file_size, 4096); | ||
let mut chunk = vec![0u8; chunk_size]; | ||
rng.try_fill(&mut chunk[..]).unwrap(); | ||
temp_file.write_all(&chunk).unwrap(); | ||
remaining_file_size -= chunk_size; | ||
} | ||
|
||
storage | ||
.put_file("test", &temp_file.path().to_str().unwrap()) | ||
.await | ||
.unwrap(); | ||
|
||
let mut stream = storage.get("test").await.unwrap(); | ||
|
||
let mut buf = Vec::new(); | ||
while let Some(chunk) = stream.next().await { | ||
match chunk { | ||
Ok(data) => { | ||
buf.extend_from_slice(&data); | ||
} | ||
Err(e) => { | ||
panic!("Error reading stream: {}", e); | ||
} | ||
} | ||
} | ||
|
||
let file_contents = std::fs::read(temp_file.path()).unwrap(); | ||
assert_eq!(buf, file_contents); | ||
} | ||
|
||
#[tokio::test] | ||
#[cfg(CHROMA_KUBERNETES_INTEGRATION)] | ||
async fn test_put_file_scenarios() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to use proptest but apparently proptest doesn't work for async functions yet :/ I guess I could There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think some unit tests are good enough here |
||
// Under chunk size | ||
test_put_file(1024).await; | ||
// At chunk size | ||
test_put_file(MULTIPART_UPLOAD_CHUNK_SIZE as usize).await; | ||
// Over chunk size | ||
test_put_file((MULTIPART_UPLOAD_CHUNK_SIZE as f64 * 2.5) as usize).await; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
insecure but much faster randomness for tests