Skip to content

Commit

Permalink
Simplify JDBC result set iterator implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Oct 9, 2024
1 parent 5739e17 commit 1c775ac
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 364 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.jdbc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -62,7 +63,6 @@
import java.time.ZonedDateTime;
import java.util.Calendar;
import java.util.GregorianCalendar;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -184,7 +184,7 @@ abstract class AbstractTrinoResultSet
return result;
})
.build();
protected final Iterator<List<Object>> results;
protected final CancellableIterator<List<Object>> results;
private final Map<String, Integer> fieldMap;
private final List<ColumnInfo> columnInfoList;
private final ResultSetMetaData resultSetMetaData;
Expand All @@ -193,7 +193,9 @@ abstract class AbstractTrinoResultSet
private final AtomicBoolean wasNull = new AtomicBoolean();
private final Optional<Statement> statement;

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, Iterator<List<Object>> results)
private boolean closed;

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, CancellableIterator<List<Object>> results)
{
this.statement = requireNonNull(statement, "statement is null");
requireNonNull(columns, "columns is null");
Expand Down Expand Up @@ -1827,6 +1829,19 @@ public <T> T getObject(String columnLabel, Class<T> type)
return getObject(columnIndex(columnLabel), type);
}

@Override
public void close()
throws SQLException
{
synchronized (this) {
if (closed) {
return;
}
closed = true;
}
results.cancel();
}

@SuppressWarnings("unchecked")
@Override
public <T> T unwrap(Class<T> iface)
Expand Down Expand Up @@ -2210,4 +2225,10 @@ public ParsedTimestamp(int year, int month, int day, int hour, int minute, int s
this.timezone = requireNonNull(timezone, "timezone is null");
}
}

abstract static class CancellableIterator<T>
extends AbstractIterator<T>
{
abstract void cancel();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
import io.trino.client.Column;

import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.Objects.requireNonNull;

public class InMemoryTrinoResultSet
extends AbstractTrinoResultSet
{
private final AtomicBoolean closed = new AtomicBoolean();

public InMemoryTrinoResultSet(List<Column> columns, List<List<Object>> results)
{
super(Optional.empty(), columns, results.iterator());
super(Optional.empty(), columns, new ListIterator(results));
}

@Override
Expand All @@ -43,4 +46,29 @@ public boolean isClosed()
{
return closed.get();
}

private static class ListIterator
extends CancellableIterator<List<Object>>
{
private final Iterator<List<Object>> iterator;

public ListIterator(List<List<Object>> results)
{
this.iterator = requireNonNull(results, "results is null").iterator();
}

@Override
protected List<Object> computeNext()
{
if (iterator.hasNext()) {
return iterator.next();
}
return endOfData();
}

@Override
void cancel()
{
}
}
}
140 changes: 67 additions & 73 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
package io.trino.jdbc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Streams;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.trino.client.Column;
Expand All @@ -24,17 +22,16 @@

import java.sql.SQLException;
import java.sql.Statement;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.function.Consumer;
import java.util.stream.Stream;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -68,7 +65,7 @@ private TrinoResultSet(Statement statement, StatementClient client, List<Column>
super(
Optional.of(requireNonNull(statement, "statement is null")),
columns,
new AsyncIterator<>(flatten(new ResultsPageIterator(requireNonNull(client, "client is null"), progressCallback, warningsManager), maxRows), client));
limitingIterator(new AsyncResultsPageIterator(requireNonNull(client, "client is null"), progressCallback, warningsManager, Optional.empty()), maxRows));

this.statement = statement;
this.client = requireNonNull(client, "client is null");
Expand Down Expand Up @@ -115,7 +112,7 @@ public void close()
closeStatement = closeStatementOnClose;
}

((AsyncIterator<?>) results).cancel();
super.close();
client.close();
if (closeStatement) {
statement.close();
Expand All @@ -134,54 +131,58 @@ void partialCancel()
client.cancelLeafStage();
}

private static <T> Iterator<T> flatten(Iterator<Iterable<T>> iterator, long maxRows)
{
Stream<T> stream = Streams.stream(iterator)
.flatMap(Streams::stream);
if (maxRows > 0) {
stream = stream.limit(maxRows);
}
return stream.iterator();
}

@VisibleForTesting
static class AsyncIterator<T>
extends AbstractIterator<T>
static class AsyncResultsPageIterator
extends CancellableIterator<List<Object>>
{
private static final int MAX_QUEUED_ROWS = 50_000;
private static final ExecutorService executorService = newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build());

