Skip to content

Commit

Permalink
feat(go/adbc): implement 1.1.0 features
Browse files Browse the repository at this point in the history
- ADBC_INFO_DRIVER_ADBC_VERSION
- StatementExecuteSchema (apache#318)
- ADBC_CONNECTION_OPTION_CURRENT_{CATALOG, DB_SCHEMA} (apache#319)
  • Loading branch information
lidavidm committed Jun 28, 2023
1 parent 5d6586b commit e5371a9
Show file tree
Hide file tree
Showing 8 changed files with 684 additions and 34 deletions.
7 changes: 5 additions & 2 deletions go/adbc/adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,10 @@ type StatementExecuteSchema interface {
ExecuteSchema(context.Context) (*arrow.Schema, error)
}

// GetSetOptions is a PostInitOptions that also supports getting and setting property values of different types.
// GetSetOptions is a PostInitOptions that also supports getting and setting option values of different types.
//
// GetOption functions should return an error with StatusNotFound for unsupported options.
// SetOption functions should return an error with StatusNotImplemented for unsupported options.
//
// Since ADBC API revision 1.1.0.
type GetSetOptions interface {
Expand All @@ -728,7 +731,7 @@ type GetSetOptions interface {
SetOptionBytes(key string, value []byte) error
SetOptionInt(key string, value int64) error
SetOptionDouble(key string, value float64) error
GetOption(key, value string) (string, error)
GetOption(key string) (string, error)
GetOptionBytes(key string) ([]byte, error)
GetOptionInt(key string) (int64, error)
GetOptionDouble(key string) (float64, error)
Expand Down
186 changes: 184 additions & 2 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func init() {
adbc.InfoDriverName,
adbc.InfoDriverVersion,
adbc.InfoDriverArrowVersion,
adbc.InfoDriverADBCVersion,
adbc.InfoVendorName,
adbc.InfoVendorVersion,
adbc.InfoVendorArrowVersion,
Expand Down Expand Up @@ -369,14 +370,63 @@ func (d *database) SetOptions(cnOptions map[string]string) error {
continue
}
return adbc.Error{
Msg: fmt.Sprintf("Unknown database option '%s'", key),
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusInvalidArgument,
}
}

return nil
}

func (d *database) GetOption(key string) (string, error) {
return "", adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotFound,
}
}
func (d *database) GetOptionBytes(key string) ([]byte, error) {
return nil, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotFound,
}
}
func (d *database) GetOptionInt(key string) (int64, error) {
return 0, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotFound,
}
}
func (d *database) GetOptionDouble(key string) (float64, error) {
return 0, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotFound,
}
}
func (d *database) SetOption(key, value string) error {
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotImplemented,
}
}
func (d *database) SetOptionBytes(key string, value []byte) error {
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotImplemented,
}
}
func (d *database) SetOptionInt(key string, value int64) error {
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotImplemented,
}
}
func (d *database) SetOptionDouble(key string, value float64) error {
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key),
Code: adbc.StatusNotImplemented,
}
}

type timeoutOption struct {
grpc.EmptyCallOption

Expand Down Expand Up @@ -729,6 +779,94 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd
return nil, err
}

func (c *cnxn) GetOption(key string) (string, error) {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
headers := c.hdrs.Get(name)
if len(headers) > 0 {
return headers[0], nil
}
return "", adbc.Error{
Msg: "[Flight SQL] unknown header",
Code: adbc.StatusNotFound,
}
}

switch key {
case OptionTimeoutFetch:
return c.timeouts.fetchTimeout.String(), nil
case OptionTimeoutQuery:
return c.timeouts.queryTimeout.String(), nil
case OptionTimeoutUpdate:
return c.timeouts.updateTimeout.String(), nil
case adbc.OptionKeyAutoCommit:
if c.txn != nil {
// No autocommit
return adbc.OptionValueDisabled, nil
} else {
// Autocommit
return adbc.OptionValueEnabled, nil
}
case adbc.OptionKeyCurrentCatalog:
return "", adbc.Error{
Msg: "[Flight SQL] current catalog not supported",
Code: adbc.StatusNotFound,
}

case adbc.OptionKeyCurrentDbSchema:
return "", adbc.Error{
Msg: "[Flight SQL] current schema not supported",
Code: adbc.StatusNotFound,
}
}

return "", adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotFound,
}
}

func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
return nil, adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotFound,
}
}

func (c *cnxn) GetOptionInt(key string) (int64, error) {
switch key {
case OptionTimeoutFetch:
return int64(c.timeouts.fetchTimeout.Seconds()), nil
case OptionTimeoutQuery:
return int64(c.timeouts.queryTimeout.Seconds()), nil
case OptionTimeoutUpdate:
return int64(c.timeouts.updateTimeout.Seconds()), nil
case adbc.OptionKeyAutoCommit:
}

return 0, adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotFound,
}
}

func (c *cnxn) GetOptionDouble(key string) (float64, error) {
switch key {
case OptionTimeoutFetch:
return c.timeouts.fetchTimeout.Seconds(), nil
case OptionTimeoutQuery:
return c.timeouts.queryTimeout.Seconds(), nil
case OptionTimeoutUpdate:
return c.timeouts.updateTimeout.Seconds(), nil
case adbc.OptionKeyAutoCommit:
}

return 0.0, adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotFound,
}
}

