Skip to content

Commit

Permalink
feat(c/driver/postgresql): implement StatementExecuteSchema
Browse files Browse the repository at this point in the history
A draft implementation of the new StatementExecuteSchema method
for the PostgreSQL driver.

See apache#318.
  • Loading branch information
lidavidm committed Jun 26, 2023
1 parent 362638d commit efe9f79
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 44 deletions.
8 changes: 6 additions & 2 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,13 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type,

class PostgresCopyStreamReader {
public:
ArrowErrorCode Init(const PostgresType& pg_type) {
ArrowErrorCode Init(PostgresType pg_type) {
if (pg_type.type_id() != PostgresTypeId::kRecord) {
return EINVAL;
}

root_reader_.Init(pg_type);
pg_type_ = std::move(pg_type);
root_reader_.Init(pg_type_);
return NANOARROW_OK;
}

Expand Down Expand Up @@ -791,7 +792,10 @@ class PostgresCopyStreamReader {
return NANOARROW_OK;
}

const PostgresType& pg_type() const { return pg_type_; }

private:
PostgresType pg_type_;
PostgresCopyFieldTupleReader root_reader_;
nanoarrow::UniqueSchema schema_;
nanoarrow::UniqueArray array_;
Expand Down
23 changes: 22 additions & 1 deletion c/driver/postgresql/postgresql.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement,
return (*ptr)->ExecuteQuery(output, rows_affected, error);
}

AdbcStatusCode PostgresStatementExecuteSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
auto* ptr =
reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
return (*ptr)->ExecuteSchema(schema, error);
}

AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement,
uint8_t* partition_desc,
struct AdbcError* error) {
Expand Down Expand Up @@ -423,6 +432,11 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
return PostgresStatementExecuteQuery(statement, output, rows_affected, error);
}

AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement,
ArrowSchema* schema, struct AdbcError* error) {
return PostgresStatementExecuteSchema(statement, schema, error);
}

AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement,
uint8_t* partition_desc,
struct AdbcError* error) {
Expand Down Expand Up @@ -474,7 +488,13 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e
if (!raw_driver) return ADBC_STATUS_INVALID_ARGUMENT;

auto* driver = reinterpret_cast<struct AdbcDriver*>(raw_driver);
std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE);
if (version >= ADBC_VERSION_1_1_0) {
std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE);
driver->StatementExecuteSchema = PostgresStatementExecuteSchema;
} else {
std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE);
}

driver->DatabaseInit = PostgresDatabaseInit;
driver->DatabaseNew = PostgresDatabaseNew;
driver->DatabaseRelease = PostgresDatabaseRelease;
Expand Down Expand Up @@ -502,6 +522,7 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e
driver->StatementRelease = PostgresStatementRelease;
driver->StatementSetOption = PostgresStatementSetOption;
driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery;

return ADBC_STATUS_OK;
}
}
2 changes: 2 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {

std::string catalog() const override { return "postgres"; }
std::string db_schema() const override { return "public"; }

bool supports_execute_schema() const override { return true; }
};

class PostgresDatabaseTest : public ::testing::Test,
Expand Down
105 changes: 64 additions & 41 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,50 +676,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,

// 1. Prepare the query to get the schema
{
// TODO: we should pipeline here and assume this will succeed
PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(),
/*nParams=*/0, nullptr);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error,
"[libpq] Failed to execute query: could not infer schema: failed to "
"prepare query: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
PQclear(result);
result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error,
"[libpq] Failed to execute query: could not infer schema: failed to "
"describe prepared statement: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}

// Resolve the information from the PGresult into a PostgresType
PostgresType root_type;
AdbcStatusCode status =
ResolvePostgresType(*type_resolver_, result, &root_type, error);
PQclear(result);
if (status != ADBC_STATUS_OK) return status;

// Initialize the copy reader and infer the output schema (i.e., error for
// unsupported types before issuing the COPY query)
reader_.copy_reader_.reset(new PostgresCopyStreamReader());
reader_.copy_reader_->Init(root_type);
struct ArrowError na_error;
int na_res = reader_.copy_reader_->InferOutputSchema(&na_error);
if (na_res != NANOARROW_OK) {
SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message);
return na_res;
}
RAISE_ADBC(SetupReader(error));

// If the caller did not request a result set or if there are no
// inferred output columns (e.g. a CREATE or UPDATE), then don't
// use COPY (which would fail anyways)
if (!stream || root_type.n_children() == 0) {
if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) {
RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error));
if (stream) {
struct ArrowSchema schema;
Expand All @@ -733,7 +695,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
// This resolves the reader specific to each PostgresType -> ArrowSchema
// conversion. It is unlikely that this will fail given that we have just
// inferred these conversions ourselves.
na_res = reader_.copy_reader_->InitFieldReaders(&na_error);
struct ArrowError na_error;
int na_res = reader_.copy_reader_->InitFieldReaders(&na_error);
if (na_res != NANOARROW_OK) {
SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message);
return na_res;
Expand Down Expand Up @@ -762,6 +725,23 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
return ADBC_STATUS_OK;
}

AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema,
struct AdbcError* error) {
ClearResult();
if (query_.empty()) {
SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery");
return ADBC_STATUS_INVALID_STATE;
} else if (bind_.release) {
// TODO: if we have parameters, bind them (since they can affect the output schema)
SetError(error, "[libpq] ExecuteSchema with parameters is not implemented");
return ADBC_STATUS_NOT_IMPLEMENTED;
}

RAISE_ADBC(SetupReader(error));
CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error);
return ADBC_STATUS_OK;
}

AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
struct AdbcError* error) {
if (!bind_.release) {
Expand Down Expand Up @@ -870,6 +850,49 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
return ADBC_STATUS_OK;
}

AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) {
// TODO: we should pipeline here and assume this will succeed
PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(),
/*nParams=*/0, nullptr);
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error,
"[libpq] Failed to execute query: could not infer schema: failed to "
"prepare query: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}
PQclear(result);
result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
if (PQresultStatus(result) != PGRES_COMMAND_OK) {
SetError(error,
"[libpq] Failed to execute query: could not infer schema: failed to "
"describe prepared statement: %s\nQuery was:%s",
PQerrorMessage(connection_->conn()), query_.c_str());
PQclear(result);
return ADBC_STATUS_IO;
}

// Resolve the information from the PGresult into a PostgresType
PostgresType root_type;
AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error);
PQclear(result);
if (status != ADBC_STATUS_OK) return status;

// Initialize the copy reader and infer the output schema (i.e., error for
// unsupported types before issuing the COPY query)
reader_.copy_reader_.reset(new PostgresCopyStreamReader());
reader_.copy_reader_->Init(root_type);
struct ArrowError na_error;
int na_res = reader_.copy_reader_->InferOutputSchema(&na_error);
if (na_res != NANOARROW_OK) {
SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res,
std::strerror(na_res), na_error.message);
return ADBC_STATUS_INTERNAL;
}
return ADBC_STATUS_OK;
}

void PostgresStatement::ClearResult() {
// TODO: we may want to synchronize here for safety
reader_.Release();
Expand Down
2 changes: 2 additions & 0 deletions c/driver/postgresql/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class PostgresStatement {
AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error);
AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected,
struct AdbcError* error);
AdbcStatusCode ExecuteSchema(struct ArrowSchema* schema, struct AdbcError* error);
AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error);
AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error);
AdbcStatusCode Prepare(struct AdbcError* error);
Expand All @@ -104,6 +105,7 @@ class PostgresStatement {
AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
int64_t* rows_affected,
struct AdbcError* error);
AdbcStatusCode SetupReader(struct AdbcError* error);

private:
std::shared_ptr<PostgresTypeResolver> type_resolver_;
Expand Down
6 changes: 6 additions & 0 deletions c/driver/sqlite/sqlite.c
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,12 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
return SqliteStatementExecuteQuery(statement, out, rows_affected, error);
}

AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
struct AdbcError* error) {
return SqliteStatementPrepare(statement, error);
Expand Down
81 changes: 81 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <gtest/gtest-matchers.h>
#include <gtest/gtest.h>
#include <nanoarrow/nanoarrow.h>
#include <nanoarrow/nanoarrow.hpp>

#include "adbc_validation_util.h"

Expand Down Expand Up @@ -2027,6 +2028,86 @@ void StatementTest::TestTransactions() {
}
}

void StatementTest::TestSqlSchemaInts() {
if (!quirks()->supports_execute_schema()) {
GTEST_SKIP() << "Not supported";
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
IsOkStatus(&error));

nanoarrow::UniqueSchema schema;
ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
IsOkStatus(&error));

ASSERT_EQ(1, schema->n_children);
ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({
::testing::StrEq("i"), // int32
::testing::StrEq("l"), // int64
}));

ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlSchemaFloats() {
if (!quirks()->supports_execute_schema()) {
GTEST_SKIP() << "Not supported";
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT CAST(1.5 AS FLOAT)", &error),
IsOkStatus(&error));

nanoarrow::UniqueSchema schema;
ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
IsOkStatus(&error));

ASSERT_EQ(1, schema->n_children);
ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({
::testing::StrEq("f"), // float32
::testing::StrEq("g"), // float64
}));

ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlSchemaStrings() {
if (!quirks()->supports_execute_schema()) {
GTEST_SKIP() << "Not supported";
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'hi'", &error),
IsOkStatus(&error));

nanoarrow::UniqueSchema schema;
ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
IsOkStatus(&error));

ASSERT_EQ(1, schema->n_children);
ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({
::testing::StrEq("u"), // string
::testing::StrEq("U"), // large_string
}));

ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlSchemaErrors() {
if (!quirks()->supports_execute_schema()) {
GTEST_SKIP() << "Not supported";
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

nanoarrow::UniqueSchema schema;
ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
IsStatus(ADBC_STATUS_INVALID_STATE, &error));

ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestConcurrentStatements() {
Handle<struct AdbcStatement> statement1;
Handle<struct AdbcStatement> statement2;
Expand Down
13 changes: 13 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class DriverQuirks {
/// single connection
virtual bool supports_concurrent_statements() const { return false; }

/// \brief Whether AdbcStatementExecuteSchema should work
virtual bool supports_execute_schema() const { return false; }

/// \brief Whether AdbcStatementExecutePartitions should work
virtual bool supports_partitioned_data() const { return false; }

Expand Down Expand Up @@ -253,6 +256,12 @@ class StatementTest {

void TestSqlQueryErrors();

void TestSqlSchemaInts();
void TestSqlSchemaFloats();
void TestSqlSchemaStrings();

void TestSqlSchemaErrors();

void TestTransactions();

void TestConcurrentStatements();
Expand Down Expand Up @@ -307,6 +316,10 @@ class StatementTest {
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \
TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \
TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); } \
TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); } \
TEST_F(FIXTURE, SqlSchemaFloats) { TestSqlSchemaFloats(); } \
TEST_F(FIXTURE, SqlSchemaStrings) { TestSqlSchemaStrings(); } \
TEST_F(FIXTURE, SqlSchemaErrors) { TestSqlSchemaErrors(); } \
TEST_F(FIXTURE, Transactions) { TestTransactions(); } \
TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); } \
TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); }
Expand Down

0 comments on commit efe9f79

Please sign in to comment.