Skip to content

Commit

Permalink
fix: Ensures BaseSubscriber.makeRequest is called on context in Polla…
Browse files Browse the repository at this point in the history
…bleSubscriber (#7212)

* fix: Ensures BaseSubscriber.makeRequest is called on context
  • Loading branch information
AlanConfluent committed Mar 13, 2021
1 parent 53d8263 commit da67bd9
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.reactivestreams.Subscription;

public class PollableSubscriber extends BaseSubscriber<Row> {

private static final int REQUEST_BATCH_SIZE = 100;
// 100ms in ns
private static final long MAX_POLL_NANOS = TimeUnit.MILLISECONDS.toNanos(100);
static final int REQUEST_BATCH_SIZE = 100;

private final BlockingQueue<Row> queue = new LinkedBlockingQueue<>();
private final Consumer<Throwable> errorHandler;
private int tokens;
private final AtomicInteger tokens = new AtomicInteger(0);
private volatile boolean complete;
private volatile boolean closed;
private volatile boolean failed;
Expand Down Expand Up @@ -86,8 +86,8 @@ public synchronized Row poll(final Duration timeout) {
try {
final Row row = queue.poll(pollTime, TimeUnit.NANOSECONDS);
if (row != null) {
tokens--;
checkRequestTokens();
tokens.decrementAndGet();
context.runOnContext(v -> checkRequestTokens());
return row;
} else if (complete) {
// If complete, close once the queue has been emptied
Expand All @@ -110,8 +110,7 @@ synchronized boolean isClosed() {
}

private void checkRequestTokens() {
if (tokens == 0) {
tokens += REQUEST_BATCH_SIZE;
if (tokens.compareAndSet(0, REQUEST_BATCH_SIZE)) {
makeRequest(REQUEST_BATCH_SIZE);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2021 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.api.client.impl;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.api.client.Row;
import io.confluent.ksql.reactive.BufferedPublisher;
import io.vertx.core.Context;
import io.vertx.core.Vertx;
import io.vertx.core.json.JsonArray;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.reactivestreams.Publisher;

public class PollableSubscriberTest {
private static Duration POLL_DURATION = Duration.ofMillis(100);
private static String COLUMN_NAME = "id";

private Publisher<Row> publisher;

private Vertx vertx;
private Throwable throwable;
private Context context;
private PollableSubscriber pollableSubscriber;

@Before
public void setUp() {
vertx = Vertx.vertx();
context = vertx.getOrCreateContext();
pollableSubscriber = new PollableSubscriber(context, t -> throwable = t);
}

@Test
public void shouldPollSingleBatch() {
shouldPollRows(PollableSubscriber.REQUEST_BATCH_SIZE - 10);
}

@Test
public void shouldPollMultiBatch() {
shouldPollRows(PollableSubscriber.REQUEST_BATCH_SIZE + 10);
}

@Test
public void shouldSetError() {
publisher = new BufferedPublisher<Row>(context, ImmutableList.of()) {
@Override
protected void maybeSend() {
sendError(new RuntimeException("Error!"));
}
};

publisher.subscribe(pollableSubscriber);

Row row = pollableSubscriber.poll(POLL_DURATION);
assertThat(row, is(nullValue()));
assertThat(throwable, is(notNullValue()));
assertThat(throwable.getMessage(), is("Error!"));
}

private void shouldPollRows(int numRows) {
final List<Row> rows = new ArrayList<>();
for (int i = 0; i < numRows; i++) {
rows.add(createRow(i));
}

publisher = new BufferedPublisher<>(context, rows);

publisher.subscribe(pollableSubscriber);

Row row = pollableSubscriber.poll(POLL_DURATION);
int i = 0;
for (; row != null; i++) {
Long col1 = row.getLong(COLUMN_NAME);
assertThat(col1, is((long) i));
row = pollableSubscriber.poll(POLL_DURATION);
}
assertThat(i, is(numRows));
assertThat(throwable, is(nullValue()));
}

private Row createRow(long id) {
return new RowImpl(
ImmutableList.of(COLUMN_NAME),
ImmutableList.of(new ColumnTypeImpl("BIGINT")),
new JsonArray().add(id),
ImmutableMap.of(COLUMN_NAME, 1));
}
}

0 comments on commit da67bd9

Please sign in to comment.