diff --git a/src/rust/engine/fs/store/src/remote.rs b/src/rust/engine/fs/store/src/remote.rs index e4db49cc353e..42dacdff0fc6 100644 --- a/src/rust/engine/fs/store/src/remote.rs +++ b/src/rust/engine/fs/store/src/remote.rs @@ -18,8 +18,8 @@ use protos::gen::build::bazel::remote::execution::v2 as remexec; use protos::gen::google::bytestream::byte_stream_client::ByteStreamClient; use remexec::{ capabilities_client::CapabilitiesClient, - content_addressable_storage_client::ContentAddressableStorageClient, BatchUpdateBlobsRequest, - ServerCapabilities, + content_addressable_storage_client::ContentAddressableStorageClient, BatchReadBlobsRequest, + BatchUpdateBlobsRequest, ServerCapabilities, }; use tonic::{Code, Request, Status}; use workunit_store::{in_workunit, ObservationMetric}; @@ -203,16 +203,7 @@ impl ByteStore { }) } - async fn store_bytes_source( - &self, - digest: Digest, - bytes: ByteSource, - ) -> Result<(), ByteStoreError> - where - ByteSource: Fn(Range) -> Bytes + Send + Sync + 'static, - { - let len = digest.size_bytes; - + async fn len_is_allowed_for_batch_api(&self, len: usize) -> Result { let max_batch_total_size_bytes = { let capabilities = self.get_capabilities().await?; @@ -226,14 +217,25 @@ impl ByteStore { let batch_api_allowed_by_local_config = len <= self.batch_api_size_limit; let batch_api_allowed_by_server_config = max_batch_total_size_bytes == 0 || len < max_batch_total_size_bytes; - if batch_api_allowed_by_local_config && batch_api_allowed_by_server_config { + Ok(batch_api_allowed_by_local_config && batch_api_allowed_by_server_config) + } + + async fn store_bytes_source( + &self, + digest: Digest, + bytes: ByteSource, + ) -> Result<(), ByteStoreError> + where + ByteSource: Fn(Range) -> Bytes + Send + Sync + 'static, + { + if self.len_is_allowed_for_batch_api(digest.size_bytes).await? { self.store_bytes_source_batch(digest, bytes).await } else { self.store_bytes_source_stream(digest, bytes).await } } - async fn store_bytes_source_batch( + pub(crate) async fn store_bytes_source_batch( &self, digest: Digest, bytes: ByteSource, @@ -257,7 +259,7 @@ impl ByteStore { Ok(()) } - async fn store_bytes_source_stream( + pub(crate) async fn store_bytes_source_stream( &self, digest: Digest, bytes: ByteSource, @@ -339,6 +341,62 @@ impl ByteStore { &self, digest: Digest, f: F, + ) -> Result, ByteStoreError> { + if self.len_is_allowed_for_batch_api(digest.size_bytes).await? { + self.load_bytes_with_batch(digest, f).await + } else { + self.load_bytes_with_stream(digest, f).await + } + } + + pub(crate) async fn load_bytes_with_batch< + T: Send + 'static, + F: Fn(Bytes) -> Result + Send + Sync + Clone + 'static, + >( + &self, + digest: Digest, + f: F, + ) -> Result, ByteStoreError> { + let request = BatchReadBlobsRequest { + instance_name: self.instance_name.clone().unwrap_or_default(), + digests: vec![digest.into()], + }; + let mut client = self.cas_client.as_ref().clone(); + let response = client + .batch_read_blobs(request) + .await + .map_err(ByteStoreError::Grpc)?; + + let response = response.into_inner(); + if response.responses.len() != 1 { + return Err(ByteStoreError::Other( + format!( + "Response from remote store for BatchReadBlobs API had inconsistent number of responses (got {}, expected 1)", + response.responses.len() + ) + )); + } + + let blob_response = response.responses.into_iter().next().unwrap(); + let rpc_status = blob_response.status.unwrap_or_default(); + let status = Status::from(rpc_status); + match status.code() { + Code::Ok => { + let result = f(blob_response.data); + result.map(Some).map_err(ByteStoreError::Other) + } + Code::NotFound => Ok(None), + _ => Err(ByteStoreError::Grpc(status)), + } + } + + pub(crate) async fn load_bytes_with_stream< + T: Send + 'static, + F: Fn(Bytes) -> Result + Send + Sync + Clone + 'static, + >( + &self, + digest: Digest, + f: F, ) -> Result, ByteStoreError> { let start = Instant::now(); let store = self.clone(); diff --git a/src/rust/engine/fs/store/src/remote_tests.rs b/src/rust/engine/fs/store/src/remote_tests.rs index 52363edc5436..583c33c30ede 100644 --- a/src/rust/engine/fs/store/src/remote_tests.rs +++ b/src/rust/engine/fs/store/src/remote_tests.rs @@ -307,6 +307,63 @@ async fn list_missing_digests_error() { ); } +#[tokio::test] +async fn load_via_each_api() { + let _ = WorkunitStore::setup_for_tests(); + let cas = StubCAS::empty(); + + let data = TestData::roland(); + { + let mut blobs = cas.blobs.lock(); + blobs.insert(data.fingerprint(), data.bytes()); + } + + let store = new_byte_store(&cas); + let result_batch = store + .load_bytes_with_batch(data.digest(), |b| Ok(b)) + .await + .unwrap() + .unwrap(); + let result_stream = store + .load_bytes_with_stream(data.digest(), |b| Ok(b)) + .await + .unwrap() + .unwrap(); + assert_eq!(result_batch, data.bytes()); + assert_eq!(result_stream, data.bytes()); +} + +#[tokio::test] +async fn store_via_each_api() { + let _ = WorkunitStore::setup_for_tests(); + let cas = StubCAS::empty(); + + let data = TestData::roland(); + let store = new_byte_store(&cas); + + let bytes = data.bytes(); + let _ = store + .store_bytes_source_batch(data.digest(), move |r| bytes.slice(r)) + .await + .unwrap(); + { + let mut blobs = cas.blobs.lock(); + assert_eq!(*blobs.get(&data.digest().hash).unwrap(), data.bytes()); + blobs.clear(); + } + + let bytes = data.bytes(); + let _ = store + .store_bytes_source_stream(data.digest(), move |r| bytes.slice(r)) + .await + .unwrap(); + { + let mut blobs = cas.blobs.lock(); + assert_eq!(*blobs.get(&data.digest().hash).unwrap(), data.bytes()); + blobs.clear(); + } +} + fn new_byte_store(cas: &StubCAS) -> ByteStore { ByteStore::new( &cas.address(), diff --git a/src/rust/engine/protos/src/conversions.rs b/src/rust/engine/protos/src/conversions.rs index 1e385d8cf008..c7fb43da86ea 100644 --- a/src/rust/engine/protos/src/conversions.rs +++ b/src/rust/engine/protos/src/conversions.rs @@ -1,3 +1,5 @@ +use tonic::{Code, Status}; + impl<'a> From<&'a hashing::Digest> for crate::gen::build::bazel::remote::execution::v2::Digest { fn from(d: &'a hashing::Digest) -> Self { Self { @@ -51,3 +53,9 @@ pub fn require_digest< None => Err("Protocol violation: Digest missing from a Remote Execution API protobuf.".into()), } } + +impl From for Status { + fn from(rpc_status: crate::gen::google::rpc::Status) -> Self { + Status::new(Code::from_i32(rpc_status.code), rpc_status.message) + } +} diff --git a/src/rust/engine/testutil/mock/src/cas_service.rs b/src/rust/engine/testutil/mock/src/cas_service.rs index de855a1ae36c..8a95f41fbe23 100644 --- a/src/rust/engine/testutil/mock/src/cas_service.rs +++ b/src/rust/engine/testutil/mock/src/cas_service.rs @@ -446,6 +446,11 @@ impl ContentAddressableStorage for StubCASResponder { &self, request: Request, ) -> Result, Status> { + { + let mut request_count = self.read_request_count.lock(); + *request_count += 1; + } + check_auth!(self, request); if self.always_errors {