Skip to content

Commit

Permalink
add REAPI batch APIs, take pantsbuild#2
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom Dyas committed Jul 12, 2022
1 parent 370ca9d commit 06c66f6
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 15 deletions.
88 changes: 73 additions & 15 deletions src/rust/engine/fs/store/src/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use protos::gen::build::bazel::remote::execution::v2 as remexec;
use protos::gen::google::bytestream::byte_stream_client::ByteStreamClient;
use remexec::{
capabilities_client::CapabilitiesClient,
content_addressable_storage_client::ContentAddressableStorageClient, BatchUpdateBlobsRequest,
ServerCapabilities,
content_addressable_storage_client::ContentAddressableStorageClient, BatchReadBlobsRequest,
BatchUpdateBlobsRequest, ServerCapabilities,
};
use tonic::{Code, Request, Status};
use workunit_store::{in_workunit, ObservationMetric};
Expand Down Expand Up @@ -203,16 +203,7 @@ impl ByteStore {
})
}

async fn store_bytes_source<ByteSource>(
&self,
digest: Digest,
bytes: ByteSource,
) -> Result<(), ByteStoreError>
where
ByteSource: Fn(Range<usize>) -> Bytes + Send + Sync + 'static,
{
let len = digest.size_bytes;

async fn len_is_allowed_for_batch_api(&self, len: usize) -> Result<bool, ByteStoreError> {
let max_batch_total_size_bytes = {
let capabilities = self.get_capabilities().await?;

Expand All @@ -226,14 +217,25 @@ impl ByteStore {
let batch_api_allowed_by_local_config = len <= self.batch_api_size_limit;
let batch_api_allowed_by_server_config =
max_batch_total_size_bytes == 0 || len < max_batch_total_size_bytes;
if batch_api_allowed_by_local_config && batch_api_allowed_by_server_config {
Ok(batch_api_allowed_by_local_config && batch_api_allowed_by_server_config)
}

async fn store_bytes_source<ByteSource>(
&self,
digest: Digest,
bytes: ByteSource,
) -> Result<(), ByteStoreError>
where
ByteSource: Fn(Range<usize>) -> Bytes + Send + Sync + 'static,
{
if self.len_is_allowed_for_batch_api(digest.size_bytes).await? {
self.store_bytes_source_batch(digest, bytes).await
} else {
self.store_bytes_source_stream(digest, bytes).await
}
}

async fn store_bytes_source_batch<ByteSource>(
pub(crate) async fn store_bytes_source_batch<ByteSource>(
&self,
digest: Digest,
bytes: ByteSource,
Expand All @@ -257,7 +259,7 @@ impl ByteStore {
Ok(())
}

async fn store_bytes_source_stream<ByteSource>(
pub(crate) async fn store_bytes_source_stream<ByteSource>(
&self,
digest: Digest,
bytes: ByteSource,
Expand Down Expand Up @@ -339,6 +341,62 @@ impl ByteStore {
&self,
digest: Digest,
f: F,
) -> Result<Option<T>, ByteStoreError> {
if self.len_is_allowed_for_batch_api(digest.size_bytes).await? {
self.load_bytes_with_batch(digest, f).await
} else {
self.load_bytes_with_stream(digest, f).await
}
}

pub(crate) async fn load_bytes_with_batch<
T: Send + 'static,
F: Fn(Bytes) -> Result<T, String> + Send + Sync + Clone + 'static,
>(
&self,
digest: Digest,
f: F,
) -> Result<Option<T>, ByteStoreError> {
let request = BatchReadBlobsRequest {
instance_name: self.instance_name.clone().unwrap_or_default(),
digests: vec![digest.into()],
};
let mut client = self.cas_client.as_ref().clone();
let response = client
.batch_read_blobs(request)
.await
.map_err(ByteStoreError::Grpc)?;

let response = response.into_inner();
if response.responses.len() != 1 {
return Err(ByteStoreError::Other(
format!(
"Response from remote store for BatchReadBlobs API had inconsistent number of responses (got {}, expected 1)",
response.responses.len()
)
));
}

let blob_response = response.responses.into_iter().next().unwrap();
let rpc_status = blob_response.status.unwrap_or_default();
let status = Status::from(rpc_status);
match status.code() {
Code::Ok => {
let result = f(blob_response.data);
result.map(Some).map_err(ByteStoreError::Other)
}
Code::NotFound => Ok(None),
_ => Err(ByteStoreError::Grpc(status)),
}
}

pub(crate) async fn load_bytes_with_stream<
T: Send + 'static,
F: Fn(Bytes) -> Result<T, String> + Send + Sync + Clone + 'static,
>(
&self,
digest: Digest,
f: F,
) -> Result<Option<T>, ByteStoreError> {
let start = Instant::now();
let store = self.clone();
Expand Down
57 changes: 57 additions & 0 deletions src/rust/engine/fs/store/src/remote_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,63 @@ async fn list_missing_digests_error() {
);
}

#[tokio::test]
async fn load_via_each_api() {
let _ = WorkunitStore::setup_for_tests();
let cas = StubCAS::empty();

let data = TestData::roland();
{
let mut blobs = cas.blobs.lock();
blobs.insert(data.fingerprint(), data.bytes());
}

let store = new_byte_store(&cas);
let result_batch = store
.load_bytes_with_batch(data.digest(), |b| Ok(b))
.await
.unwrap()
.unwrap();
let result_stream = store
.load_bytes_with_stream(data.digest(), |b| Ok(b))
.await
.unwrap()
.unwrap();
assert_eq!(result_batch, data.bytes());
assert_eq!(result_stream, data.bytes());
}

#[tokio::test]
async fn store_via_each_api() {
let _ = WorkunitStore::setup_for_tests();
let cas = StubCAS::empty();

let data = TestData::roland();
let store = new_byte_store(&cas);

let bytes = data.bytes();
let _ = store
.store_bytes_source_batch(data.digest(), move |r| bytes.slice(r))
.await
.unwrap();
{
let mut blobs = cas.blobs.lock();
assert_eq!(*blobs.get(&data.digest().hash).unwrap(), data.bytes());
blobs.clear();
}

let bytes = data.bytes();
let _ = store
.store_bytes_source_stream(data.digest(), move |r| bytes.slice(r))
.await
.unwrap();
{
let mut blobs = cas.blobs.lock();
assert_eq!(*blobs.get(&data.digest().hash).unwrap(), data.bytes());
blobs.clear();
}
}

fn new_byte_store(cas: &StubCAS) -> ByteStore {
ByteStore::new(
&cas.address(),
Expand Down
8 changes: 8 additions & 0 deletions src/rust/engine/protos/src/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use tonic::{Code, Status};

impl<'a> From<&'a hashing::Digest> for crate::gen::build::bazel::remote::execution::v2::Digest {
fn from(d: &'a hashing::Digest) -> Self {
Self {
Expand Down Expand Up @@ -51,3 +53,9 @@ pub fn require_digest<
None => Err("Protocol violation: Digest missing from a Remote Execution API protobuf.".into()),
}
}

impl From<crate::gen::google::rpc::Status> for Status {
fn from(rpc_status: crate::gen::google::rpc::Status) -> Self {
Status::new(Code::from_i32(rpc_status.code), rpc_status.message)
}
}
5 changes: 5 additions & 0 deletions src/rust/engine/testutil/mock/src/cas_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@ impl ContentAddressableStorage for StubCASResponder {
&self,
request: Request<BatchReadBlobsRequest>,
) -> Result<Response<BatchReadBlobsResponse>, Status> {
{
let mut request_count = self.read_request_count.lock();
*request_count += 1;
}

check_auth!(self, request);

if self.always_errors {
Expand Down

0 comments on commit 06c66f6

Please sign in to comment.