Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify JDBC result set iterator implementation #23713

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import io.trino.client.Column;
import io.trino.client.IntervalDayTime;
import io.trino.client.IntervalYearMonth;
import io.trino.client.QueryError;
import io.trino.client.QueryStatusInfo;
import io.trino.jdbc.ColumnInfo.Nullable;
import io.trino.jdbc.TypeConversions.NoConversionRegisteredException;
import org.joda.time.DateTimeZone;
Expand Down Expand Up @@ -62,7 +60,6 @@
import java.time.ZonedDateTime;
import java.util.Calendar;
import java.util.GregorianCalendar;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -184,7 +181,7 @@ abstract class AbstractTrinoResultSet
return result;
})
.build();
protected final Iterator<List<Object>> results;
protected final CancellableIterator<List<Object>> results;
private final Map<String, Integer> fieldMap;
private final List<ColumnInfo> columnInfoList;
private final ResultSetMetaData resultSetMetaData;
Expand All @@ -193,7 +190,9 @@ abstract class AbstractTrinoResultSet
private final AtomicBoolean wasNull = new AtomicBoolean();
private final Optional<Statement> statement;

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, Iterator<List<Object>> results)
private final AtomicBoolean closed = new AtomicBoolean();

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, CancellableIterator<List<Object>> results)
{
this.statement = requireNonNull(statement, "statement is null");
requireNonNull(columns, "columns is null");
Expand Down Expand Up @@ -1827,6 +1826,15 @@ public <T> T getObject(String columnLabel, Class<T> type)
return getObject(columnIndex(columnLabel), type);
}

@Override
public void close()
throws SQLException
{
if (closed.compareAndSet(false, true)) {
results.cancel();
}
}

@SuppressWarnings("unchecked")
@Override
public <T> T unwrap(Class<T> iface)
Expand Down Expand Up @@ -1929,14 +1937,6 @@ private static Optional<BigDecimal> toBigDecimal(String value)
}
}

static SQLException resultsException(QueryStatusInfo results)
{
QueryError error = requireNonNull(results.getError());
String message = format("Query failed (#%s): %s", results.getId(), error.getMessage());
Throwable cause = (error.getFailureInfo() == null) ? null : error.getFailureInfo().toException();
return new SQLException(message, error.getSqlState(), error.getErrorCode(), cause);
}

private static Map<String, Integer> getFieldMap(List<Column> columns)
{
Map<String, Integer> map = Maps.newHashMapWithExpectedSize(columns.size());
Expand Down
169 changes: 169 additions & 0 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/AsyncResultIterator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.AbstractIterator;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.trino.client.QueryStatusInfo;
import io.trino.client.StatementClient;

import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.function.Consumer;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static io.trino.jdbc.ResultUtils.resultsException;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;

public class AsyncResultIterator
extends AbstractIterator<List<Object>>
implements CancellableIterator<List<Object>>
{
private static final int MAX_QUEUED_ROWS = 50_000;
private static final ExecutorService executorService = newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build());

private final StatementClient client;
private final BlockingQueue<List<Object>> rowQueue;
// Semaphore to indicate that some data is ready.
// Each permit represents a row of data (or that the underlying iterator is exhausted).
private final Semaphore semaphore = new Semaphore(0);
private final Future<?> future;

private volatile boolean cancelled;
private volatile boolean finished;

AsyncResultIterator(StatementClient client, Consumer<QueryStats> progressCallback, WarningsManager warningsManager, Optional<BlockingQueue<List<Object>>> queue)
{
requireNonNull(progressCallback, "progressCallback is null");
requireNonNull(warningsManager, "warningsManager is null");

this.client = client;
this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS));
this.cancelled = false;
this.finished = false;
this.future = executorService.submit(() -> {
try {
do {
QueryStatusInfo results = client.currentStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
Iterable<List<Object>> data = client.currentData().getData();
if (data != null) {
for (List<Object> row : data) {
rowQueue.put(row);
semaphore.release();
}
}
}
while (!cancelled && client.advance());

verify(client.isFinished());
QueryStatusInfo results = client.finalStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
if (results.getError() != null) {
throw new RuntimeException(resultsException(results));
}
}
catch (CancellationException | InterruptedException e) {
close();
throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e));
}
finally {
finished = true;
semaphore.release();
}
});
}

