diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java index 73e90cece4..9008bcc642 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java @@ -258,7 +258,7 @@ enum GetObjectsDepth { * * * - * + * * *
Field Name Field Type
db_schema_name utf8
db_schema_tables list[TABLE_SCHEMA]
db_schema_statistics list[STATISTICS_SCHEMA]
The definition of DB_SCHEMA_SCHEMA.
* @@ -268,7 +268,7 @@ enum GetObjectsDepth { * Field Name Field Type Comments * table_name utf8 not null * column_name utf8 (1) - * statistic_key int16 (2) + * statistic_key int16 not null (2) * statistic_value VALUE_SCHEMA not null * statistic_is_approximatebool not null (3) * The definition of STATISTICS_SCHEMA. @@ -291,7 +291,6 @@ enum GetObjectsDepth { * int64 int64 * uint64 uint64 * float64 float64 - * decimal256 decimal256 * binary binary * The definition of VALUE_SCHEMA. * @@ -314,6 +313,18 @@ default ArrowReader getStatistics( throw AdbcException.notImplemented("Connection does not support getStatistics()"); } + /** + * Get the names of additional statistics defined by this driver. + * + *

The result is an Arrow dataset with the following schema: + * + * + * + * + * + * + *
Field Name Field Type
statistic_name utf8 not null
statistic_key int16 not null
The definition of the GetStatistics result schema.
+ */ default ArrowReader getStatisticNames() throws AdbcException { throw AdbcException.notImplemented("Connection does not support getStatisticNames()"); } @@ -408,7 +419,7 @@ default String getCurrentDbSchema() throws AdbcException { * * @since ADBC API revision 1.1.0 */ - default void setCurrentDbSchema(String catalog) throws AdbcException { + default void setCurrentDbSchema(String dbSchema) throws AdbcException { throw AdbcException.notImplemented("Connection does not support current catalog"); } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java index 9386b88089..5e32fd1ed7 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcDriver.java @@ -26,14 +26,14 @@ public interface AdbcDriver { * * @since ADBC API revision 1.1.0 */ - AdbcOptionKey PARAM_PASSWORD = new AdbcOptionKey<>("password", String.class); + TypedKey PARAM_PASSWORD = new TypedKey<>("password", String.class); /** * The standard parameter name for a connection URI (type String). * * @since ADBC API revision 1.1.0 */ - AdbcOptionKey PARAM_URI = new AdbcOptionKey<>("uri", String.class); + TypedKey PARAM_URI = new TypedKey<>("uri", String.class); /** * The standard parameter name for a connection URL (type String). @@ -47,7 +47,7 @@ public interface AdbcDriver { * * @since ADBC API revision 1.1.0 */ - AdbcOptionKey PARAM_USERNAME = new AdbcOptionKey<>("username", String.class); + TypedKey PARAM_USERNAME = new TypedKey<>("username", String.class); /** The standard parameter name for SQL quirks configuration (type SqlQuirks). */ String PARAM_SQL_QUIRKS = "adbc.sql.quirks"; diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java index f909addf68..dce7570e3d 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcException.java @@ -16,7 +16,6 @@ */ package org.apache.arrow.adbc.core; -import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; @@ -37,7 +36,7 @@ public class AdbcException extends Exception { private final AdbcStatusCode status; private final String sqlState; private final int vendorCode; - private Collection details; + private Collection details; public AdbcException( String message, Throwable cause, AdbcStatusCode status, String sqlState, int vendorCode) { @@ -50,7 +49,7 @@ public AdbcException( AdbcStatusCode status, String sqlState, int vendorCode, - Collection details) { + Collection details) { super(message, cause); this.status = status; this.sqlState = sqlState; @@ -94,14 +93,14 @@ public int getVendorCode() { } /** - * Get extra driver-specific binary error details. + * Get extra driver-specific error details. * *

This allows drivers to return custom, structured error information (for example, JSON or * Protocol Buffers) that can be optionally parsed by clients, beyond the standard AdbcError * fields, without having to encode it in the error message. The encoding of the data is * driver-defined. */ - public Collection getDetails() { + public Collection getDetails() { return details; } @@ -115,7 +114,7 @@ public AdbcException withCause(Throwable cause) { /** * Copy this exception with different details (a convenience for use with the static factories). */ - public AdbcException withDetails(Collection details) { + public AdbcException withDetails(Collection details) { return new AdbcException(getMessage(), getCause(), status, sqlState, vendorCode, details); } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java index efd8eab77b..5a8e78b08f 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java @@ -27,7 +27,7 @@ public interface AdbcOptions { * @return The option value. * @param The option value type. */ - default T getOption(AdbcOptionKey key) throws AdbcException { + default T getOption(TypedKey key) throws AdbcException { throw AdbcException.notImplemented("Unsupported option " + key); } @@ -39,7 +39,7 @@ default T getOption(AdbcOptionKey key) throws AdbcException { * @param value The option value. * @param The option value type. */ - default void setOption(AdbcOptionKey key, T value) throws AdbcException { + default void setOption(TypedKey key, T value) throws AdbcException { throw AdbcException.notImplemented("Unsupported option " + key); } } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java index 27708e1bbf..07c7eab126 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java @@ -58,7 +58,7 @@ default void cancel() throws AdbcException { /** * Set a generic query option. * - * @deprecated Prefer {@link #setOption(AdbcOptionKey, Object)}. + * @deprecated Prefer {@link #setOption(TypedKey, Object)}. */ default void setOption(String key, Object value) throws AdbcException { throw AdbcException.notImplemented("Unsupported option " + key); diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java new file mode 100644 index 0000000000..13521fb82e --- /dev/null +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.core; + +import java.util.Objects; + +/** Additional details (not necessarily human-readable) contained in an {@link AdbcException}. */ +public class ErrorDetail { + private final String key; + private final Object value; + + public ErrorDetail(String key, Object value) { + this.key = Objects.requireNonNull(key); + this.value = Objects.requireNonNull(value); + } + + public String getKey() { + return key; + } + + public Object getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ErrorDetail that = (ErrorDetail) o; + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + @Override + public int hashCode() { + return Objects.hash(getKey(), getValue()); + } + + @Override + public String toString() { + return "ErrorDetail{" + "key='" + key + '\'' + ", value=" + value + '}'; + } +} diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java index a14c04c700..c1e5594b87 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java @@ -19,6 +19,8 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -30,10 +32,13 @@ private StandardSchemas() { throw new AssertionError("Do not instantiate this class"); } - private static final ArrowType INT16 = new ArrowType.Int(16, true); - private static final ArrowType INT32 = new ArrowType.Int(32, true); - private static final ArrowType INT64 = new ArrowType.Int(64, true); + private static final ArrowType INT16 = Types.MinorType.SMALLINT.getType(); + private static final ArrowType INT32 = Types.MinorType.INT.getType(); + private static final ArrowType INT64 = Types.MinorType.BIGINT.getType(); private static final ArrowType UINT32 = new ArrowType.Int(32, false); + private static final ArrowType UINT64 = new ArrowType.Int(64, false); + private static final ArrowType FLOAT64 = + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); /** The schema of the result set of {@link AdbcConnection#getInfo(int[])}}. */ public static final Schema GET_INFO_SCHEMA = @@ -83,11 +88,11 @@ private StandardSchemas() { Field.notNullable("constraint_type", ArrowType.Utf8.INSTANCE), new Field( "constraint_column_names", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList(Field.nullable("item", new ArrowType.Utf8()))), new Field( "constraint_column_usage", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), USAGE_SCHEMA)))); @@ -119,12 +124,12 @@ private StandardSchemas() { new Field("table_type", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "table_columns", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), COLUMN_SCHEMA))), new Field( "table_constraints", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( "item", FieldType.nullable(ArrowType.Struct.INSTANCE), CONSTRAINT_SCHEMA)))); @@ -134,20 +139,76 @@ private StandardSchemas() { new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "db_schema_tables", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), TABLE_SCHEMA)))); + /** + * The schema of the result of {@link AdbcConnection#getObjects(AdbcConnection.GetObjectsDepth, + * String, String, String, String[], String)}. + */ public static final Schema GET_OBJECTS_SCHEMA = new Schema( Arrays.asList( new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), new Field( "catalog_db_schemas", - FieldType.notNullable(ArrowType.List.INSTANCE), + FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList( new Field( "item", FieldType.nullable(ArrowType.Struct.INSTANCE), DB_SCHEMA_SCHEMA))))); + + public static final List STATISTICS_VALUE_SCHEMA = + Arrays.asList( + Field.nullable("int64", INT64), + Field.nullable("uint64", UINT64), + Field.nullable("float64", FLOAT64), + Field.nullable("binary", ArrowType.Binary.INSTANCE)); + + public static final List STATISTICS_SCHEMA = + Arrays.asList( + Field.notNullable("table_name", ArrowType.Utf8.INSTANCE), + Field.nullable("column_name", ArrowType.Utf8.INSTANCE), + Field.notNullable("statistic_key", INT16), + new Field( + "statistic_value", + FieldType.notNullable(new ArrowType.Union(UnionMode.Dense, new int[] {0, 1, 2, 3})), + STATISTICS_VALUE_SCHEMA), + Field.notNullable("statistic_is_approximate", ArrowType.Bool.INSTANCE)); + + public static final List STATISTICS_DB_SCHEMA_SCHEMA = + Arrays.asList( + new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), + new Field( + "db_schema_statistics", + FieldType.nullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field( + "item", FieldType.nullable(ArrowType.Struct.INSTANCE), STATISTICS_SCHEMA)))); + + /** + * The schema of the result of {@link AdbcConnection#getStatistics(String, String, String, + * boolean)}. + */ + public static final Schema GET_STATISTICS_SCHEMA = + new Schema( + Arrays.asList( + new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null), + new Field( + "catalog_db_schemas", + FieldType.nullable(ArrowType.List.INSTANCE), + Collections.singletonList( + new Field( + "item", + FieldType.nullable(ArrowType.Struct.INSTANCE), + STATISTICS_DB_SCHEMA_SCHEMA))))); + + /** The schema of the result of {@link AdbcConnection#getStatisticNames()}. */ + public static final Schema GET_STATISTIC_NAMES_SCHEMA = + new Schema( + Arrays.asList( + Field.notNullable("statistic_name", ArrowType.Utf8.INSTANCE), + Field.notNullable("statistic_name", INT16))); } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java index 5412c645c3..f5097f4413 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java @@ -32,39 +32,39 @@ public enum StandardStatistics { * *

For example, this is roughly the average length of a string for a string column. */ - AVERAGE_BYTE_WIDTH("adbc.statistic.byte_width", 0), + AVERAGE_BYTE_WIDTH("adbc.statistic.byte_width", (short) 0), /** * The distinct value count (NDV) statistic. The number of distinct values in the column. Value * type is int64 (when not approximate) or float64 (when approximate). */ - DISTINCT_COUNT("adbc.statistic.distinct_count", 1), + DISTINCT_COUNT("adbc.statistic.distinct_count", (short) 1), /** * The max byte width statistic. The maximum size in bytes of a row in the column. Value type is * int64 (when not approximate) or float64 (when approximate). * *

For example, this is the maximum length of a string for a string column. */ - MAX_BYTE_WIDTH("adbc.statistic.byte_width", 2), + MAX_BYTE_WIDTH("adbc.statistic.byte_width", (short) 2), /** The max value statistic. Value type is column-dependent. */ - MAX_VALUE_NAME("adbc.statistic.byte_width", 3), + MAX_VALUE("adbc.statistic.byte_width", (short) 3), /** The min value statistic. Value type is column-dependent. */ - MIN_VALUE_NAME("adbc.statistic.byte_width", 4), + MIN_VALUE("adbc.statistic.byte_width", (short) 4), /** * The null count statistic. The number of values that are null in the column. Value type is int64 * (when not approximate) or float64 (when approximate). */ - NULL_COUNT_NAME("adbc.statistic.null_count", 5), + NULL_COUNT("adbc.statistic.null_count", (short) 5), /** * The row count statistic. The number of rows in the column or table. Value type is int64 (when * not approximate) or float64 (when approximate). */ - ROW_COUNT_NAME("adbc.statistic.row_count", 6), + ROW_COUNT("adbc.statistic.row_count", (short) 6), ; private final String name; - private final int key; + private final short key; - StandardStatistics(String name, int key) { + StandardStatistics(String name, short key) { this.name = Objects.requireNonNull(name); this.key = key; } @@ -75,7 +75,7 @@ public String getName() { } /** Get the dictionary-encoded name. */ - public int getKey() { + public short getKey() { return key; } } diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptionKey.java b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java similarity index 78% rename from java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptionKey.java rename to java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java index d594703688..21523bb429 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptionKey.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java @@ -26,15 +26,33 @@ * @since ADBC API revision 1.1.0 * @param The option value type. */ -public final class AdbcOptionKey { +public final class TypedKey { private final String key; private final Class type; - public AdbcOptionKey(String key, Class type) { + public TypedKey(String key, Class type) { this.key = Objects.requireNonNull(key); this.type = Objects.requireNonNull(type); } + /** Get the option key. */ + public String getKey() { + return key; + } + + /** + * Get the option value (if it was set) and check the type. + * + * @throws ClassCastException if the value is of the wrong type. + */ + public T get(Map options) { + Object value = options.get(key); + if (value == null) { + return null; + } + return type.cast(value); + } + /** * Set this option in an options map (like for {@link AdbcDriver#open(Map)}. * @@ -53,7 +71,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) { return false; } - AdbcOptionKey that = (AdbcOptionKey) o; + TypedKey that = (TypedKey) o; return Objects.equals(key, that.key) && Objects.equals(type, that.type); } diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java index 43a6df99c9..d3f79889ec 100644 --- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java +++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java @@ -47,7 +47,7 @@ public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException String url = getFlightLocation(); final Map parameters = new HashMap<>(); - parameters.put(AdbcDriver.PARAM_URL, url); + AdbcDriver.PARAM_URI.set(parameters, url); return new FlightSqlDriver(allocator).open(parameters); } diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml index 432967963b..9b78b4da24 100644 --- a/java/driver/flight-sql/pom.xml +++ b/java/driver/flight-sql/pom.xml @@ -66,5 +66,17 @@ org.apache.arrow.adbc adbc-sql + + + + org.assertj + assertj-core + test + + + org.junit.jupiter + junit-jupiter + test + diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java index 30fc460b8e..5015ecfcfe 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java @@ -43,17 +43,22 @@ public class FlightSqlDriver implements AdbcDriver { @Override public AdbcDatabase open(Map parameters) throws AdbcException { - Object target = parameters.get("adbc.url"); - if (!(target instanceof String)) { - throw AdbcException.invalidArgument( - "[Flight SQL] Must provide String " + PARAM_URL + " parameter"); + String uri = PARAM_URI.get(parameters); + if (uri == null) { + Object target = parameters.get("adbc.url"); + if (!(target instanceof String)) { + throw AdbcException.invalidArgument( + "[Flight SQL] Must provide String " + PARAM_URI + " parameter"); + } + uri = (String) target; } + Location location; try { - location = new Location((String) target); + location = new Location(uri); } catch (URISyntaxException e) { throw AdbcException.invalidArgument( - String.format("[Flight SQL] Location %s is invalid: %s", target, e)) + String.format("[Flight SQL] Location %s is invalid: %s", uri, e)) .withCause(e); } Object quirks = parameters.get(PARAM_SQL_QUIRKS); diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java index cb6b3038f8..45b42df2ee 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java @@ -17,8 +17,11 @@ package org.apache.arrow.adbc.driver.flightsql; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.core.ErrorDetail; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; @@ -72,7 +75,24 @@ static AdbcStatusCode fromFlightStatusCode(FlightStatusCode code) { } static AdbcException fromFlightException(FlightRuntimeException e) { + List errorDetails = new ArrayList<>(); + for (String key : e.status().metadata().keys()) { + if (key.endsWith("-bin")) { + for (byte[] value : e.status().metadata().getAllByte(key)) { + errorDetails.add(new ErrorDetail(key, value)); + } + } else { + for (String value : e.status().metadata().getAll(key)) { + errorDetails.add(new ErrorDetail(key, value)); + } + } + } return new AdbcException( - e.getMessage(), e.getCause(), fromFlightStatusCode(e.status().code()), null, 0); + e.getMessage(), + e.getCause(), + fromFlightStatusCode(e.status().code()), + null, + 0, + errorDetails); } } diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java new file mode 100644 index 0000000000..c617f664fa --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.grpc.Metadata; +import io.grpc.Status; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.ErrorDetail; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.ErrorFlightMetadata; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Test that gRPC error details make it through. */ +class DetailsTest { + static BufferAllocator allocator; + static Producer producer; + static FlightServer server; + static AdbcDriver driver; + static AdbcDatabase database; + AdbcConnection connection; + AdbcStatement statement; + + @BeforeAll + static void beforeAll() throws Exception { + allocator = new RootAllocator(); + producer = new Producer(); + server = + FlightServer.builder() + .allocator(allocator) + .producer(producer) + .location(Location.forGrpcInsecure("localhost", 0)) + .build(); + server.start(); + driver = new FlightSqlDriver(allocator); + Map parameters = new HashMap<>(); + AdbcDriver.PARAM_URI.set( + parameters, Location.forGrpcInsecure("localhost", server.getPort()).getUri().toString()); + database = driver.open(parameters); + } + + @BeforeEach + void beforeEach() throws Exception { + connection = database.connect(); + statement = connection.createStatement(); + } + + @AfterEach + void afterEach() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterAll + static void afterAll() throws Exception { + AutoCloseables.close(database, server, allocator); + } + + @Test + void flightDetails() throws Exception { + statement.setSqlQuery("flight"); + + AdbcException exception = + assertThrows( + AdbcException.class, + () -> { + try (AdbcStatement.QueryResult result = statement.executeQuery()) {} + }); + + assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text")); + Optional binaryKey = + exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny(); + assertThat(binaryKey) + .get() + .extracting(ErrorDetail::getValue) + .isEqualTo("text".getBytes(StandardCharsets.UTF_8)); + } + + @Test + void grpcDetails() throws Exception { + statement.setSqlQuery("grpc"); + + AdbcException exception = + assertThrows( + AdbcException.class, + () -> { + try (AdbcStatement.QueryResult result = statement.executeQuery()) {} + }); + + assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text")); + Optional binaryKey = + exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny(); + assertThat(binaryKey) + .get() + .extracting(ErrorDetail::getValue) + .isEqualTo("text".getBytes(StandardCharsets.UTF_8)); + } + + static class Producer implements FlightSqlProducer { + Metadata.Key BINARY_KEY = Metadata.Key.of("x-foo-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata.Key TEXT_KEY = Metadata.Key.of("x-foo", Metadata.ASCII_STRING_MARSHALLER); + + @Override + public FlightInfo getFlightInfoStatement( + FlightSql.CommandStatementQuery commandStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + if (commandStatementQuery.getQuery().equals("flight")) { + // Using Flight path + ErrorFlightMetadata metadata = new ErrorFlightMetadata(); + metadata.insert("x-foo", "text"); + metadata.insert("x-foo-bin", "text".getBytes(StandardCharsets.UTF_8)); + throw CallStatus.UNKNOWN + .withDescription("Expected") + .withMetadata(metadata) + .toRuntimeException(); + } else if (commandStatementQuery.getQuery().equals("grpc")) { + // Using gRPC path + Metadata trailers = new Metadata(); + trailers.put(TEXT_KEY, "text"); + trailers.put(BINARY_KEY, "text".getBytes(StandardCharsets.UTF_8)); + throw Status.UNKNOWN.asRuntimeException(trailers); + } + + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + // No-op implementations + + @Override + public void createPreparedStatement( + FlightSql.ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) {} + + @Override + public void closePreparedStatement( + FlightSql.ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) {} + + @Override + public FlightInfo getFlightInfoPreparedStatement( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public SchemaResult getSchemaStatement( + FlightSql.CommandStatementQuery commandStatementQuery, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamStatement( + FlightSql.TicketStatementQuery ticketStatementQuery, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamPreparedStatement( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public Runnable acceptPutStatement( + FlightSql.CommandStatementUpdate commandStatementUpdate, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + FlightSql.CommandPreparedStatementUpdate commandPreparedStatementUpdate, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, + FlightStream flightStream, + StreamListener streamListener) { + return null; + } + + @Override + public FlightInfo getFlightInfoSqlInfo( + FlightSql.CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamSqlInfo( + FlightSql.CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTypeInfo( + FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTypeInfo( + FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoCatalogs( + FlightSql.CommandGetCatalogs commandGetCatalogs, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamCatalogs( + CallContext callContext, ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoSchemas( + FlightSql.CommandGetDbSchemas commandGetDbSchemas, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamSchemas( + FlightSql.CommandGetDbSchemas commandGetDbSchemas, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTables( + FlightSql.CommandGetTables commandGetTables, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTables( + FlightSql.CommandGetTables commandGetTables, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoTableTypes( + FlightSql.CommandGetTableTypes commandGetTableTypes, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamTableTypes( + CallContext callContext, ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoPrimaryKeys( + FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamPrimaryKeys( + FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public FlightInfo getFlightInfoExportedKeys( + FlightSql.CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public FlightInfo getFlightInfoImportedKeys( + FlightSql.CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public FlightInfo getFlightInfoCrossReference( + FlightSql.CommandGetCrossReference commandGetCrossReference, + CallContext callContext, + FlightDescriptor flightDescriptor) { + return null; + } + + @Override + public void getStreamExportedKeys( + FlightSql.CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamImportedKeys( + FlightSql.CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void getStreamCrossReference( + FlightSql.CommandGetCrossReference commandGetCrossReference, + CallContext callContext, + ServerStreamListener serverStreamListener) {} + + @Override + public void close() throws Exception {} + + @Override + public void listFlights( + CallContext callContext, Criteria criteria, StreamListener streamListener) {} + } +} diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java index ccce7db70d..fce9ff134d 100644 --- a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java +++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java @@ -37,6 +37,9 @@ public class PostgresqlQuirks extends SqlValidationQuirks { static final String POSTGRESQL_URL_ENV_VAR = "ADBC_JDBC_POSTGRESQL_URL"; static final String POSTGRESQL_USER_ENV_VAR = "ADBC_JDBC_POSTGRESQL_USER"; static final String POSTGRESQL_PASSWORD_ENV_VAR = "ADBC_JDBC_POSTGRESQL_PASSWORD"; + static final String POSTGRESQL_DATABASE_ENV_VAR = "ADBC_JDBC_POSTGRESQL_DATABASE"; + + String catalog = "postgres"; static String makeJdbcUrl() { final String postgresUrl = System.getenv(POSTGRESQL_URL_ENV_VAR); @@ -49,12 +52,21 @@ static String makeJdbcUrl() { return String.format("jdbc:postgresql://%s?user=%s&password=%s", postgresUrl, user, password); } + public Connection getJdbcConnection() throws SQLException { + return DriverManager.getConnection(makeJdbcUrl()); + } + @Override public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException { String url = makeJdbcUrl(); + final String catalog = System.getenv(POSTGRESQL_DATABASE_ENV_VAR); + Assumptions.assumeFalse( + catalog == null, "PostgreSQL catalog not found, set " + POSTGRESQL_DATABASE_ENV_VAR); + this.catalog = catalog; + final Map parameters = new HashMap<>(); - parameters.put(AdbcDriver.PARAM_URL, url); + AdbcDriver.PARAM_URI.set(parameters, url); parameters.put(JdbcDriver.PARAM_JDBC_QUIRKS, StandardJdbcQuirks.POSTGRESQL); return new JdbcDriver(allocator).open(parameters); } @@ -71,8 +83,12 @@ public void cleanupTable(String name) throws Exception { @Override public String defaultCatalog() { - // XXX: this should really come from configuration - return "postgres"; + return catalog; + } + + @Override + public String defaultDbSchema() { + return "public"; } @Override @@ -94,4 +110,9 @@ public TimeUnit defaultTimeUnit() { public TimeUnit defaultTimestampUnit() { return TimeUnit.MICROSECOND; } + + @Override + public boolean supportsCurrentCatalog() { + return true; + } } diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java new file mode 100644 index 0000000000..13ca0ee191 --- /dev/null +++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.jdbc.postgresql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.entry; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.StandardStatistics; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class StatisticsTest { + static PostgresqlQuirks quirks; + + @BeforeAll + static void beforeAll() { + quirks = new PostgresqlQuirks(); + } + + @Test + void adbc() throws Exception { + try (Connection connection = quirks.getJdbcConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest"); + statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)"); + statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)"); + statement.executeUpdate("ANALYZE adbcpkeytest"); + } + + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase database = quirks.initDatabase(allocator); + AdbcConnection connection = database.connect(); + ArrowReader reader = connection.getStatistics(null, null, "adbcpkeytest", true)) { + assertThat(reader.loadNextBatch()).isTrue(); + VectorSchemaRoot vsr = reader.getVectorSchemaRoot(); + assertThat(vsr.getRowCount()).isEqualTo(1); + + ListVector catalogDbSchemas = (ListVector) vsr.getVector(1); + assertThat(catalogDbSchemas.getValueCount()).isEqualTo(1); + + StructVector catalogDbSchema = (StructVector) catalogDbSchemas.getDataVector(); + ListVector dbSchemaStatistics = (ListVector) catalogDbSchema.getVectorById(1); + assertThat(dbSchemaStatistics.getValueCount()).isEqualTo(1); + + @SuppressWarnings("unchecked") + Map statistic = (Map) dbSchemaStatistics.getObject(0).get(0); + assertThat(statistic) + .contains( + entry("table_name", new Text("adbcpkeytest")), + entry("statistic_key", StandardStatistics.DISTINCT_COUNT.getKey()), + entry("statistic_value", 3L)); + + assertThat(reader.loadNextBatch()).isFalse(); + } + } + + /** Validate what PostgreSQL does. */ + @Test + void jdbc() throws Exception { + try (Connection connection = quirks.getJdbcConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest"); + statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)"); + statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)"); + statement.executeUpdate("ANALYZE adbcpkeytest"); + + int count = 0; + try (ResultSet rs = + connection.getMetaData().getIndexInfo(null, null, "adbcpkeytest", false, true)) { + ResultSetMetaData rsmd = rs.getMetaData(); + while (rs.next()) { + // For debugging + for (int i = 1; i <= rsmd.getColumnCount(); i++) { + System.out.println(rsmd.getColumnName(i) + " => " + rs.getObject(i)); + } + System.out.println("==="); + + // TABLE_NAME + assertThat(rs.getString(3)).isEqualTo("adbcpkeytest"); + // TYPE + assertThat(rs.getShort(7)).isEqualTo(DatabaseMetaData.tableIndexOther); + // CARDINALITY + assertThat(rs.getLong(11)).isEqualTo(3); + + count++; + } + } + + assertThat(count).isEqualTo(1); + } + } +} diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java index 02c2ccac22..ae5ec226fc 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java @@ -25,10 +25,12 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.arrow.adbc.core.AdbcDriver; import org.apache.arrow.adbc.core.AdbcInfoCode; import org.apache.arrow.adbc.core.StandardSchemas; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -37,6 +39,7 @@ /** Helper class to track state needed to build up the info structure. */ final class InfoMetadataBuilder implements AutoCloseable { private static final byte STRING_VALUE_TYPE_ID = (byte) 0; + private static final byte BIGINT_VALUE_TYPE_ID = (byte) 2; private static final Map SUPPORTED_CODES = new HashMap<>(); private final Collection requestedCodes; private final DatabaseMetaData dbmd; @@ -45,6 +48,7 @@ final class InfoMetadataBuilder implements AutoCloseable { final UInt4Vector infoCodes; final DenseUnionVector infoValues; final VarCharVector stringValues; + final BigIntVector bigIntValues; @FunctionalInterface interface AddInfo { @@ -74,6 +78,11 @@ interface AddInfo { final String driverVersion = b.dbmd.getDriverVersion() + " (ADBC Driver Version 0.0.1)"; b.setStringValue(idx, driverVersion); }); + SUPPORTED_CODES.put( + AdbcInfoCode.DRIVER_ADBC_VERSION.getValue(), + (b, idx) -> { + b.setBigIntValue(idx, AdbcDriver.ADBC_VERSION_1_1_0); + }); } InfoMetadataBuilder(BufferAllocator allocator, Connection connection, int[] infoCodes) @@ -86,7 +95,18 @@ interface AddInfo { this.dbmd = connection.getMetaData(); this.infoCodes = (UInt4Vector) root.getVector(0); this.infoValues = (DenseUnionVector) root.getVector(1); - this.stringValues = this.infoValues.getVarCharVector((byte) 0); + this.stringValues = this.infoValues.getVarCharVector(STRING_VALUE_TYPE_ID); + this.bigIntValues = this.infoValues.getBigIntVector(BIGINT_VALUE_TYPE_ID); + } + + void setBigIntValue(int index, long value) { + infoValues.setValueCount(index + 1); + infoValues.setTypeId(index, BIGINT_VALUE_TYPE_ID); + bigIntValues.setSafe(index, value); + infoValues + .getOffsetBuffer() + .setInt((long) index * DenseUnionVector.OFFSET_WIDTH, bigIntValues.getValueCount()); + bigIntValues.setValueCount(bigIntValues.getValueCount() + 1); } void setStringValue(int index, final String value) { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java index 1ddbf1c88a..aba972a9a2 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java @@ -42,12 +42,7 @@ public class JdbcArrowReader extends ArrowReader { JdbcArrowReader(BufferAllocator allocator, ResultSet resultSet, Schema overrideSchema) throws AdbcException { super(allocator); - final JdbcToArrowConfig config = - new JdbcToArrowConfigBuilder() - .setAllocator(allocator) - .setCalendar(JdbcToArrowUtils.getUtcCalendar()) - .setTargetBatchSize(1024) - .build(); + final JdbcToArrowConfig config = makeJdbcConfig(allocator); try { this.delegate = JdbcToArrow.sqlToArrowVectorIterator(resultSet, config); } catch (SQLException e) { @@ -75,6 +70,14 @@ public class JdbcArrowReader extends ArrowReader { } } + static JdbcToArrowConfig makeJdbcConfig(BufferAllocator allocator) { + return new JdbcToArrowConfigBuilder() + .setAllocator(allocator) + .setCalendar(JdbcToArrowUtils.getUtcCalendar()) + .setTargetBatchSize(1024) + .build(); + } + @Override public boolean loadNextBatch() { if (!delegate.hasNext()) return false; diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java index 398ef6d42e..8f66c154fa 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java @@ -21,7 +21,9 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -29,15 +31,24 @@ import org.apache.arrow.adbc.core.BulkIngestMode; import org.apache.arrow.adbc.core.IsolationLevel; import org.apache.arrow.adbc.core.StandardSchemas; +import org.apache.arrow.adbc.core.StandardStatistics; import org.apache.arrow.adbc.driver.jdbc.adapter.JdbcFieldInfoExtra; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; public class JdbcConnection implements AdbcConnection { private final BufferAllocator allocator; @@ -117,6 +128,165 @@ public ArrowReader getObjects( } } + static final class Statistic { + String table; + String column; + short key; + long value; + boolean multiColumn = false; + } + + @Override + public ArrowReader getStatistics( + String catalogPattern, String dbSchemaPattern, String tableNamePattern, boolean approximate) + throws AdbcException { + if (tableNamePattern == null) { + throw AdbcException.notImplemented( + JdbcDriverUtil.prefixExceptionMessage("getStatistics: must supply table name")); + } + + try (final VectorSchemaRoot root = + VectorSchemaRoot.create(StandardSchemas.GET_STATISTICS_SCHEMA, allocator); + ResultSet rs = + connection + .getMetaData() + .getIndexInfo( + catalogPattern, + dbSchemaPattern, + tableNamePattern, /*unique*/ + false, + approximate)) { + // Build up the statistics in-memory and then return a constant reader. + // We have to read and sort the data first because the ordering is not by the catalog/etc. + + // {catalog: {schema: {index_name: statistic}}} + Map>> allStatistics = new HashMap<>(); + + while (rs.next()) { + String catalog = rs.getString(1); + String schema = rs.getString(2); + String table = rs.getString(3); + String index = rs.getString(6); + short statisticType = rs.getShort(7); + String column = rs.getString(9); + long cardinality = rs.getLong(11); + + if (!allStatistics.containsKey(catalog)) { + allStatistics.put(catalog, new HashMap<>()); + } + + Map> catalogStats = allStatistics.get(catalog); + if (!catalogStats.containsKey(schema)) { + catalogStats.put(schema, new HashMap<>()); + } + + Map schemaStats = catalogStats.get(schema); + Statistic statistic = schemaStats.getOrDefault(index, new Statistic()); + if (schemaStats.containsKey(index)) { + // Multi-column index, ignore it + statistic.multiColumn = true; + continue; + } + + statistic.column = column; + statistic.table = table; + statistic.key = + statisticType == DatabaseMetaData.tableIndexStatistic + ? StandardStatistics.ROW_COUNT.getKey() + : StandardStatistics.DISTINCT_COUNT.getKey(); + statistic.value = cardinality; + schemaStats.put(index, statistic); + } + + VarCharVector catalogNames = (VarCharVector) root.getVector(0); + ListVector catalogDbSchemas = (ListVector) root.getVector(1); + StructVector dbSchemas = (StructVector) catalogDbSchemas.getDataVector(); + VarCharVector dbSchemaNames = (VarCharVector) dbSchemas.getVectorById(0); + ListVector dbSchemaStatistics = (ListVector) dbSchemas.getVectorById(1); + StructVector statistics = (StructVector) dbSchemaStatistics.getDataVector(); + VarCharVector tableNames = (VarCharVector) statistics.getVectorById(0); + VarCharVector columnNames = (VarCharVector) statistics.getVectorById(1); + SmallIntVector statisticKeys = (SmallIntVector) statistics.getVectorById(2); + DenseUnionVector statisticValues = (DenseUnionVector) statistics.getVectorById(3); + BitVector statisticIsApproximate = (BitVector) statistics.getVectorById(4); + + // Build up the Arrow result + Text text = new Text(); + NullableBigIntHolder holder = new NullableBigIntHolder(); + int catalogIndex = 0; + int schemaIndex = 0; + int statisticIndex = 0; + for (String catalog : allStatistics.keySet()) { + Map> schemas = allStatistics.get(catalog); + + if (catalog == null) { + catalogNames.setNull(catalogIndex); + } else { + text.set(catalog); + catalogNames.setSafe(catalogIndex, text); + } + catalogDbSchemas.startNewValue(catalogIndex); + + int schemaCount = 0; + for (String schema : schemas.keySet()) { + if (schema == null) { + dbSchemaNames.setNull(schemaIndex); + } else { + text.set(schema); + dbSchemaNames.setSafe(schemaIndex, text); + } + + dbSchemaStatistics.startNewValue(schemaIndex); + + Map indices = schemas.get(schema); + int statisticCount = 0; + for (Statistic statistic : indices.values()) { + if (statistic.multiColumn) { + continue; + } + + text.set(statistic.table); + tableNames.setSafe(statisticIndex, text); + if (statistic.column == null) { + columnNames.setNull(statisticIndex); + } else { + text.set(statistic.column); + columnNames.setSafe(statisticIndex, text); + } + statisticKeys.setSafe(statisticIndex, statistic.key); + statisticValues.setTypeId(statisticIndex, (byte) 0); + holder.isSet = 1; + holder.value = statistic.value; + statisticValues.setSafe(statisticIndex, holder); + statisticIsApproximate.setSafe(statisticIndex, approximate ? 1 : 0); + + statistics.setIndexDefined(statisticIndex++); + statisticCount++; + } + + dbSchemaStatistics.endValue(schemaIndex, statisticCount); + + dbSchemas.setIndexDefined(schemaIndex++); + schemaCount++; + } + + catalogDbSchemas.endValue(catalogIndex, schemaCount); + catalogIndex++; + } + root.setRowCount(catalogIndex); + + return RootArrowReader.fromRoot(allocator, root); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public ArrowReader getStatisticNames() throws AdbcException { + // TODO: + return AdbcConnection.super.getStatisticNames(); + } + @Override public Schema getTableSchema(String catalog, String dbSchema, String tableName) throws AdbcException { @@ -211,6 +381,42 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException { } } + @Override + public String getCurrentCatalog() throws AdbcException { + try { + return connection.getCatalog(); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public void setCurrentCatalog(String catalog) throws AdbcException { + try { + connection.setCatalog(catalog); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public String getCurrentDbSchema() throws AdbcException { + try { + return connection.getSchema(); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + + @Override + public void setCurrentDbSchema(String dbSchema) throws AdbcException { + try { + connection.setSchema(dbSchema); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public boolean getReadOnly() throws AdbcException { try { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java index 95b3775f68..fd39e6d08b 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java @@ -30,6 +30,7 @@ import java.util.stream.LongStream; import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig; import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -263,6 +264,41 @@ public QueryResult executeQuery() throws AdbcException { return new QueryResult(/*affectedRows=*/ -1, reader); } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[JDBC] Call executeUpdate() for bulk operations"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[JDBC] Must setSqlQuery() first"); + } + try { + invalidatePriorQuery(); + final PreparedStatement preparedStatement; + final PreparedStatement ownedStatement; + if (statement instanceof PreparedStatement) { + preparedStatement = (PreparedStatement) statement; + if (bindRoot != null) { + JdbcParameterBinder.builder(preparedStatement, bindRoot).bindAll().build().next(); + } + ownedStatement = null; + } else { + // new statement + preparedStatement = connection.prepareStatement(sqlQuery); + ownedStatement = preparedStatement; + } + + final JdbcToArrowConfig config = JdbcArrowReader.makeJdbcConfig(allocator); + final Schema schema = + JdbcToArrowUtils.jdbcToArrowSchema(preparedStatement.getMetaData(), config); + if (ownedStatement != null) { + ownedStatement.close(); + } + return schema; + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException(e); + } + } + @Override public Schema getParameterSchema() throws AdbcException { if (statement instanceof PreparedStatement) { diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java index 9915636d69..54e6059046 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java @@ -18,6 +18,7 @@ package org.apache.arrow.adbc.driver.testsuite; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; import org.apache.arrow.adbc.core.AdbcConnection; import org.apache.arrow.adbc.core.AdbcDatabase; @@ -48,6 +49,19 @@ public void afterEach() throws Exception { AutoCloseables.close(connection, database, allocator); } + @Test + void currentCatalog() throws Exception { + assumeThat(quirks.supportsCurrentCatalog()).isTrue(); + + assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog()); + connection.setCurrentCatalog(quirks.defaultCatalog()); + assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog()); + + assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema()); + connection.setCurrentDbSchema(quirks.defaultDbSchema()); + assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema()); + } + @Test void multipleConnections() throws Exception { try (final AdbcConnection ignored = database.connect()) {} diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java index e7a1a5743a..4d9184a4bb 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java @@ -19,6 +19,7 @@ import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertField; import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertRoot; +import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertSchema; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -239,6 +240,62 @@ public void bulkIngestCreateConflict() throws Exception { } } + @Test + public void executeSchema() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String name = quirks.caseFoldColumnName("STRS"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT " + name + " FROM " + tableName); + final Schema actualSchema = stmt.executeSchema(); + assertSchema(actualSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(name, Types.MinorType.VARCHAR.getType())))); + } + } + + @Test + public void executeSchemaPrepared() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String name = quirks.caseFoldColumnName("STRS"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT " + name + " FROM " + tableName); + stmt.prepare(); + final Schema actualSchema = stmt.executeSchema(); + assertSchema(actualSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(name, Types.MinorType.VARCHAR.getType())))); + } + } + + @Test + public void executeSchemaParams() throws Exception { + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery("SELECT ? AS FOO"); + stmt.prepare(); + Schema actualSchema = stmt.executeSchema(); + // Actual type unknown + assertThat(actualSchema.getFields().size()).isEqualTo(1); + + final Schema schema = + new Schema( + Collections.singletonList( + Field.nullable( + quirks.caseFoldColumnName("foo"), Types.MinorType.VARCHAR.getType()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + ((VarCharVector) root.getVector(0)).setSafe(0, "foo".getBytes(StandardCharsets.UTF_8)); + root.setRowCount(1); + stmt.bind(root); + + actualSchema = stmt.executeSchema(); + assertSchema(actualSchema).isEqualTo(schema); + } + } + } + @Test public void prepareQuery() throws Exception { final Schema expectedSchema = util.ingestTableIntsStrs(allocator, connection, tableName); diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java index 120ecab255..a5da97f658 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java @@ -35,6 +35,11 @@ public void cleanupTable(String name) throws Exception {} /** Get the name of the default catalog. */ public abstract String defaultCatalog(); + /** Get the name of the default schema. */ + public String defaultDbSchema() { + return ""; + } + /** Normalize a table name. */ public String caseFoldTableName(String name) { return name; @@ -110,4 +115,8 @@ public ArrowType defaultTimeType() { public TimeUnit defaultTimestampUnit() { return TimeUnit.MILLISECOND; } + + public boolean supportsCurrentCatalog() { + return false; + } }