private final StatementClient client;
private final BlockingQueue<T> rowQueue;
private final BlockingQueue<List<Object>> rowQueue;
// Semaphore to indicate that some data is ready.
// Each permit represents a row of data (or that the underlying iterator is exhausted).
private final Semaphore semaphore = new Semaphore(0);
private final Future<?> future;

private volatile boolean cancelled;
private volatile boolean finished;

public AsyncIterator(Iterator<T> dataIterator, StatementClient client)
AsyncResultsPageIterator(StatementClient client, Consumer<QueryStats> progressCallback, WarningsManager warningsManager, Optional<BlockingQueue<List<Object>>> queue)
{
this(dataIterator, client, Optional.empty());
}
requireNonNull(progressCallback, "progressCallback is null");
requireNonNull(warningsManager, "warningsManager is null");

@VisibleForTesting
AsyncIterator(Iterator<T> dataIterator, StatementClient client, Optional<BlockingQueue<T>> queue)
{
requireNonNull(dataIterator, "dataIterator is null");
this.client = client;
this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS));
this.cancelled = false;
this.finished = false;
this.future = executorService.submit(() -> {
try {
while (!cancelled && dataIterator.hasNext()) {
rowQueue.put(dataIterator.next());
semaphore.release();
do {
QueryStatusInfo results = client.currentStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
Iterable<List<Object>> data = client.currentData().getData();
if (data != null) {
for (List<Object> row : data) {
rowQueue.put(row);
semaphore.release();
}
}
}
while (!cancelled && client.advance());

verify(client.isFinished());
QueryStatusInfo results = client.finalStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
if (results.getError() != null) {
throw new RuntimeException(resultsException(results));
}
}
catch (InterruptedException e) {
catch (CancellationException | InterruptedException e) {
client.close();
rowQueue.clear();
throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e));
Expand All @@ -193,9 +194,15 @@ public AsyncIterator(Iterator<T> dataIterator, StatementClient client)
});
}

@Override
public void cancel()
{
cancelled = true;
synchronized (this) {
if (cancelled) {
return;
}
cancelled = true;
}
future.cancel(true);
// When thread interruption is mis-handled by underlying implementation of `client`, the thread which
// is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish
Expand All @@ -217,7 +224,7 @@ boolean isBackgroundThreadFinished()
}

@Override
protected T computeNext()
protected List<Object> computeNext()
{
try {
semaphore.acquire();
Expand Down Expand Up @@ -251,47 +258,6 @@ private void handleInterrupt(InterruptedException e)
}
}

private static class ResultsPageIterator
extends AbstractIterator<Iterable<List<Object>>>
{
private final StatementClient client;
private final Consumer<QueryStats> progressCallback;
private final WarningsManager warningsManager;

private ResultsPageIterator(StatementClient client, Consumer<QueryStats> progressCallback, WarningsManager warningsManager)
{
this.client = requireNonNull(client, "client is null");
this.progressCallback = requireNonNull(progressCallback, "progressCallback is null");
this.warningsManager = requireNonNull(warningsManager, "warningsManager is null");
}

@Override
protected Iterable<List<Object>> computeNext()
{
while (client.isRunning()) {
QueryStatusInfo results = client.currentStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
Iterable<List<Object>> data = client.currentData().getData();
if (!client.advance() && data == null) {
break; // No more rows, query finished
}
if (data != null) {
return data;
}
}

verify(client.isFinished());
QueryStatusInfo results = client.finalStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
if (results.getError() != null) {
throw new RuntimeException(resultsException(results));
}
return endOfData();
}
}

private static List<Column> getColumns(StatementClient client, Consumer<QueryStats> progressCallback)
throws SQLException
{
Expand All @@ -312,4 +278,32 @@ private static List<Column> getColumns(StatementClient client, Consumer<QuerySta
}
throw resultsException(results);
}

private static CancellableIterator<List<Object>> limitingIterator(CancellableIterator<List<Object>> iterator, long maxRows)
{
return new CancellableIterator<List<Object>>()
{
private long currentRow;

@Override
void cancel()
{
iterator.cancel();
}

@Override
protected List<Object> computeNext()
{
if (maxRows > 0 && currentRow >= maxRows) {
cancel();
return endOfData();
}
currentRow++;
if (iterator.hasNext()) {
return iterator.next();
}
return endOfData();
}
};
}
}
Loading

0 comments on commit 1c775ac

Please sign in to comment.