diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java index 5474f884233832..feef519888f57c 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java @@ -13,8 +13,6 @@ // limitations under the License. package com.google.devtools.build.lib.remote; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.devtools.build.lib.remote.util.RxFutures.toCompletable; import static com.google.devtools.build.lib.remote.util.RxFutures.toSingle; @@ -26,8 +24,11 @@ import build.bazel.remote.execution.v2.Directory; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.devtools.build.lib.profiler.Profiler; +import com.google.devtools.build.lib.profiler.SilentCloseable; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.merkletree.MerkleTree; @@ -36,16 +37,21 @@ import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.RxUtils.TransferResult; import com.google.protobuf.Message; +import io.reactivex.rxjava3.annotations.NonNull; +import io.reactivex.rxjava3.annotations.Nullable; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableEmitter; +import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.subjects.AsyncSubject; import java.io.IOException; -import java.util.HashSet; +import java.util.ArrayList; +import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; -import javax.annotation.concurrent.GuardedBy; +import java.util.stream.Collectors; /** A {@link RemoteCache} with additional functionality needed for remote execution. */ public class RemoteExecutionCache extends RemoteCache { @@ -85,13 +91,11 @@ public void ensureInputsPresent( return; } - MissingDigestFinder missingDigestFinder = new MissingDigestFinder(context, allDigests.size()); Flowable uploads = - Flowable.fromIterable(allDigests) - .flatMapSingle( - digest -> - uploadBlobIfMissing( - context, merkleTree, additionalInputs, force, missingDigestFinder, digest)); + collectDigests(allDigests, force) + .flatMap(uploadTasks -> findMissingUploads(context, uploadTasks)) + .flatMapPublisher( + digests -> uploadNewDigests(context, merkleTree, additionalInputs, digests)); try { mergeBulkTransfer(uploads).blockingAwait(); @@ -105,36 +109,6 @@ public void ensureInputsPresent( } } - private Single uploadBlobIfMissing( - RemoteActionExecutionContext context, - MerkleTree merkleTree, - Map additionalInputs, - boolean force, - MissingDigestFinder missingDigestFinder, - Digest digest) { - Completable upload = - casUploadCache.execute( - digest, - Completable.defer( - () -> - // Only reach here if the digest is missing and is not being uploaded. - missingDigestFinder - .registerAndCount(digest) - .flatMapCompletable( - missingDigests -> { - if (missingDigests.contains(digest)) { - return toCompletable( - () -> uploadBlob(context, digest, merkleTree, additionalInputs), - directExecutor()); - } else { - return Completable.complete(); - } - })), - /* onIgnored= */ missingDigestFinder::count, - force); - return toTransferResult(upload); - } - private ListenableFuture uploadBlob( RemoteActionExecutionContext context, Digest digest, @@ -165,92 +139,152 @@ private ListenableFuture uploadBlob( digest))); } - /** - * A missing digest finder that initiates the request when the internal counter reaches an - * expected count. - */ - class MissingDigestFinder { - private final int expectedCount; + static class UploadTask { + Digest digest; + // If continuation is not null, we own this task and need to upload the digest if missing and + // then use continuation to notify downstream. + // Otherwise, the task is owned by others, we only care about the completion. + @Nullable CompletableEmitter continuation; + Completable completion; + @Nullable Disposable disposable; + } - private final AsyncSubject> digestsSubject; - private final Single> resultSingle; + private Single> collectDigests(Iterable allDigests, boolean force) { + return Single.using( + () -> Profiler.instance().profile("collect digests"), + ignored -> + Flowable.fromIterable(allDigests) + .flatMapMaybe(digest -> maybeCreateUploadTask(digest, force)) + .collect(Collectors.toList()), + SilentCloseable::close); + } - @GuardedBy("this") - private final Set digests; + private Maybe maybeCreateUploadTask(Digest digest, boolean force) { + return Maybe.create( + emitter -> { + AsyncSubject completion = AsyncSubject.create(); + UploadTask uploadTask = new UploadTask(); + uploadTask.digest = digest; + uploadTask.completion = + Completable.fromObservable( + completion.doOnDispose( + () -> { + if (uploadTask.disposable != null) { + uploadTask.disposable.dispose(); + } + })); + Completable upload = + casUploadCache.execute( + digest, + Completable.create( + continuation -> { + uploadTask.continuation = continuation; + emitter.onSuccess(uploadTask); + }), + () -> emitter.onSuccess(uploadTask), + emitter::onComplete, + force); + upload.subscribe( + new CompletableObserver() { + @Override + public void onSubscribe(@NonNull Disposable d) { + uploadTask.disposable = d; + } - @GuardedBy("this") - private int currentCount = 0; + @Override + public void onComplete() { + completion.onComplete(); + } - MissingDigestFinder(RemoteActionExecutionContext context, int expectedCount) { - checkArgument(expectedCount > 0, "expectedCount should be greater than 0"); - this.expectedCount = expectedCount; - this.digestsSubject = AsyncSubject.create(); - this.digests = new HashSet<>(); + @Override + public void onError(@NonNull Throwable e) { + completion.onError(e); + } + }); + }); + } - AtomicBoolean findMissingDigestsCalled = new AtomicBoolean(false); - this.resultSingle = - Single.fromObservable( - digestsSubject - .flatMapSingle( - digests -> { - boolean wasCalled = findMissingDigestsCalled.getAndSet(true); - // Make sure we don't have re-subscription caused by refCount() below. - checkState(!wasCalled, "FindMissingDigests is called more than once"); - return toSingle( - () -> findMissingDigests(context, digests), directExecutor()); - }) - // Use replay here because we could have a race condition that downstream hasn't - // been added to the subscription list (to receive the upstream result) while - // upstream is completed. - .replay(1) - .refCount()); - } + private Single> findMissingUploads( + RemoteActionExecutionContext context, Iterable newDigests) { + return Single.using( + () -> Profiler.instance().profile("findMissingDigests"), + ignored -> + toSingle( + () -> + findMissingDigests( + context, + Iterables.transform( + Iterables.filter( + newDigests, uploadTask -> uploadTask.continuation != null), + uploadTask -> uploadTask.digest)), + directExecutor()) + .doOnDispose( + () -> { + for (UploadTask uploadTask : newDigests) { + if (uploadTask.disposable != null) { + uploadTask.disposable.dispose(); + } + } + }) + .map( + missingDigests -> { + List result = new ArrayList<>(); + for (UploadTask uploadTask : newDigests) { + if (missingDigests.contains(uploadTask.digest)) { + result.add(uploadTask); + } else { + if (uploadTask.continuation != null) { + uploadTask.continuation.onComplete(); + } + } + } + return result; + }), + SilentCloseable::close); + } - /** - * Register the {@code digest} and increase the counter. - * - *

