Skip to content

Commit

Permalink
feat(go/adbc): add FFI support for 1.1.0 features (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Aug 10, 2023
1 parent 2a7b7bd commit 04b5565
Show file tree
Hide file tree
Showing 16 changed files with 4,225 additions and 639 deletions.
49 changes: 29 additions & 20 deletions adbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,25 +334,35 @@ struct ADBC_EXPORT AdbcError {
/// Buffers) that can be optionally parsed by clients, beyond the
/// standard AdbcError fields, without having to encode it in the
/// error message. The encoding of the data is driver-defined.
/// Drivers may provide multiple error details.
///
/// This can be called immediately after any API call that returns an
/// This can be used immediately after any API call that returns an
/// error. Additionally, if an ArrowArrayStream returned from an
/// AdbcConnection or an AdbcStatement returns an error, this can be
/// immediately called from the associated AdbcConnection or
/// AdbcStatement to get further error details (if available). Making
/// other API calls with that connection or statement may clear this
/// error value.
///
/// Drivers may provide multiple error details. Each call to
/// GetOptionBytes will return the next error detail. The driver
/// should return ADBC_STATUS_NOT_FOUND if there are no (more) error
/// details.
/// To use, call GetOptionInt with this option to get the number of
/// available details. Then, call GetOption with the option key
/// ADBC_OPTION_ERROR_DETAILS_PREFIX + (zero-indexed index) to get the
/// name of the error detail (for example, drivers that use gRPC
/// underneath may provide the name of the gRPC trailer corresponding
/// to the error detail). GetOptionBytes with that option name will
/// retrieve the value of the error detail (for example, a serialized
/// Any-wrapped Protobuf).
///
/// The type is uint8_t*.
/// \since ADBC API revision 1.1.0
/// \addtogroup adbc-1.1.0
#define ADBC_OPTION_ERROR_DETAILS "adbc.error_details"

/// \brief Canonical option name for error details.
///
/// \see ADBC_OPTION_ERROR_DETAILS
/// \since ADBC API revision 1.1.0
/// \addtogroup adbc-1.1.0
#define ADBC_OPTION_ERROR_DETAILS "error_details"
#define ADBC_OPTION_ERROR_DETAILS_PREFIX "adbc.error_details."

/// \brief The database vendor/product name (e.g. the server name).
/// (type: utf8).
Expand Down Expand Up @@ -888,26 +898,26 @@ struct ADBC_EXPORT AdbcDriver {
struct AdbcError*);
AdbcStatusCode (*DatabaseGetOptionBytes)(struct AdbcDatabase*, const char*, uint8_t*,
size_t*, struct AdbcError*);
AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*,
struct AdbcError*);
AdbcStatusCode (*DatabaseGetOptionDouble)(struct AdbcDatabase*, const char*, double*,
struct AdbcError*);
AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*,
struct AdbcError*);
AdbcStatusCode (*DatabaseSetOptionBytes)(struct AdbcDatabase*, const char*,
const uint8_t*, size_t, struct AdbcError*);
AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t,
struct AdbcError*);
AdbcStatusCode (*DatabaseSetOptionDouble)(struct AdbcDatabase*, const char*, double,
struct AdbcError*);
AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t,
struct AdbcError*);

AdbcStatusCode (*ConnectionCancel)(struct AdbcConnection*, struct AdbcError*);
AdbcStatusCode (*ConnectionGetOption)(struct AdbcConnection*, const char*, char*,
size_t*, struct AdbcError*);
AdbcStatusCode (*ConnectionGetOptionBytes)(struct AdbcConnection*, const char*,
uint8_t*, size_t*, struct AdbcError*);
AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*,
struct AdbcError*);
AdbcStatusCode (*ConnectionGetOptionDouble)(struct AdbcConnection*, const char*,
double*, struct AdbcError*);
AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*,
struct AdbcError*);
AdbcStatusCode (*ConnectionGetStatistics)(struct AdbcConnection*, const char*,
const char*, const char*, char,
struct ArrowArrayStream*, struct AdbcError*);
Expand All @@ -916,10 +926,10 @@ struct ADBC_EXPORT AdbcDriver {
struct AdbcError*);
AdbcStatusCode (*ConnectionSetOptionBytes)(struct AdbcConnection*, const char*,
const uint8_t*, size_t, struct AdbcError*);
AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t,
struct AdbcError*);
AdbcStatusCode (*ConnectionSetOptionDouble)(struct AdbcConnection*, const char*, double,
struct AdbcError*);
AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t,
struct AdbcError*);

