From 85e1f51296f7f2854c41b9612338495c2fa42352 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 17 May 2023 09:17:10 -0400 Subject: [PATCH] feat(format): add AdbcStatementExecuteSchema Fixes #318. --- .pre-commit-config.yaml | 2 +- adbc.h | 21 +++++++ c/driver_manager/adbc_driver_manager.cc | 18 +++++- go/adbc/adbc.go | 13 +++++ go/adbc/driver/flightsql/flightsql_adbc.go | 16 ++++++ .../driver/flightsql/flightsql_adbc_test.go | 1 + .../driver/flightsql/flightsql_statement.go | 55 +++++++++++++++++++ go/adbc/validation/validation.go | 46 ++++++++++++++++ .../apache/arrow/adbc/core/AdbcStatement.java | 11 ++++ .../flightsql/FlightSqlStatementTest.java | 8 +++ .../driver/flightsql/FlightSqlStatement.java | 12 ++++ .../arrow/adbc/driver/jdbc/JdbcStatement.java | 35 ++++++++++++ .../testsuite/AbstractStatementTest.java | 32 +++++++++++ 13 files changed, 268 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aacf930e7e..e01ba20d50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: - "--linelength=90" - "--verbose=2" - repo: https://github.com/golangci/golangci-lint - rev: v1.49.0 + rev: v1.52.2 hooks: - id: golangci-lint entry: bash -c 'cd go/adbc && golangci-lint run --fix --timeout 5m' diff --git a/adbc.h b/adbc.h index 27450faf57..90b5a86c62 100644 --- a/adbc.h +++ b/adbc.h @@ -683,6 +683,9 @@ struct ADBC_EXPORT AdbcDriver { /// worrying about multiple definitions of the same symbol. struct ADBC_EXPORT AdbcDriver110 { struct AdbcDriver base; + + AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); }; /// @} @@ -1072,6 +1075,24 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error); +/// \brief Get the schema of the result set of a query without +/// executing it. +/// +/// This invalidates any prior result sets. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The result schema. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support this. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + /// \brief Turn this statement into a prepared statement to be /// executed multiple times. /// diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index afe44a908a..faea85cf26 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -191,6 +191,12 @@ AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -560,6 +566,16 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, ->base.StatementExecuteQuery(statement, out, rows_affected, error); } +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + return static_cast(statement->private_driver) + ->StatementExecuteSchema(statement, schema, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -852,7 +868,7 @@ AdbcStatusCode PolyfillDriver100(AdbcDriver* driver, AdbcError* error) { } AdbcStatusCode PolyfillDriver110(AdbcDriver110* driver, AdbcError* error) { - // No new functions yet + FILL_DEFAULT(driver, StatementExecuteSchema); return PolyfillDriver100(&driver->base, error); } diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index 2ecd415436..0469b41818 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -536,3 +536,16 @@ type Statement interface { // an error with a StatusNotImplemented code. ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, error) } + +// Statement110 is an extension interface for methods added to Statement in +// ADBC API revision 1.1.0. +type Statement110 interface { + Statement + + // ExecuteSchema returns the schema of the result set of a query without + // executing it. + // + // If the driver does not support this, this will return an error with a + // StatusNotImplemented code. + ExecuteSchema(context.Context) (*arrow.Schema, error) +} diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index 9bbd7d1491..8591fb2713 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -1352,6 +1352,22 @@ func (c *cnxn) executeSubstraitUpdate(ctx context.Context, plan flightsql.Substr return c.cl.ExecuteSubstraitUpdate(ctx, plan, opts...) } +func (c *cnxn) getExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (schemaResult *flight.SchemaResult, err error) { + if c.txn != nil { + return c.txn.GetExecuteSchema(ctx, query, opts...) + } + + return c.cl.GetExecuteSchema(ctx, query, opts...) +} + +func (c *cnxn) getExecuteSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (schemaResult *flight.SchemaResult, err error) { + if c.txn != nil { + return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) + } + + return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) +} + func (c *cnxn) prepare(ctx context.Context, query string, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.Prepare(ctx, query, opts...) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 11d74f2f91..8070b686e3 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -232,6 +232,7 @@ func (s *FlightSQLQuirks) DropTable(cnxn adbc.Connection, tblname string) error func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) SupportsExecuteSchema() bool { return true } func (s *FlightSQLQuirks) SupportsPartitionedData() bool { return true } func (s *FlightSQLQuirks) SupportsTransactions() bool { return true } func (s *FlightSQLQuirks) SupportsGetParameterSchema() bool { return false } diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index 3d051ef6f5..301ab3444b 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -86,6 +86,35 @@ func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ... } } +func (s *sqlOrSubstrait) getExecuteSchema(ctx context.Context, alloc memory.Allocator, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { + var result *flight.SchemaResult + var err error + if s.sqlQuery != "" { + result, err = cnxn.getExecuteSchema(ctx, s.sqlQuery, opts...) + } else if s.substraitPlan != nil { + result, err = cnxn.getExecuteSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) + } else { + return nil, adbc.Error{ + Code: adbc.StatusInvalidState, + Msg: "[Flight SQL Statement] cannot call ExecuteSchema without a query or prepared statement", + } + } + + if err != nil { + return nil, err + } + + schema, err := flight.DeserializeSchema(result.GetSchema(), alloc) + if err != nil { + return nil, adbc.Error{ + Code: adbc.StatusInternal, + Msg: "[Flight SQL Statement] server returned invalid schema", + } + } + + return schema, nil +} + func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if s.sqlQuery != "" { return cnxn.prepare(ctx, s.sqlQuery, opts...) @@ -247,6 +276,32 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n return } +// ExecuteSchema returns the schema of the result set of a query without +// executing it. +// +// If the driver does not support this, this will return an error with a +// StatusNotImplemented code. +func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { + ctx = metadata.NewOutgoingContext(ctx, s.hdrs) + + if s.prepared != nil { + schema = s.prepared.DatasetSchema() + if schema == nil { + err = adbc.Error{ + Msg: "[Flight SQL Statement] server did not provide schema for prepared statement", + Code: adbc.StatusNotImplemented} + } + } else { + schema, err = s.query.getExecuteSchema(ctx, s.alloc, s.cnxn, s.timeouts) + } + + if err != nil { + err = adbcFromFlightStatus(err) + } + + return +} + // ExecuteUpdate executes a statement that does not generate a result // set. It returns the number of rows affected if known, otherwise -1. func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index bc6e368990..01ca642975 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -46,6 +46,8 @@ type DriverQuirks interface { BindParameter(index int) string // Whether two statements can be used at the same time on a single connection SupportsConcurrentStatements() bool + // Whether retrieving the schema of a query is supported + SupportsExecuteSchema() bool // Whether AdbcStatementExecutePartitions should work SupportsPartitionedData() bool // Whether transactions are supported (Commit/Rollback on connection) @@ -407,6 +409,28 @@ func (s *StatementTests) TestNewStatement() { s.Equal(adbc.StatusInvalidState, adbcError.Code) } +func (s *StatementTests) TestSQLExecuteSchema() { + stmt, err := s.Cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + query := "SELECT 1" + s.NoError(stmt.SetSqlQuery(query)) + + sc, err := stmt.(adbc.Statement110).ExecuteSchema(s.ctx) + if !s.Quirks.SupportsExecuteSchema() { + var adbcError adbc.Error + s.ErrorAs(err, &adbcError) + s.Equal(adbc.StatusNotImplemented, adbcError.Code) + return + } + // TODO: check for NotImplemented here, too, since even if the + // driver supports it the server may not + s.NoError(err) + + s.Len(sc.Fields(), 1) +} + func (s *StatementTests) TestSqlPartitionedInts() { stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) @@ -460,6 +484,28 @@ func (s *StatementTests) TestSqlPartitionedInts() { s.False(rdr.Next()) } +func (s *StatementTests) TestSQLPrepareExecuteSchema() { + stmt, err := s.Cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + query := "SELECT 1" + s.NoError(stmt.SetSqlQuery(query)) + s.NoError(stmt.Prepare(s.ctx)) + + // TODO: move new validation tests into new file? + sc, err := stmt.(adbc.Statement110).ExecuteSchema(s.ctx) + if !s.Quirks.SupportsExecuteSchema() { + var adbcError adbc.Error + s.ErrorAs(err, &adbcError) + s.Equal(adbc.StatusNotImplemented, adbcError.Code) + return + } + s.NoError(err) + + s.Len(sc.Fields(), 1) +} + func (s *StatementTests) TestSQLPrepareGetParameterSchema() { stmt, err := s.Cnxn.NewStatement() s.NoError(err) diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java index ef2be487e2..5952aca8e1 100644 --- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java +++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java @@ -83,6 +83,17 @@ default void bind(VectorSchemaRoot root) throws AdbcException { */ UpdateResult executeUpdate() throws AdbcException; + /** + * Get the schema of the result set of a query without executing it. + * + * @since ADBC API revision 1.1.0 + * @throws AdbcException with status {@link AdbcStatusCode#NOT_IMPLEMENTED} if the driver does not + * support this. + */ + default Schema executeSchema() throws AdbcException { + throw AdbcException.notImplemented("Statement does not support executeSchema"); + } + /** * Execute a result set-generating query and get a list of partitions of the result set. * diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java index 306f69e44f..a4270e1e89 100644 --- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java +++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatementTest.java @@ -30,4 +30,12 @@ public static void beforeAll() { @Override @Disabled("Requires spec clarification") public void prepareQueryWithParameters() {} + + @Override + @Disabled("Not supported by the SQLite test server") + public void executeSchema() {} + + @Override + @Disabled("Not supported by the SQLite test server") + public void executeSchemaPrepared() {} } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java index 1fd8b910c7..e64508b4bf 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java @@ -247,6 +247,18 @@ public QueryResult executeQuery() throws AdbcException { new FlightInfoReader(allocator, client, clientCache, info.getEndpoints())); } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[Flight SQL] Must executeUpdate() for bulk ingestion"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[Flight SQL] Must setSqlQuery() before execute"); + } + return execute( + FlightSqlClient.PreparedStatement::getResultSetSchema, + (client) -> client.getExecuteSchema(sqlQuery).getSchema()); + } + @Override public UpdateResult executeUpdate() throws AdbcException { if (bulkOperation != null) { diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java index ba3f30ecab..3ac84720c1 100644 --- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java +++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java @@ -22,7 +22,9 @@ import java.sql.ParameterMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; import java.sql.Statement; import java.util.ArrayList; import java.util.List; @@ -30,6 +32,8 @@ import java.util.stream.LongStream; import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; import org.apache.arrow.adbc.core.AdbcException; import org.apache.arrow.adbc.core.AdbcStatement; @@ -199,6 +203,37 @@ private void invalidatePriorQuery() throws AdbcException { } } + @Override + public Schema executeSchema() throws AdbcException { + if (bulkOperation != null) { + throw AdbcException.invalidState("[JDBC] Ingestion operations have no schema"); + } else if (sqlQuery == null) { + throw AdbcException.invalidState("[JDBC] Must setSqlQuery() first"); + } + + try (final PreparedStatement preparedStatement = + connection.prepareStatement( + sqlQuery, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY)) { + final ResultSetMetaData rsmd = preparedStatement.getMetaData(); + final JdbcToArrowConfig config = + new JdbcToArrowConfigBuilder() + .setAllocator(allocator) + .setCalendar(JdbcToArrowUtils.getUtcCalendar()) + .build(); + try { + return JdbcToArrowUtils.jdbcToArrowSchema(rsmd, config); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException("Failed to convert JDBC schema to Arrow schema:", e); + } + } catch (SQLFeatureNotSupportedException e) { + throw AdbcException.notImplemented( + "[JDBC] Driver does not support getting a result set schema") + .withCause(e); + } catch (SQLException e) { + throw JdbcDriverUtil.fromSqlException("Failed to prepare statement:", e); + } + } + @Override public UpdateResult executeUpdate() throws AdbcException { if (bulkOperation != null) { diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java index e7a1a5743a..69eca3bd9a 100644 --- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java +++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java @@ -322,6 +322,38 @@ public void prepareQueryWithParameters() throws Exception { } } + @Test + public void executeSchema() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String column = quirks.caseFoldColumnName("ints"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery(String.format("SELECT %s FROM %s", column, tableName)); + final Schema paramsSchema = stmt.executeSchema(); + assertThat(paramsSchema).isNotNull(); + assertThat(paramsSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(column, Types.MinorType.INT.getType())))); + } + } + + @Test + public void executeSchemaPrepared() throws Exception { + util.ingestTableIntsStrs(allocator, connection, tableName); + final String column = quirks.caseFoldColumnName("ints"); + try (final AdbcStatement stmt = connection.createStatement()) { + stmt.setSqlQuery(String.format("SELECT %s FROM %s", column, tableName)); + stmt.prepare(); + final Schema paramsSchema = stmt.executeSchema(); + assertThat(paramsSchema) + .isEqualTo( + new Schema( + Collections.singletonList( + Field.nullable(column, Types.MinorType.INT.getType())))); + } + } + @Test public void getParameterSchema() throws Exception { util.ingestTableIntsStrs(allocator, connection, tableName);