Returned Single cannot be subscribed more than once. - * - * @return Single that emits the result of the {@code FindMissingDigest} request. - */ - Single> registerAndCount(Digest digest) { - AtomicBoolean subscribed = new AtomicBoolean(false); - // count() will potentially trigger the findMissingDigests call. Adding and counting before - // returning the Single could introduce a race that the result of findMissingDigests is - // available but the consumer doesn't get it because it hasn't subscribed the returned - // Single. In this case, it subscribes after upstream is completed resulting a re-run of - // findMissingDigests (due to refCount()). - // - // Calling count() inside doOnSubscribe to ensure the consumer already subscribed to the - // returned Single to avoid a re-execution of findMissingDigests. - return resultSingle.doOnSubscribe( - d -> { - boolean wasSubscribed = subscribed.getAndSet(true); - checkState(!wasSubscribed, "Single is subscribed more than once"); - synchronized (this) { - digests.add(digest); - } - count(); - }); - } + private Flowable uploadNewDigests( + RemoteActionExecutionContext context, + MerkleTree merkleTree, + Map additionalInputs, + Iterable newDigests) { + return Flowable.using( + () -> Profiler.instance().profile("upload"), + ignored -> + Flowable.fromIterable(newDigests) + .flatMapSingle( + digest -> uploadNewDigest(context, merkleTree, additionalInputs, digest)), + SilentCloseable::close); + } - /** Increase the counter. */ - void count() { - ImmutableSet digestsResult = null; + private Single uploadNewDigest( + RemoteActionExecutionContext context, + MerkleTree merkleTree, + Map additionalInputs, + UploadTask uploadTask) { + CompletableEmitter continuation = uploadTask.continuation; + if (continuation != null) { + toCompletable( + () -> uploadBlob(context, uploadTask.digest, merkleTree, additionalInputs), + directExecutor()) + .subscribe( + new CompletableObserver() { + @Override + public void onSubscribe(@NonNull Disposable d) { + continuation.setDisposable(d); + } - synchronized (this) { - if (currentCount < expectedCount) { - currentCount++; - if (currentCount == expectedCount) { - digestsResult = ImmutableSet.copyOf(digests); - } - } - } + @Override + public void onComplete() { + continuation.onComplete(); + } - if (digestsResult != null) { - digestsSubject.onNext(digestsResult); - digestsSubject.onComplete(); - } + @Override + public void onError(@NonNull Throwable e) { + continuation.onError(e); + } + }); } + return toTransferResult(uploadTask.completion); } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index 31369ef4ee1eab..623dcde8db3201 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -257,10 +257,10 @@ public boolean isDisposed() { /** * Executes a task. * - * @see #execute(Object, Single, Action, boolean). + * @see #execute(Object, Single, Action, Action, boolean). */ public Single execute(KeyT key, Single task, boolean force) { - return execute(key, task, () -> {}, force); + return execute(key, task, () -> {}, () -> {}, force); } /** @@ -270,12 +270,17 @@ public Single execute(KeyT key, Single task, boolean force) { *

If the cache is already shutdown, a {@link CancellationException} will be emitted. * * @param key identifies the task. - * @param onIgnored callback called when provided task is ignored. + * @param onAlreadyFinished callback called when provided task is already finished. * @param force re-execute a finished task if set to {@code true}. * @return a {@link Single} which turns to completed once the task is finished or propagates the * error if any. */ - public Single execute(KeyT key, Single task, Action onIgnored, boolean force) { + public Single execute( + KeyT key, + Single task, + Action onAlreadyRunning, + Action onAlreadyFinished, + boolean force) { return Single.create( emitter -> { synchronized (lock) { @@ -285,7 +290,7 @@ public Single execute(KeyT key, Single task, Action onIgnored, b } if (!force && finished.containsKey(key)) { - onIgnored.run(); + onAlreadyFinished.run(); emitter.onSuccess(finished.get(key)); return; } @@ -294,7 +299,7 @@ public Single execute(KeyT key, Single task, Action onIgnored, b Execution execution = inProgress.get(key); if (execution != null) { - onIgnored.run(); + onAlreadyRunning.run(); } else { execution = new Execution(key, task); inProgress.put(key, execution); @@ -445,13 +450,23 @@ public Completable executeIfNot(KeyT key, Completable task) { /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ public Completable execute(KeyT key, Completable task, boolean force) { - return execute(key, task, () -> {}, force); + return execute(key, task, () -> {}, () -> {}, force); } /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ - public Completable execute(KeyT key, Completable task, Action onIgnored, boolean force) { + public Completable execute( + KeyT key, + Completable task, + Action onAlreadyRunning, + Action onAlreadyFinished, + boolean force) { return Completable.fromSingle( - cache.execute(key, task.toSingleDefault(Optional.empty()), onIgnored, force)); + cache.execute( + key, + task.toSingleDefault(Optional.empty()), + onAlreadyRunning, + onAlreadyFinished, + force)); } /** Returns a set of keys for tasks which is finished. */ diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java index 55b5e7ca44e809..0697e6639a64d5 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java @@ -63,6 +63,8 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; @@ -294,15 +296,17 @@ public void upload_failedUploads_doNotDeduplicate() throws Exception { } @Test - public void ensureInputsPresent_interrupted_cancelInProgressUploadTasks() throws Exception { + public void ensureInputsPresent_interruptedDuringFindMissingDigests_cancelInProgressUploadTasks() + throws Exception { // arrange InMemoryRemoteCache remoteCache = spy(newRemoteCache()); + SettableFuture future = SettableFuture.create(); CountDownLatch findMissingDigestsCalled = new CountDownLatch(1); doAnswer( invocationOnMock -> { findMissingDigestsCalled.countDown(); - return SettableFuture.create(); + return future; }) .when(remoteCache) .findMissingDigests(any(), any()); @@ -336,6 +340,142 @@ public void ensureInputsPresent_interrupted_cancelInProgressUploadTasks() throws // assert assertThat(remoteCache.casUploadCache.getInProgressTasks()).isEmpty(); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); + assertThat(future.isCancelled()).isTrue(); + } + + @Test + public void ensureInputsPresent_interruptedDuringUploadBlobs_cancelInProgressUploadTasks() + throws Exception { + // arrange + RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); + RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); + + List> futures = new ArrayList<>(); + CountDownLatch uploadBlobCalls = new CountDownLatch(2); + doAnswer( + invocationOnMock -> { + uploadBlobCalls.countDown(); + SettableFuture future = SettableFuture.create(); + futures.add(future); + return future; + }) + .when(cacheProtocol) + .uploadBlob(any(), any(), any()); + doAnswer( + invocationOnMock -> { + uploadBlobCalls.countDown(); + SettableFuture future = SettableFuture.create(); + futures.add(future); + return future; + }) + .when(cacheProtocol) + .uploadFile(any(), any(), any()); + + Path path = fs.getPath("/execroot/foo"); + FileSystemUtils.writeContentAsLatin1(path, "bar"); + SortedMap inputs = new TreeMap<>(); + inputs.put(PathFragment.create("foo"), path); + MerkleTree merkleTree = MerkleTree.build(inputs, digestUtil); + + CountDownLatch ensureInputsPresentReturned = new CountDownLatch(1); + Thread thread = + new Thread( + () -> { + try { + remoteCache.ensureInputsPresent(context, merkleTree, ImmutableMap.of(), false); + } catch (IOException | InterruptedException ignored) { + // ignored + } finally { + ensureInputsPresentReturned.countDown(); + } + }); + + // act + thread.start(); + uploadBlobCalls.await(); + assertThat(futures.size()).isEqualTo(2); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).isNotEmpty(); + + thread.interrupt(); + ensureInputsPresentReturned.await(); + + // assert + assertThat(remoteCache.casUploadCache.getInProgressTasks()).isEmpty(); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); + for (SettableFuture future : futures) { + assertThat(future.isCancelled()).isTrue(); + } + } + + @Test + public void ensureInputsPresent_multipleConsumers_interruptedOne_keepInProgressUploadTasks() + throws Exception { + // arrange + RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); + RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); + + List> futures = new ArrayList<>(); + CountDownLatch uploadBlobCalls = new CountDownLatch(2); + doAnswer( + invocationOnMock -> { + uploadBlobCalls.countDown(); + SettableFuture future = SettableFuture.create(); + futures.add(future); + return future; + }) + .when(cacheProtocol) + .uploadBlob(any(), any(), any()); + doAnswer( + invocationOnMock -> { + uploadBlobCalls.countDown(); + SettableFuture future = SettableFuture.create(); + futures.add(future); + return future; + }) + .when(cacheProtocol) + .uploadFile(any(), any(), any()); + + Path path = fs.getPath("/execroot/foo"); + FileSystemUtils.writeContentAsLatin1(path, "bar"); + SortedMap inputs = new TreeMap<>(); + inputs.put(PathFragment.create("foo"), path); + MerkleTree merkleTree = MerkleTree.build(inputs, digestUtil); + + CountDownLatch ensureInputsPresentReturned = new CountDownLatch(2); + CountDownLatch ensureInterrupted = new CountDownLatch(1); + Runnable work = + () -> { + try { + remoteCache.ensureInputsPresent(context, merkleTree, ImmutableMap.of(), false); + } catch (IOException ignored) { + // ignored + } catch (InterruptedException e) { + ensureInterrupted.countDown(); + } finally { + ensureInputsPresentReturned.countDown(); + } + }; + Thread thread1 = new Thread(work); + Thread thread2 = new Thread(work); + + // act + thread1.start(); + thread2.start(); + uploadBlobCalls.await(); + assertThat(futures.size()).isEqualTo(2); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).isNotEmpty(); + + thread1.interrupt(); + ensureInterrupted.await(); + ensureInputsPresentReturned.await(); + + // assert + assertThat(remoteCache.casUploadCache.getInProgressTasks().size()).isEqualTo(2); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); + for (SettableFuture future : futures) { + assertThat(future.isCancelled()).isFalse(); + } } @Test @@ -364,4 +504,9 @@ private InMemoryRemoteCache newRemoteCache() { private RemoteCache newRemoteCache(RemoteCacheClient remoteCacheClient) { return new RemoteCache(remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil); } + + private RemoteExecutionCache newRemoteExecutionCache(RemoteCacheClient remoteCacheClient) { + return new RemoteExecutionCache( + remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil); + } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java index 8925640c11ccbc..1fb0fe969ef8ee 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java @@ -38,7 +38,7 @@ import java.util.stream.Collectors; /** A {@link RemoteCacheClient} that stores its contents in memory. */ -public final class InMemoryCacheClient implements RemoteCacheClient { +public class InMemoryCacheClient implements RemoteCacheClient { private final ListeningExecutorService executorService = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(100));