Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial file retrieval when using GridFS stream download #874

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/cursor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl<T> Cursor<T> {
}

/// Whether this cursor has any additional items to return.
#[allow(dead_code)]
pub(crate) fn has_next(&self) -> bool {
!self.is_exhausted()
|| !self
Expand Down
9 changes: 9 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,15 @@ pub enum GridFsErrorKind {
/// [`GridFsUploadStream`](crate::gridfs::GridFsUploadStream) while a write was still in
/// progress.
WriteInProgress,

/// Partial file download range is invalid when start is greater then end
InvalidPartialDownloadRange { start: u64, end: u64 },

/// Partial file download range is invalid when start or end are greater then file length
PartialDownloadRangeOutOfBounds {
out_of_bounds_value: u64,
file_length: u64,
},
}

/// An identifier for a file stored in a GridFS bucket.
Expand Down
140 changes: 121 additions & 19 deletions src/gridfs/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ use super::{options::GridFsDownloadByNameOptions, Chunk, FilesCollectionDocument
use crate::{
bson::{doc, Bson},
error::{ErrorKind, GridFsErrorKind, GridFsFileIdentifier, Result},
gridfs::GridFsDownloadByIdOptions,
options::{FindOneOptions, FindOptions},
Collection,
Cursor,
};

struct DownloadRange(Option<u64>, Option<u64>);

// Utility functions for finding files within the bucket.
impl GridFsBucket {
async fn find_file_by_id(&self, id: &Bson) -> Result<FilesCollectionDocument> {
Expand Down Expand Up @@ -214,7 +217,7 @@ impl GridFsBucket {
/// use futures_util::io::AsyncReadExt;
///
/// let mut buf = Vec::new();
/// let mut download_stream = bucket.open_download_stream(id).await?;
/// let mut download_stream = bucket.open_download_stream(id, None).await?;
/// download_stream.read_to_end(&mut buf).await?;
/// # Ok(())
/// # }
Expand All @@ -228,15 +231,18 @@ impl GridFsBucket {
/// # async fn compat_example(bucket: GridFsBucket, id: Bson) -> Result<()> {
/// use tokio_util::compat::FuturesAsyncReadCompatExt;
///
/// let futures_upload_stream = bucket.open_download_stream(id).await?;
/// let futures_upload_stream = bucket.open_download_stream(id, None).await?;
/// let tokio_upload_stream = futures_upload_stream.compat();
/// # Ok(())
/// # }
/// ```
pub struct GridFsDownloadStream {
state: State,
current_n: u32,
total_n: u32,
file: FilesCollectionDocument,
to_skip: u64,
to_take: u64,
}

type GetBytesFuture = BoxFuture<'static, Result<(Vec<u8>, Box<Cursor<Chunk<'static>>>)>>;
Expand Down Expand Up @@ -264,25 +270,71 @@ impl State {
}
}

fn validate_range_value(range_value: Option<u64>, file_length: u64) -> Result<()> {
if let Some(range) = range_value {
if range > file_length {
return Err(
ErrorKind::GridFs(GridFsErrorKind::PartialDownloadRangeOutOfBounds {
file_length,
out_of_bounds_value: range,
})
.into(),
);
}
}

Ok(())
}

impl GridFsDownloadStream {
async fn new(
file: FilesCollectionDocument,
chunks: &Collection<Chunk<'static>>,
range: DownloadRange,
) -> Result<Self> {
let initial_state = if file.length == 0 {
validate_range_value(range.0, file.length)?;
validate_range_value(range.1, file.length)?;

let is_empty_range = match range {
DownloadRange(Some(start), Some(end)) => start == end,
_ => false,
};

let to_skip = range.0.unwrap_or(0);
let to_take = range.1.unwrap_or(file.length) - to_skip;
let chunk_size = file.chunk_size_bytes as u64;
let chunks_to_skip = to_skip / chunk_size;
let total_chunks = range
.1
.map(|end| end / chunk_size + u64::from(end % chunk_size != 0));

let initial_state = if file.length == 0 || is_empty_range {
State::Done
} else {
let options = FindOptions::builder().sort(doc! { "n": 1 }).build();
let cursor = chunks.find(doc! { "files_id": &file.id }, options).await?;
let options = FindOptions::builder()
.sort(doc! { "n": 1 })
.limit(total_chunks.map(|end| (end - chunks_to_skip) as i64))
.build();
let cursor = chunks
.find(
doc! { "files_id": &file.id, "n": { "$gte": chunks_to_skip as i64 } },
options,
)
.await?;

State::Idle(Some(Idle {
buffer: Vec::new(),
cursor: Box::new(cursor),
}))
};

Ok(Self {
state: initial_state,
current_n: 0,
current_n: chunks_to_skip as u32,
total_n: total_chunks.map(|value| value as u32).unwrap_or(file.n()),
file,
to_skip: to_skip % chunk_size,
to_take,
})
}
}
Expand All @@ -303,12 +355,12 @@ impl AsyncRead for GridFsDownloadStream {
Ok((buffer, cursor))
} else {
let chunks_in_buf = FilesCollectionDocument::n_from_vals(
buf.len() as u64,
stream.to_skip + buf.len() as u64,
stream.file.chunk_size_bytes,
);
// We should read from current_n to chunks_in_buf + current_n, or, if that would
// exceed the total number of chunks in the file, to the last chunk in the file.
let final_n = std::cmp::min(chunks_in_buf + stream.current_n, stream.file.n());
let final_n = std::cmp::min(chunks_in_buf + stream.current_n, stream.total_n);
let n_range = stream.current_n..final_n;

stream.current_n = final_n;
Expand All @@ -320,10 +372,13 @@ impl AsyncRead for GridFsDownloadStream {
n_range,
stream.file.chunk_size_bytes,
stream.file.length,
stream.to_skip,
)
.boxed(),
);

stream.to_skip = 0;

match new_future.poll_unpin(cx) {
Poll::Ready(result) => result,
Poll::Pending => return Poll::Pending,
Expand All @@ -340,13 +395,19 @@ impl AsyncRead for GridFsDownloadStream {

match result {
Ok((mut buffer, cursor)) => {
let bytes_to_write = std::cmp::min(buffer.len(), buf.len());
let mut bytes_to_write = std::cmp::min(buffer.len(), buf.len());

if bytes_to_write as u64 > stream.to_take {
bytes_to_write = stream.to_take as usize;
}

buf[..bytes_to_write].copy_from_slice(buffer.drain(0..bytes_to_write).as_slice());
stream.to_take -= bytes_to_write as u64;

stream.state = if !buffer.is_empty() || cursor.has_next() {
State::Idle(Some(Idle { buffer, cursor }))
} else {
stream.state = if stream.to_take == 0 {
State::Done
} else {
State::Idle(Some(Idle { buffer, cursor }))
};

Poll::Ready(Ok(bytes_to_write))
Expand All @@ -365,6 +426,7 @@ async fn get_bytes(
n_range: Range<u32>,
chunk_size_bytes: u32,
file_len: u64,
mut to_skip: u64,
) -> Result<(Vec<u8>, Box<Cursor<Chunk<'static>>>)> {
for n in n_range {
if !cursor.advance().await? {
Expand All @@ -389,19 +451,53 @@ async fn get_bytes(
.into());
}

buffer.extend_from_slice(chunk_bytes);
if to_skip >= chunk_bytes.len() as u64 {
to_skip -= chunk_bytes.len() as u64;
} else if to_skip > 0 {
buffer.extend_from_slice(&chunk_bytes[to_skip as usize..]);
to_skip = 0;
} else {
buffer.extend_from_slice(chunk_bytes);
}
}

Ok((buffer, cursor))
}

fn create_download_range(start: Option<u64>, end: Option<u64>) -> Result<DownloadRange> {
match (start, end) {
(Some(start), Some(end)) => {
if start <= end {
Ok(DownloadRange(Some(start), Some(end)))
} else {
Err(
ErrorKind::GridFs(GridFsErrorKind::InvalidPartialDownloadRange { start, end })
.into(),
)
}
}
_ => Ok(DownloadRange(start, end)),
}
}

// User functions for creating download streams.
impl GridFsBucket {
/// Opens and returns a [`GridFsDownloadStream`] from which the application can read
/// the contents of the stored file specified by `id`.
pub async fn open_download_stream(&self, id: Bson) -> Result<GridFsDownloadStream> {
pub async fn open_download_stream(
&self,
id: Bson,
options: impl Into<Option<GridFsDownloadByIdOptions>>,
) -> Result<GridFsDownloadStream> {
let options: Option<GridFsDownloadByIdOptions> = options.into();
let file = self.find_file_by_id(&id).await?;
GridFsDownloadStream::new(file, self.chunks()).await

let range = create_download_range(
options.as_ref().and_then(|options| options.start),
options.as_ref().and_then(|options| options.end),
)?;

GridFsDownloadStream::new(file, self.chunks(), range).await
}

/// Opens and returns a [`GridFsDownloadStream`] from which the application can read
Expand All @@ -416,9 +512,15 @@ impl GridFsBucket {
filename: impl AsRef<str>,
options: impl Into<Option<GridFsDownloadByNameOptions>>,
) -> Result<GridFsDownloadStream> {
let file = self
.find_file_by_name(filename.as_ref(), options.into())
.await?;
GridFsDownloadStream::new(file, self.chunks()).await
let options: Option<GridFsDownloadByNameOptions> = options.into();

let range = create_download_range(
options.as_ref().and_then(|options| options.start),
options.as_ref().and_then(|options| options.end),
)?;

let file = self.find_file_by_name(filename.as_ref(), options).await?;

GridFsDownloadStream::new(file, self.chunks(), range).await
}
}
21 changes: 21 additions & 0 deletions src/gridfs/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ pub struct GridFsUploadOptions {
pub metadata: Option<Document>,
}

/// Contains the options for downloading a file from a [`GridFsBucket`](crate::gridfs::GridFsBucket)
/// by id.
#[derive(Clone, Debug, Default, Deserialize, TypedBuilder)]
#[builder(field_defaults(default, setter(into)))]
#[non_exhaustive]
pub struct GridFsDownloadByIdOptions {
/// 0-indexed non-negative byte offset from the beginning of the file.
pub start: Option<u64>,

/// 0-indexed non-negative byte offset to the end of the file contents to be returned by the
/// stream. end is non-inclusive.
pub end: Option<u64>,
}

/// Contains the options for downloading a file from a [`GridFsBucket`](crate::gridfs::GridFsBucket)
/// by name.
#[derive(Clone, Debug, Default, Deserialize, TypedBuilder)]
Expand All @@ -60,6 +74,13 @@ pub struct GridFsDownloadByNameOptions {
/// -2 = the second most recent revision
/// -1 = the most recent revision
pub revision: Option<i32>,

/// 0-indexed non-negative byte offset from the beginning of the file.
pub start: Option<u64>,

/// 0-indexed non-negative byte offset to the end of the file contents to be returned by the
/// stream. end is non-inclusive.
pub end: Option<u64>,
}

/// Contains the options for finding
Expand Down
12 changes: 9 additions & 3 deletions src/sync/gridfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{
};

pub use crate::gridfs::FilesCollectionDocument;
use crate::gridfs::GridFsDownloadByIdOptions;

/// A `GridFsBucket` provides the functionality for storing and retrieving binary BSON data that
/// exceeds the 16 MiB size limit of a MongoDB document. Users may upload and download large amounts
Expand Down Expand Up @@ -98,7 +99,7 @@ impl GridFsBucket {
/// use std::io::Read;
///
/// let mut buf = Vec::new();
/// let mut download_stream = bucket.open_download_stream(id)?;
/// let mut download_stream = bucket.open_download_stream(id, None)?;
/// download_stream.read_to_end(&mut buf)?;
/// # Ok(())
/// # }
Expand All @@ -123,8 +124,13 @@ impl GridFsDownloadStream {
impl GridFsBucket {
/// Opens and returns a [`GridFsDownloadStream`] from which the application can read
/// the contents of the stored file specified by `id`.
pub fn open_download_stream(&self, id: Bson) -> Result<GridFsDownloadStream> {
runtime::block_on(self.async_bucket.open_download_stream(id)).map(GridFsDownloadStream::new)
pub fn open_download_stream(
&self,
id: Bson,
options: impl Into<Option<GridFsDownloadByIdOptions>>,
) -> Result<GridFsDownloadStream> {
runtime::block_on(self.async_bucket.open_download_stream(id, options))
.map(GridFsDownloadStream::new)
}

/// Opens and returns a [`GridFsDownloadStream`] from which the application can read
Expand Down
2 changes: 1 addition & 1 deletion src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ fn gridfs() {
upload_stream.close().unwrap();

let mut download_stream = bucket
.open_download_stream(upload_stream.id().clone())
.open_download_stream(upload_stream.id().clone(), None)
.unwrap();
download_stream.read_to_end(&mut download).unwrap();

Expand Down
Loading