From da67bd91062a58e93842579f0919092245ab630f Mon Sep 17 00:00:00 2001 From: Alan Sheinberg <57688982+AlanConfluent@users.noreply.github.com> Date: Fri, 12 Mar 2021 17:07:43 -0800 Subject: [PATCH] fix: Ensures BaseSubscriber.makeRequest is called on context in PollableSubscriber (#7212) * fix: Ensures BaseSubscriber.makeRequest is called on context --- .../api/client/impl/PollableSubscriber.java | 13 +-- .../client/impl/PollableSubscriberTest.java | 110 ++++++++++++++++++ 2 files changed, 116 insertions(+), 7 deletions(-) create mode 100644 ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/impl/PollableSubscriberTest.java diff --git a/ksqldb-api-client/src/main/java/io/confluent/ksql/api/client/impl/PollableSubscriber.java b/ksqldb-api-client/src/main/java/io/confluent/ksql/api/client/impl/PollableSubscriber.java index 956c01508872..80a071959367 100644 --- a/ksqldb-api-client/src/main/java/io/confluent/ksql/api/client/impl/PollableSubscriber.java +++ b/ksqldb-api-client/src/main/java/io/confluent/ksql/api/client/impl/PollableSubscriber.java @@ -23,18 +23,18 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import org.reactivestreams.Subscription; public class PollableSubscriber extends BaseSubscriber { - - private static final int REQUEST_BATCH_SIZE = 100; // 100ms in ns private static final long MAX_POLL_NANOS = TimeUnit.MILLISECONDS.toNanos(100); + static final int REQUEST_BATCH_SIZE = 100; private final BlockingQueue queue = new LinkedBlockingQueue<>(); private final Consumer errorHandler; - private int tokens; + private final AtomicInteger tokens = new AtomicInteger(0); private volatile boolean complete; private volatile boolean closed; private volatile boolean failed; @@ -86,8 +86,8 @@ public synchronized Row poll(final Duration timeout) { try { final Row row = queue.poll(pollTime, TimeUnit.NANOSECONDS); if (row != null) { - tokens--; - checkRequestTokens(); + tokens.decrementAndGet(); + context.runOnContext(v -> checkRequestTokens()); return row; } else if (complete) { // If complete, close once the queue has been emptied @@ -110,8 +110,7 @@ synchronized boolean isClosed() { } private void checkRequestTokens() { - if (tokens == 0) { - tokens += REQUEST_BATCH_SIZE; + if (tokens.compareAndSet(0, REQUEST_BATCH_SIZE)) { makeRequest(REQUEST_BATCH_SIZE); } } diff --git a/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/impl/PollableSubscriberTest.java b/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/impl/PollableSubscriberTest.java new file mode 100644 index 000000000000..7e36c2f804b7 --- /dev/null +++ b/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/impl/PollableSubscriberTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.api.client.impl; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.api.client.Row; +import io.confluent.ksql.reactive.BufferedPublisher; +import io.vertx.core.Context; +import io.vertx.core.Vertx; +import io.vertx.core.json.JsonArray; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Publisher; + +public class PollableSubscriberTest { + private static Duration POLL_DURATION = Duration.ofMillis(100); + private static String COLUMN_NAME = "id"; + + private Publisher publisher; + + private Vertx vertx; + private Throwable throwable; + private Context context; + private PollableSubscriber pollableSubscriber; + + @Before + public void setUp() { + vertx = Vertx.vertx(); + context = vertx.getOrCreateContext(); + pollableSubscriber = new PollableSubscriber(context, t -> throwable = t); + } + + @Test + public void shouldPollSingleBatch() { + shouldPollRows(PollableSubscriber.REQUEST_BATCH_SIZE - 10); + } + + @Test + public void shouldPollMultiBatch() { + shouldPollRows(PollableSubscriber.REQUEST_BATCH_SIZE + 10); + } + + @Test + public void shouldSetError() { + publisher = new BufferedPublisher(context, ImmutableList.of()) { + @Override + protected void maybeSend() { + sendError(new RuntimeException("Error!")); + } + }; + + publisher.subscribe(pollableSubscriber); + + Row row = pollableSubscriber.poll(POLL_DURATION); + assertThat(row, is(nullValue())); + assertThat(throwable, is(notNullValue())); + assertThat(throwable.getMessage(), is("Error!")); + } + + private void shouldPollRows(int numRows) { + final List rows = new ArrayList<>(); + for (int i = 0; i < numRows; i++) { + rows.add(createRow(i)); + } + + publisher = new BufferedPublisher<>(context, rows); + + publisher.subscribe(pollableSubscriber); + + Row row = pollableSubscriber.poll(POLL_DURATION); + int i = 0; + for (; row != null; i++) { + Long col1 = row.getLong(COLUMN_NAME); + assertThat(col1, is((long) i)); + row = pollableSubscriber.poll(POLL_DURATION); + } + assertThat(i, is(numRows)); + assertThat(throwable, is(nullValue())); + } + + private Row createRow(long id) { + return new RowImpl( + ImmutableList.of(COLUMN_NAME), + ImmutableList.of(new ColumnTypeImpl("BIGINT")), + new JsonArray().add(id), + ImmutableMap.of(COLUMN_NAME, 1)); + } +}