@Override
public void cancel()
{
synchronized (this) {
if (cancelled) {
return;
}
cancelled = true;
}
future.cancel(true);
close();
}

private void close()
{
// When thread interruption is mis-handled by underlying implementation of `client`, the thread which
// is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish
// its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks.
client.close();
rowQueue.clear();
}

@VisibleForTesting
Future<?> getFuture()
{
return future;
}

@VisibleForTesting
boolean isBackgroundThreadFinished()
{
return finished;
}

@Override
protected List<Object> computeNext()
{
try {
semaphore.acquire();
}
catch (InterruptedException e) {
handleInterrupt(e);
}
if (rowQueue.isEmpty()) {
// If we got here and the queue is empty the thread fetching from the underlying iterator is done.
// Wait for Future to marked done and check status.
try {
future.get();
}
catch (InterruptedException e) {
handleInterrupt(e);
}
catch (ExecutionException e) {
throwIfUnchecked(e.getCause());
throw new RuntimeException(e.getCause());
}
return endOfData();
}
return rowQueue.poll();
}

private void handleInterrupt(InterruptedException e)
{
cancel();
Thread.currentThread().interrupt();
throw new RuntimeException(new SQLException("Interrupted", e));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import java.util.Iterator;

interface CancellableIterator<T>
extends Iterator<T>
{
void cancel();

static <T> CancellableIterator<T> wrap(Iterator<T> iterator)
{
return new CancellableIterator<T>() {
@Override
public void cancel()
{
// noop
}

@Override
public boolean hasNext()
{
return iterator.hasNext();
}

@Override
public T next()
{
return iterator.next();
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import com.google.common.collect.AbstractIterator;

import static java.util.Objects.requireNonNull;

public class CancellableLimitingIterator<T>
extends AbstractIterator<T>
implements CancellableIterator<T>
{
private final long maxRows;
private final CancellableIterator<T> delegate;
private long currentRow;

CancellableLimitingIterator(CancellableIterator<T> delegate, long maxRows)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.maxRows = maxRows;
}

@Override
public void cancel()
{
delegate.cancel();
}

@Override
protected T computeNext()
{
if (maxRows > 0 && currentRow >= maxRows) {
cancel();
return endOfData();
}
currentRow++;
if (delegate.hasNext()) {
return delegate.next();
}
return endOfData();
}

static <T> CancellableIterator<T> limit(CancellableIterator<T> delegate, long maxRows)
{
return new CancellableLimitingIterator<>(delegate, maxRows);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.trino.jdbc.CancellableIterator.wrap;

public class InMemoryTrinoResultSet
extends AbstractTrinoResultSet
{
private final AtomicBoolean closed = new AtomicBoolean();

public InMemoryTrinoResultSet(List<Column> columns, List<List<Object>> results)
{
super(Optional.empty(), columns, results.iterator());
super(Optional.empty(), columns, wrap(results.iterator()));
}

@Override
Expand Down
35 changes: 35 additions & 0 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/ResultUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import io.trino.client.QueryError;
import io.trino.client.QueryStatusInfo;

import java.sql.SQLException;

import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

class ResultUtils
{
private ResultUtils() {}

static SQLException resultsException(QueryStatusInfo results)
{
QueryError error = requireNonNull(results.getError());
String message = format("Query failed (#%s): %s", results.getId(), error.getMessage());
Throwable cause = (error.getFailureInfo() == null) ? null : error.getFailureInfo().toException();
return new SQLException(message, error.getSqlState(), error.getErrorCode(), cause);
}
}
Loading