func (c *cnxn) SetOption(key, value string) error {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
Expand Down Expand Up @@ -775,6 +913,7 @@ func (c *cnxn) SetOption(key, value string) error {
autocommit := true
switch value {
case adbc.OptionValueEnabled:
autocommit = true
case adbc.OptionValueDisabled:
autocommit = false
default:
Expand Down Expand Up @@ -827,6 +966,27 @@ func (c *cnxn) SetOption(key, value string) error {
return nil
}

func (c *cnxn) SetOptionBytes(key string, value []byte) error {
return adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotImplemented,
}
}

func (c *cnxn) SetOptionInt(key string, value int64) error {
return adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotImplemented,
}
}

func (c *cnxn) SetOptionDouble(key string, value float64) error {
return adbc.Error{
Msg: "[Flight SQL] unknown connection option",
Code: adbc.StatusNotImplemented,
}
}

// GetInfo returns metadata about the database/driver.
//
// The result is an Arrow dataset with the following schema:
Expand All @@ -853,6 +1013,7 @@ func (c *cnxn) SetOption(key, value string) error {
// codes (the row will be omitted from the result).
func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) {
const strValTypeID arrow.UnionTypeCode = 0
const intValTypeID arrow.UnionTypeCode = 2

if len(infoCodes) == 0 {
infoCodes = infoSupportedCodes
Expand All @@ -864,7 +1025,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re

infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder)
strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder)
intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder)

translated := make([]flightsql.SqlInfo, 0, len(infoCodes))
for _, code := range infoCodes {
Expand All @@ -886,6 +1048,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(strValTypeID)
strInfoBldr.Append(infoDriverArrowVersion)
case adbc.InfoDriverADBCVersion:
infoNameBldr.Append(uint32(code))
infoValueBldr.Append(intValTypeID)
intInfoBldr.Append(adbc.AdbcVersion1_1_0)
}
}

Expand Down Expand Up @@ -1350,6 +1516,14 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio
return c.cl.Execute(ctx, query, opts...)
}

func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
if c.txn != nil {
return c.txn.GetExecuteSchema(ctx, query, opts...)
}

return c.cl.GetExecuteSchema(ctx, query, opts...)
}

func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
if c.txn != nil {
return c.txn.ExecuteSubstrait(ctx, plan, opts...)
Expand All @@ -1358,6 +1532,14 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla
return c.cl.ExecuteSubstrait(ctx, plan, opts...)
}

func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
if c.txn != nil {
return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...)
}

return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...)
}

func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) {
if c.txn != nil {
return c.txn.ExecuteUpdate(ctx, query, opts...)
Expand Down
94 changes: 94 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ func TestAuthn(t *testing.T) {
suite.Run(t, &AuthnTests{})
}

func TestExecuteSchema(t *testing.T) {
suite.Run(t, &ExecuteSchemaTests{})
}

func TestTimeout(t *testing.T) {
suite.Run(t, &TimeoutTests{})
}
Expand Down Expand Up @@ -202,6 +206,96 @@ func (suite *AuthnTests) TestBearerTokenUpdated() {
defer reader.Release()
}

// ---- ExecuteSchema Tests --------------------

type ExecuteSchemaTestServer struct {
flightsql.BaseServer
}

func (srv *ExecuteSchemaTestServer) GetSchemaStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) {
if query.GetQuery() == "sample query" {
return &flight.SchemaResult{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{
{Name: "ints", Type: arrow.PrimitiveTypes.Int32},
}, nil), srv.Alloc),
}, nil
}
return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented")
}

func (srv *ExecuteSchemaTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) {
if req.GetQuery() == "sample query" {
return flightsql.ActionCreatePreparedStatementResult{
DatasetSchema: arrow.NewSchema([]arrow.Field{
{Name: "ints", Type: arrow.PrimitiveTypes.Int32},
}, nil),
}, nil
}
return flightsql.ActionCreatePreparedStatementResult{}, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented")
}

type ExecuteSchemaTests struct {
ServerBasedTests
}

func (suite *ExecuteSchemaTests) SetupSuite() {
srv := ExecuteSchemaTestServer{}
srv.Alloc = memory.DefaultAllocator
suite.DoSetupSuite(&srv, nil, nil)
}

func (ts *ExecuteSchemaTests) TestNoQuery() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

es := stmt.(adbc.StatementExecuteSchema)
_, err = es.ExecuteSchema(context.Background())

var adbcErr adbc.Error
ts.ErrorAs(err, &adbcErr)
ts.Equal(adbc.StatusInvalidState, adbcErr.Code, adbcErr.Error())
}

func (ts *ExecuteSchemaTests) TestPreparedQuery() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

ts.NoError(stmt.SetSqlQuery("sample query"))
ts.NoError(stmt.Prepare(context.Background()))

es := stmt.(adbc.StatementExecuteSchema)
schema, err := es.ExecuteSchema(context.Background())
ts.NoError(err)
ts.NotNil(schema)

expectedSchema := arrow.NewSchema([]arrow.Field{
{Name: "ints", Type: arrow.PrimitiveTypes.Int32},
}, nil)

ts.True(expectedSchema.Equal(schema), schema.String())
}

func (ts *ExecuteSchemaTests) TestQuery() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

ts.NoError(stmt.SetSqlQuery("sample query"))

es := stmt.(adbc.StatementExecuteSchema)
schema, err := es.ExecuteSchema(context.Background())
ts.NoError(err)
ts.NotNil(schema)

expectedSchema := arrow.NewSchema([]arrow.Field{
{Name: "ints", Type: arrow.PrimitiveTypes.Int32},
}, nil)

ts.True(expectedSchema.Equal(schema), schema.String())
}

// ---- Timeout Tests --------------------

type TimeoutTestServer struct {
Expand Down
Loading

0 comments on commit e5371a9

Please sign in to comment.