diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java index e642750305..73810097a9 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ReplayStrategies.java @@ -23,14 +23,15 @@ import java.util.Deque; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; -import static io.servicetalk.concurrent.internal.EmptySubscriptions.EMPTY_SUBSCRIPTION_NO_THROW; +import static io.servicetalk.concurrent.internal.ConcurrentUtils.releaseLock; +import static io.servicetalk.concurrent.internal.ConcurrentUtils.tryAcquireLock; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; @@ -92,23 +93,24 @@ public void deliverAccumulation(final Consumer consumer) { } private static final class MostRecentTimeLimitedReplayAccumulator implements ReplayAccumulator { - @SuppressWarnings("rawtypes") - private static final AtomicLongFieldUpdater stateSizeUpdater = - AtomicLongFieldUpdater.newUpdater(MostRecentTimeLimitedReplayAccumulator.class, "stateSize"); + private static final Cancellable CANCELLED = () -> { }; @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater timerCancellableUpdater = newUpdater(MostRecentTimeLimitedReplayAccumulator.class, Cancellable.class, "timerCancellable"); + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater queueLockUpdater = + AtomicIntegerFieldUpdater.newUpdater(MostRecentTimeLimitedReplayAccumulator.class, "queueLock"); + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater queueSizeUpdater = + AtomicIntegerFieldUpdater.newUpdater(MostRecentTimeLimitedReplayAccumulator.class, "queueSize"); private final Executor executor; private final Queue> items; private final long ttlNanos; private final int maxItems; - /** - * Provide atomic state for size of {@link #items} and also for visibility between the threads consuming and - * producing. The atomically incrementing "state" ensures that any modifications from the producer thread - * are visible from the consumer thread and we never "miss" a timer schedule event if the queue becomes empty. - */ - private volatile long stateSize; + private volatile int queueSize; + @SuppressWarnings("unused") + private volatile int queueLock; @Nullable private volatile Cancellable timerCancellable; @@ -122,68 +124,102 @@ private static final class MostRecentTimeLimitedReplayAccumulator implements this.executor = requireNonNull(executor); this.ttlNanos = ttl.toNanos(); this.maxItems = maxItems; - items = new ConcurrentLinkedQueue<>(); // SpMc + // SpSc, but needs iterator. + // accumulate is called on one thread (no concurrent access on this method). + // timerFire maybe called on another thread + items = new ConcurrentLinkedQueue<>(); } @Override public void accumulate(@Nullable final T t) { - // We may exceed max items in the queue but this method isn't invoked concurrently, so we only go over by - // at most 1 item. - items.add(new TimeStampSignal<>(executor.currentTime(NANOSECONDS), t)); - for (;;) { - final long currentStateSize = stateSize; - final int currentSize = getSize(currentStateSize); - final int nextState = getState(currentStateSize) + 1; - if (currentSize >= maxItems) { - if (stateSizeUpdater.compareAndSet(this, currentStateSize, - buildStateSize(nextState, currentSize))) { + long scheduleTimerNanos = -1; + final TimeStampSignal signal = new TimeStampSignal<>(executor.currentTime(NANOSECONDS), t); + if (tryAcquireLock(queueLockUpdater, this)) { + for (;;) { + final int qSize = queueSize; + if (qSize < maxItems) { + if (queueSizeUpdater.compareAndSet(this, qSize, qSize + 1)) { + items.add(signal); + if (qSize == 0) { + scheduleTimerNanos = ttlNanos; + } + break; + } + } else if (queueSizeUpdater.compareAndSet(this, qSize, qSize)) { + // Queue removal is only done while queueLock is acquired, so we don't need to worry about + // the timer thread removing items concurrently. items.poll(); + items.add(signal); break; } - } else if (stateSizeUpdater.compareAndSet(this, currentStateSize, - buildStateSize(nextState, currentSize + 1))) { - if (currentSize == 0) { - schedulerTimer(ttlNanos); - } - break; } + if (!releaseLock(queueLockUpdater, this)) { + scheduleTimerNanos = tryDrainQueue(); + } + } else { + queueSizeUpdater.incrementAndGet(this); + items.add(signal); + scheduleTimerNanos = tryDrainQueue(); + } + + if (scheduleTimerNanos >= 0) { + schedulerTimer(scheduleTimerNanos); } } @Override public void deliverAccumulation(final Consumer consumer) { + int i = 0; for (TimeStampSignal timeStampSignal : items) { consumer.accept(timeStampSignal.signal); + // The queue size maybe larger than maxItems if we weren't able to acquire the queueLock while adding. + // This is only a temporary condition while there is concurrent access between timer and accumulator. + // Guard against it here to preserve the invariant that we shouldn't deliver more than maxItems. + if (++i >= maxItems) { + break; + } } } @Override public void cancelAccumulation() { - final Cancellable cancellable = timerCancellableUpdater.getAndSet(this, EMPTY_SUBSCRIPTION_NO_THROW); + // Stop the background timer and prevent it from being rescheduled. It is possible upstream may deliver + // more data but the queue size is bounded by maxItems and this method should only be called when upstream + // is cancelled which should eventually dereference this object making it eligible for GC (no memory leak). + final Cancellable cancellable = timerCancellableUpdater.getAndSet(this, CANCELLED); if (cancellable != null) { cancellable.cancel(); } } - private static int getSize(long stateSize) { - return (int) stateSize; - } + private long tryDrainQueue() { + long scheduleTimerNanos = -1; + boolean tryAcquire = true; + while (tryAcquire && tryAcquireLock(queueLockUpdater, this)) { + // Ensure the queue contains maxItems or less items. + for (;;) { + final int qSize = queueSize; + if (qSize <= maxItems) { + break; + } else if (queueSizeUpdater.compareAndSet(this, qSize, qSize - 1)) { + items.poll(); + } + } - private static int getState(long stateSize) { - return (int) (stateSize >>> 32); - } + scheduleTimerNanos = doExpire(); - private static long buildStateSize(int state, int size) { - return (((long) state) << 32) | size; + tryAcquire = !releaseLock(queueLockUpdater, this); + } + return scheduleTimerNanos; } private void schedulerTimer(long nanos) { for (;;) { final Cancellable currentCancellable = timerCancellable; - if (currentCancellable == EMPTY_SUBSCRIPTION_NO_THROW) { + if (currentCancellable == CANCELLED) { break; } else { - final Cancellable nextCancellable = executor.schedule(this::expireSignals, nanos, NANOSECONDS); + final Cancellable nextCancellable = executor.schedule(this::timerFire, nanos, NANOSECONDS); if (timerCancellableUpdater.compareAndSet(this, currentCancellable, nextCancellable)) { // Current logic only has 1 timer outstanding at any give time so cancellation of // the current cancellable shouldn't be necessary but do it for completeness. @@ -198,37 +234,44 @@ private void schedulerTimer(long nanos) { } } - private void expireSignals() { + // lock must be held! + private long doExpire() { final long nanoTime = executor.currentTime(NANOSECONDS); TimeStampSignal item; for (;;) { - // read stateSize before peek, so if we poll from the queue we are sure to see the correct - // state relative to items in the queue. - final long currentStateSize = stateSize; + final long delta; item = items.peek(); if (item == null) { - break; - } else if (nanoTime - item.timeStamp >= ttlNanos) { - final int currentSize = getSize(currentStateSize); - if (stateSizeUpdater.compareAndSet(this, currentStateSize, - buildStateSize(getState(currentStateSize) + 1, currentSize - 1))) { - // When we add: we add to the queue we add first, then CAS sizeState. - // When we remove: we CAS the atomic state first, then poll. - // This avoids removing a non-expired item because if the "add" thread is running faster and - // already polled "item" the CAS will fail, and we will try again on the next loop iteration. - items.poll(); - if (currentSize == 1) { - // a new timer task will be scheduled after addition if this is the case. break to avoid - // multiple timer tasks running concurrently. - break; - } - } + return -1; + } else if ((delta = nanoTime - item.timeStamp) >= ttlNanos) { + final int qSize = queueSizeUpdater.decrementAndGet(this); + assert qSize >= 0; + // Removal is only done while holding the lock. This means we don't have to worry about the + // accumulator thread running concurrently and removing the peeked item behind our back. + items.poll(); } else { - schedulerTimer(ttlNanos - (nanoTime - item.timeStamp)); - break; // elements sorted in increasing time, break when first non-expired entry found. + // elements sorted in increasing time, break when first non-expired entry found. + // delta maybe negative if ttlNanos is small and this method sees newly added items while looping. + return delta <= 0 ? ttlNanos : ttlNanos - (nanoTime - item.timeStamp); } } } + + private void timerFire() { + long scheduleTimerNanos; + if (tryAcquireLock(queueLockUpdater, this)) { + scheduleTimerNanos = doExpire(); + if (!releaseLock(queueLockUpdater, this)) { + scheduleTimerNanos = tryDrainQueue(); + } + } else { + scheduleTimerNanos = tryDrainQueue(); + } + + if (scheduleTimerNanos >= 0) { + schedulerTimer(scheduleTimerNanos); + } + } } private static final class TimeStampSignal { diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java index 235b6bac3d..e6285ecd58 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ReplayPublisherTest.java @@ -15,6 +15,8 @@ */ package io.servicetalk.concurrent.api; +import io.servicetalk.concurrent.Cancellable; +import io.servicetalk.concurrent.PublisherSource; import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; import org.junit.jupiter.api.AfterEach; @@ -23,8 +25,13 @@ import org.junit.jupiter.params.provider.ValueSource; import java.time.Duration; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Function; import javax.annotation.Nullable; @@ -32,6 +39,7 @@ import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; import static java.time.Duration.ofMillis; +import static java.time.Duration.ofNanos; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; @@ -209,6 +217,69 @@ void threeSubscribersTTL(boolean onError) { threeSubscribersTerminate(onError); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void concurrentTTL(boolean onError) throws Exception { + final Duration ttl = ofNanos(1); + final int queueLimit = Integer.MAX_VALUE; + Executor executor2 = Executors.newCachedThreadExecutor(); + ScheduleQueueExecutor queueExecutor = new ScheduleQueueExecutor(executor2); + Publisher publisher = source.replay( + ReplayStrategies.historyTtlBuilder(2, ttl, queueExecutor) + .queueLimitHint(queueLimit).build()); + try { + toSource(publisher).subscribe(subscriber1); + toSource(publisher).subscribe(subscriber2); + subscriber1.awaitSubscription().request(Long.MAX_VALUE); + subscriber2.awaitSubscription().request(Long.MAX_VALUE); + subscription.awaitRequestN(queueLimit); + // The goal is to race onNext (which calls accumulate) with the timer expiration. We don't verify all the + // signals are delivered but instead verify that the timer and max elements are always enforced even after + // the concurrent operations. + for (int i = 0; i < 10000; ++i) { + source.onNext(1); + Thread.yield(); // Increase likelihood that timer expires some signals. + } + + // Wait for the timer to expire all signals. + waitForReplayQueueToDrain(publisher); + + queueExecutor.enableScheduleQueue(); + source.onNext(2, 3); + toSource(publisher).subscribe(subscriber3); + subscriber3.awaitSubscription().request(Long.MAX_VALUE); + assertThat(subscriber3.takeOnNext(2), contains(2, 3)); + + // Test that advancing the timer past expiration still expires events and there were no race conditions + queueExecutor.drainScheduleQueue(); + waitForReplayQueueToDrain(publisher); + + // We don't consume signals for subscriber1 and subscriber2, so just test termination of subscriber3. + if (onError) { + source.onError(DELIBERATE_EXCEPTION); + assertThat(subscriber3.awaitOnError(), is(DELIBERATE_EXCEPTION)); + } else { + source.onComplete(); + subscriber3.awaitOnComplete(); + } + } finally { + executor2.closeAsync().toFuture().get(); + } + } + + private void waitForReplayQueueToDrain(Publisher publisher) throws InterruptedException { + boolean waitForAccumulatorToDrain; + do { + Thread.sleep(1); + TestPublisherSubscriber subscriber5 = new TestPublisherSubscriber<>(); + toSource(publisher).subscribe(subscriber5); + PublisherSource.Subscription subscription5 = subscriber5.awaitSubscription(); + subscription5.request(Long.MAX_VALUE); + waitForAccumulatorToDrain = subscriber5.pollOnNext(10, MILLISECONDS) != null; + subscription5.cancel(); + } while (waitForAccumulatorToDrain); + } + @ParameterizedTest(name = "{displayName} [{index}] expectedSubscribers={0} expectedSum={1}") @CsvSource(value = {"500,500", "50,50", "50,500", "500,50"}) void concurrentSubscribes(final int expectedSubscribers, final long expectedSum) throws Exception { @@ -306,4 +377,82 @@ public void deliverAccumulation(final Consumer consumer) { } } } + + private static final class ScheduleHolder implements Cancellable { + final Duration duration; + final Runnable task; + final AtomicBoolean isCancelled = new AtomicBoolean(); + + ScheduleHolder(final long duration, final TimeUnit unit, final Runnable task) { + this(Duration.ofNanos(unit.toNanos(duration)), task); + } + + ScheduleHolder(final Duration duration, final Runnable task) { + this.duration = duration; + this.task = task; + } + + @Override + public void cancel() { + isCancelled.set(true); + } + } + + private static final class ScheduleQueueExecutor implements io.servicetalk.concurrent.Executor { + private final io.servicetalk.concurrent.Executor executor; + private final AtomicBoolean enableScheduleQueue = new AtomicBoolean(); + private final Queue scheduleQueue = new ConcurrentLinkedQueue<>(); + + private ScheduleQueueExecutor(final io.servicetalk.concurrent.Executor executor) { + this.executor = executor; + } + + void enableScheduleQueue() { + enableScheduleQueue.set(true); + } + + void drainScheduleQueue() { + if (enableScheduleQueue.compareAndSet(true, false)) { + ScheduleHolder item; + while ((item = scheduleQueue.poll()) != null) { + if (item.isCancelled.compareAndSet(false, true)) { + executor.schedule(item.task, item.duration); + } + } + } + } + + @Override + public long currentTime(final TimeUnit unit) { + return executor.currentTime(unit); + } + + @Override + public Cancellable execute(final Runnable task) throws RejectedExecutionException { + return executor.execute(task); + } + + @Override + public Cancellable schedule(final Runnable task, final long delay, final TimeUnit unit) + throws RejectedExecutionException { + if (enableScheduleQueue.get()) { + ScheduleHolder holder = new ScheduleHolder(delay, unit, task); + scheduleQueue.add(holder); + return holder; + } else { + return executor.schedule(task, delay, unit); + } + } + + @Override + public Cancellable schedule(final Runnable task, final Duration delay) throws RejectedExecutionException { + if (enableScheduleQueue.get()) { + ScheduleHolder holder = new ScheduleHolder(delay, task); + scheduleQueue.add(holder); + return holder; + } else { + return executor.schedule(task, delay); + } + } + } }