AdbcStatusCode (*StatementCancel)(struct AdbcStatement*, struct AdbcError*);
AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*,
Expand All @@ -928,16 +938,16 @@ struct ADBC_EXPORT AdbcDriver {
struct AdbcError*);
AdbcStatusCode (*StatementGetOptionBytes)(struct AdbcStatement*, const char*, uint8_t*,
size_t*, struct AdbcError*);
AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*,
struct AdbcError*);
AdbcStatusCode (*StatementGetOptionDouble)(struct AdbcStatement*, const char*, double*,
struct AdbcError*);
AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*,
struct AdbcError*);
AdbcStatusCode (*StatementSetOptionBytes)(struct AdbcStatement*, const char*,
const uint8_t*, size_t, struct AdbcError*);
AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t,
struct AdbcError*);
AdbcStatusCode (*StatementSetOptionDouble)(struct AdbcStatement*, const char*, double,
struct AdbcError*);
AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t,
struct AdbcError*);

/// @}
};
Expand Down Expand Up @@ -1639,7 +1649,6 @@ AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection,
/// | int64 | int64 |
/// | uint64 | uint64 |
/// | float64 | float64 |
/// | decimal256 | decimal256 |
/// | binary | binary |
///
/// This AdbcConnection must outlive the returned ArrowArrayStream.
Expand Down
20 changes: 20 additions & 0 deletions c/driver/flightsql/sqlite_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include <chrono>
#include <optional>
#include <random>
#include <thread>

