Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] multipart S3 file uploads #2590

Merged
merged 10 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ rayon = "1.8.0"
criterion = "0.3"
random-port = "0.1.1"
serial_test = "3.1.1"
rand_xorshift = "0.3.0"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insecure but much faster randomness for tests


[build-dependencies]
tonic-build = "0.10"
Expand Down
196 changes: 175 additions & 21 deletions rust/worker/src/storage/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ use aws_sdk_s3::error::ProvideErrorMetadata;
use aws_sdk_s3::error::SdkError;
use aws_sdk_s3::operation::create_bucket::CreateBucketError;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::CompletedMultipartUpload;
use aws_sdk_s3::types::CompletedPart;
use aws_smithy_types::byte_stream::Length;
use futures::Stream;
use std::clone::Clone;
use std::time::Duration;
use thiserror::Error;

// todo: make this more principled
const MULTIPART_UPLOAD_CHUNK_SIZE: u64 = 1024 * 1024 * 32;
HammadB marked this conversation as resolved.
Show resolved Hide resolved

#[derive(Clone)]
pub(crate) struct S3Storage {
bucket: String,
Expand Down Expand Up @@ -161,13 +167,88 @@ impl S3Storage {
}

pub(crate) async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let bytestream = ByteStream::from_path(path).await;
match bytestream {
Ok(bytestream) => return self.put_bytestream(key, bytestream).await,
Err(e) => {
return Err(S3PutError::S3PutError(e.to_string()));
let file_size = tokio::fs::metadata(path)
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?
.len();
let mut chunk_count = (file_size / MULTIPART_UPLOAD_CHUNK_SIZE) + 1;
let mut size_of_last_chunk = file_size % MULTIPART_UPLOAD_CHUNK_SIZE;
if size_of_last_chunk == 0 {
size_of_last_chunk = MULTIPART_UPLOAD_CHUNK_SIZE;
chunk_count -= 1;
}

let upload_id = match self
.client
.create_multipart_upload()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?
.upload_id
{
Some(upload_id) => upload_id,
None => {
return Err(S3PutError::S3PutError(
"Multipart upload creation response missing upload ID".to_string(),
));
}
};

let mut upload_parts = Vec::new();
for chunk_index in 0..chunk_count {
let this_chunk = if chunk_count - 1 == chunk_index {
size_of_last_chunk
} else {
MULTIPART_UPLOAD_CHUNK_SIZE
};

let stream = ByteStream::read_from()
.path(path)
.offset(chunk_index * MULTIPART_UPLOAD_CHUNK_SIZE)
.length(Length::Exact(this_chunk))
.build()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?;

//Chunk index needs to start at 0, but part numbers start at 1.
let part_number = (chunk_index as i32) + 1;
let upload_part_res = self
.client
.upload_part()
.key(key)
.bucket(&self.bucket)
.upload_id(&upload_id)
.body(stream)
.part_number(part_number)
.send()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?;

upload_parts.push(
CompletedPart::builder()
.e_tag(upload_part_res.e_tag.unwrap_or_default())
.part_number(part_number)
.build(),
);
}

self.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(key)
.multipart_upload(
CompletedMultipartUpload::builder()
.set_parts(Some(upload_parts))
.build(),
)
.upload_id(&upload_id)
.send()
.await
.map_err(|err| S3PutError::S3PutError(err.to_string()))?;

Ok(())
}

async fn put_bytestream(&self, key: &str, bytestream: ByteStream) -> Result<(), S3PutError> {
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -289,13 +370,14 @@ impl Configurable<StorageConfig> for S3Storage {

#[cfg(test)]
mod tests {
use std::io::Write;

use super::*;
use tempfile::tempdir;
use tokio::io::AsyncReadExt;
use futures::StreamExt;
use rand::{Rng, SeedableRng};
use tempfile::{tempdir, NamedTempFile};

#[tokio::test]
#[cfg(CHROMA_KUBERNETES_INTEGRATION)]
async fn test_get() {
fn get_s3_client() -> aws_sdk_s3::Client {
// Set up credentials assuming minio is running locally
let cred = aws_sdk_s3::config::Credentials::new(
"minio",
Expand All @@ -313,26 +395,98 @@ mod tests {
.region(aws_sdk_s3::config::Region::new("us-east-1"))
.force_path_style(true)
.build();
let client = aws_sdk_s3::Client::from_conf(config);

aws_sdk_s3::Client::from_conf(config)
}

#[tokio::test]
#[cfg(CHROMA_KUBERNETES_INTEGRATION)]
async fn test_put_get_key() {
let client = get_s3_client();

let storage = S3Storage {
bucket: "test".to_string(),
client,
};
storage.create_bucket().await.unwrap();

// Write some data to a test file, put it in s3, get it back and verify its contents
let tmp_dir = tempdir().unwrap();
let persist_path = tmp_dir.path().to_str().unwrap().to_string();

let test_data = "test data";
let test_file_in = format!("{}/test_file_in", persist_path);
std::fs::write(&test_file_in, test_data).unwrap();
storage.put_file("test", &test_file_in).await.unwrap();
let mut bytes = storage.get("test").await.unwrap();
storage
.put_bytes("test", test_data.as_bytes().to_vec())
.await
.unwrap();

let mut buf = String::new();
bytes.read_to_string(&mut buf).await.unwrap();
let mut stream = storage.get("test").await.unwrap();

let mut buf = Vec::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(data) => {
buf.extend_from_slice(&data);
}
Err(e) => {
panic!("Error reading stream: {}", e);
}
}
}

let buf = String::from_utf8(buf).unwrap();
assert_eq!(buf, test_data);
}

async fn test_put_file(file_size: usize) {
let client = get_s3_client();

let storage = S3Storage {
bucket: "test".to_string(),
client,
};
storage.create_bucket().await.unwrap();

let mut temp_file = NamedTempFile::new().unwrap();

let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(0);
let mut remaining_file_size = file_size;

while remaining_file_size > 0 {
let chunk_size = std::cmp::min(remaining_file_size, 4096);
let mut chunk = vec![0u8; chunk_size];
rng.try_fill(&mut chunk[..]).unwrap();
temp_file.write_all(&chunk).unwrap();
remaining_file_size -= chunk_size;
}

storage
.put_file("test", &temp_file.path().to_str().unwrap())
.await
.unwrap();

let mut stream = storage.get("test").await.unwrap();

let mut buf = Vec::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(data) => {
buf.extend_from_slice(&data);
}
Err(e) => {
panic!("Error reading stream: {}", e);
}
}
}

let file_contents = std::fs::read(temp_file.path()).unwrap();
assert_eq!(buf, file_contents);
}

#[tokio::test]
#[cfg(CHROMA_KUBERNETES_INTEGRATION)]
async fn test_put_file_scenarios() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to use proptest but apparently proptest doesn't work for async functions yet :/

I guess I could block_on if that's preferred

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some unit tests are good enough here

// Under chunk size
test_put_file(1024).await;
// At chunk size
test_put_file(MULTIPART_UPLOAD_CHUNK_SIZE as usize).await;
// Over chunk size
test_put_file((MULTIPART_UPLOAD_CHUNK_SIZE as f64 * 2.5) as usize).await;
}
}
Loading