Expand Down Expand Up @@ -92,6 +93,25 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks {
bool supports_concurrent_statements() const override { return true; }
bool supports_transactions() const override { return false; }
bool supports_get_sql_info() const override { return true; }
std::optional<adbc_validation::SqlInfoValue> supports_get_sql_info(
uint32_t info_code) const override {
switch (info_code) {
case ADBC_INFO_DRIVER_NAME:
return "ADBC Flight SQL Driver - Go";
case ADBC_INFO_DRIVER_VERSION:
return "(unknown or development build)";
case ADBC_INFO_DRIVER_ADBC_VERSION:
return ADBC_VERSION_1_1_0;
case ADBC_INFO_VENDOR_NAME:
return "db_name";
case ADBC_INFO_VENDOR_VERSION:
return "sqlite 3";
case ADBC_INFO_VENDOR_ARROW_VERSION:
return "12.0.0";
default:
return std::nullopt;
}
}
bool supports_get_objects() const override { return true; }
bool supports_bulk_ingest() const override { return false; }
bool supports_partitioned_data() const override { return true; }
Expand Down
157 changes: 89 additions & 68 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <string_view>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>

#include <adbc.h>
Expand Down Expand Up @@ -255,79 +256,99 @@ void ConnectionTest::TestMetadataGetInfo() {
GTEST_SKIP();
}

StreamReader reader;
std::vector<uint32_t> info = {
ADBC_INFO_DRIVER_NAME,
ADBC_INFO_DRIVER_VERSION,
ADBC_INFO_VENDOR_NAME,
ADBC_INFO_VENDOR_VERSION,
};
for (uint32_t info_code : {
ADBC_INFO_DRIVER_NAME,
ADBC_INFO_DRIVER_VERSION,
ADBC_INFO_DRIVER_ADBC_VERSION,
ADBC_INFO_VENDOR_NAME,
ADBC_INFO_VENDOR_VERSION,
}) {
uint32_t info[] = {info_code};

ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(),
&reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(CompareSchema(
&reader.schema.value, {
{"info_name", NANOARROW_TYPE_UINT32, NOT_NULL},
{"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1],
{
{"string_value", NANOARROW_TYPE_STRING, NULLABLE},
{"bool_value", NANOARROW_TYPE_BOOL, NULLABLE},
{"int64_value", NANOARROW_TYPE_INT64, NULLABLE},
{"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE},
{"string_list", NANOARROW_TYPE_LIST, NULLABLE},
{"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4],
{
{"item", NANOARROW_TYPE_STRING, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[5],
{
{"entries", NANOARROW_TYPE_STRUCT, NOT_NULL},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1]->children[5]->children[0],
{
{"key", NANOARROW_TYPE_INT32, NOT_NULL},
{"value", NANOARROW_TYPE_LIST, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1],
{
{"item", NANOARROW_TYPE_INT32, NULLABLE},
}));
StreamReader reader;
ASSERT_THAT(AdbcConnectionGetInfo(&connection, info, 1, &reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());

std::vector<uint32_t> seen;
while (true) {
ASSERT_NO_FATAL_FAILURE(reader.Next());
if (!reader.array->release) break;

for (int64_t row = 0; row < reader.array->length; row++) {
ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row));
const uint32_t code =
reader.array_view->children[0]->buffer_views[1].data.as_uint32[row];
seen.push_back(code);

switch (code) {
case ADBC_INFO_DRIVER_NAME:
case ADBC_INFO_DRIVER_VERSION:
case ADBC_INFO_VENDOR_NAME:
case ADBC_INFO_VENDOR_VERSION:
// UTF8
ASSERT_EQ(uint8_t(0),
reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]);
default:
// Ignored
break;
ASSERT_NO_FATAL_FAILURE(CompareSchema(
&reader.schema.value, {
{"info_name", NANOARROW_TYPE_UINT32, NOT_NULL},
{"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1],
{
{"string_value", NANOARROW_TYPE_STRING, NULLABLE},
{"bool_value", NANOARROW_TYPE_BOOL, NULLABLE},
{"int64_value", NANOARROW_TYPE_INT64, NULLABLE},
{"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE},
{"string_list", NANOARROW_TYPE_LIST, NULLABLE},
{"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4],
{
{"item", NANOARROW_TYPE_STRING, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1]->children[5],
{
{"entries", NANOARROW_TYPE_STRUCT, NOT_NULL},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1]->children[5]->children[0],
{
{"key", NANOARROW_TYPE_INT32, NOT_NULL},
{"value", NANOARROW_TYPE_LIST, NULLABLE},
}));
ASSERT_NO_FATAL_FAILURE(
CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1],
{
{"item", NANOARROW_TYPE_INT32, NULLABLE},
}));

std::vector<uint32_t> seen;
while (true) {
ASSERT_NO_FATAL_FAILURE(reader.Next());
if (!reader.array->release) break;

for (int64_t row = 0; row < reader.array->length; row++) {
ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row));
const uint32_t code =
reader.array_view->children[0]->buffer_views[1].data.as_uint32[row];
seen.push_back(code);

std::optional<SqlInfoValue> expected = quirks()->supports_get_sql_info(code);
ASSERT_TRUE(expected.has_value()) << "Got unexpected info code " << code;

uint8_t type_code =
reader.array_view->children[1]->buffer_views[0].data.as_uint8[row];
int32_t offset =
reader.array_view->children[1]->buffer_views[1].data.as_int32[row];
ASSERT_NO_FATAL_FAILURE(std::visit(
[&](auto&& expected_value) {
using T = std::decay_t<decltype(expected_value)>;
if constexpr (std::is_same_v<T, int64_t>) {
ASSERT_EQ(uint8_t(2), type_code);
ASSERT_EQ(expected_value,
ArrowArrayViewGetIntUnsafe(
reader.array_view->children[1]->children[2], offset));
} else if constexpr (std::is_same_v<T, std::string>) {
ASSERT_EQ(uint8_t(0), type_code);
struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(
reader.array_view->children[1]->children[0], offset);
ASSERT_EQ(expected_value,
std::string_view(static_cast<const char*>(view.data),
view.size_bytes));
} else {
static_assert(!sizeof(T), "not yet implemented");
}
},
*expected))
<< "code: " << type_code;
}
}
ASSERT_THAT(seen, ::testing::IsSupersetOf(info));
}
ASSERT_THAT(seen, ::testing::UnorderedElementsAreArray(info));
}

void ConnectionTest::TestMetadataGetTableSchema() {
Expand Down
8 changes: 8 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <optional>
#include <string>
#include <variant>
#include <vector>

#include <adbc.h>
Expand All @@ -31,6 +32,8 @@ namespace adbc_validation {
#define ADBCV_STRINGIFY(s) #s
#define ADBCV_STRINGIFY_VALUE(s) ADBCV_STRINGIFY(s)

using SqlInfoValue = std::variant<std::string, int64_t>;

/// \brief Configuration for driver-specific behavior.
class DriverQuirks {
public:
Expand Down Expand Up @@ -101,6 +104,11 @@ class DriverQuirks {
/// \brief Whether GetSqlInfo is implemented
virtual bool supports_get_sql_info() const { return true; }

/// \brief The expected value for a given info code
virtual std::optional<SqlInfoValue> supports_get_sql_info(uint32_t info_code) const {
return std::nullopt;
}

/// \brief Whether GetObjects is implemented
virtual bool supports_get_objects() const { return true; }

Expand Down
1 change: 1 addition & 0 deletions c/validation/adbc_validation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <nanoarrow/nanoarrow.h>

#include "common/utils.h"

namespace adbc_validation {
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ type ConnectionGetStatistics interface {
// statistic_name | utf8 not null
// statistic_key | int16 not null
//
GetStatisticNames() (array.RecordReader, error)
GetStatisticNames(ctx context.Context) (array.RecordReader, error)
}

// StatementExecuteSchema is a Statement that also supports ExecuteSchema.
Expand Down
8 changes: 3 additions & 5 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,7 @@ func (d *database) GetOptionDouble(key string) (float64, error) {
func (d *database) SetOption(key, value string) error {
// We can't change most options post-init
switch key {
case OptionTimeoutFetch:
fallthrough
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
return d.timeout.setTimeoutString(key, value)
}
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
Expand Down Expand Up @@ -1154,6 +1150,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
infoNameBldr.Append(uint32(adbc.InfoVendorVersion))
case flightsql.SqlInfoFlightSqlServerArrowVersion:
infoNameBldr.Append(uint32(adbc.InfoVendorArrowVersion))
default:
continue
}

infoValueBldr.Append(info.TypeCode(i))
Expand Down
Loading

0 comments on commit 04b5565

Please sign in to comment.