diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 32d24a07ef..09adb7c321 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -315,6 +315,7 @@ jobs: popd - name: Go Test env: + SNOWFLAKE_DATABASE: ADBC_TESTING SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} run: | ./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" diff --git a/adbc.h b/adbc.h index 122b0605ae..ab066483de 100644 --- a/adbc.h +++ b/adbc.h @@ -1612,7 +1612,7 @@ AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, /// | Field Name | Field Type | /// |--------------------------|----------------------------------| /// | db_schema_name | utf8 | -/// | db_schema_functions | list | +/// | db_schema_statistics | list | /// /// STATISTICS_SCHEMA is a Struct with fields: /// diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index 99a4f81b75..c297608cec 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -58,13 +58,15 @@ type Error struct { // SqlState is a SQLSTATE error code, if provided, as defined // by the SQL:2003 standard. If not set, it will be "\0\0\0\0\0" SqlState [5]byte - // Details is an array of additional driver-specific binary error details. + // Details is an array of additional driver-specific error details. // // This allows drivers to return custom, structured error information (for // example, JSON or Protocol Buffers) that can be optionally parsed by // clients, beyond the standard Error fields, without having to encode it in - // the error message. The encoding of the data is driver-defined. - Details [][]byte + // the error message. The encoding of the data is driver-defined. It is + // suggested to use proto.Message for Protocol Buffers and error for wrapped + // errors. + Details []interface{} } func (e Error) Error() string { @@ -621,23 +623,6 @@ type Statement interface { ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, error) } -// Cancellable is a Connection or Statement that also supports Cancel. -// -// Since ADBC API revision 1.1.0. -type Cancellable interface { - // Cancel stops execution of an in-progress query. - // - // This can be called during ExecuteQuery, GetObjects, or other - // methods that produce result sets, or while consuming a - // RecordReader returned from such. Calling this function should - // make the other functions return an error with a StatusCancelled - // code. - // - // This must always be thread-safe (other operations are not - // necessarily thread-safe). - Cancel() error -} - // ConnectionGetStatistics is a Connection that supports getting // statistics on data in the database. // @@ -657,7 +642,7 @@ type ConnectionGetStatistics interface { // Field Name | Field Type // -------------------------|---------------------------------- // db_schema_name | utf8 - // db_schema_functions | list + // db_schema_statistics | list // // STATISTICS_SCHEMA is a Struct with fields: // @@ -684,7 +669,6 @@ type ConnectionGetStatistics interface { // int64 | int64 // uint64 | uint64 // float64 | float64 - // decimal256 | decimal256 // binary | binary // // For the parameters: If nil is passed, then that parameter will not @@ -719,7 +703,10 @@ type StatementExecuteSchema interface { ExecuteSchema(context.Context) (*arrow.Schema, error) } -// GetSetOptions is a PostInitOptions that also supports getting and setting property values of different types. +// GetSetOptions is a PostInitOptions that also supports getting and setting option values of different types. +// +// GetOption functions should return an error with StatusNotFound for unsupported options. +// SetOption functions should return an error with StatusNotImplemented for unsupported options. // // Since ADBC API revision 1.1.0. type GetSetOptions interface { @@ -728,7 +715,7 @@ type GetSetOptions interface { SetOptionBytes(key string, value []byte) error SetOptionInt(key string, value int64) error SetOptionDouble(key string, value float64) error - GetOption(key, value string) (string, error) + GetOption(key string) (string, error) GetOptionBytes(key string) ([]byte, error) GetOptionInt(key string) (int64, error) GetOptionDouble(key string) (float64, error) diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index e038354cc5..00d123b322 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -36,7 +36,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "io" "math" @@ -119,20 +118,13 @@ func init() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, } } -func getTimeoutOptionValue(v string) (time.Duration, error) { - timeout, err := strconv.ParseFloat(v, 64) - if math.IsNaN(timeout) || math.IsInf(timeout, 0) || timeout < 0 { - return 0, errors.New("timeout must be positive and finite") - } - return time.Duration(timeout * float64(time.Second)), err -} - type Driver struct { Alloc memory.Allocator } @@ -164,6 +156,8 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) { db.dialOpts.block = false db.dialOpts.maxMsgSize = 16 * 1024 * 1024 + db.options = make(map[string]string) + return db, db.SetOptions(opts) } @@ -192,6 +186,7 @@ type database struct { timeout timeoutOption dialOpts dbDialOpts enableCookies bool + options map[string]string alloc memory.Allocator } @@ -199,6 +194,10 @@ type database struct { func (d *database) SetOptions(cnOptions map[string]string) error { var tlsConfig tls.Config + for k, v := range cnOptions { + d.options[k] = v + } + mtlsCert := cnOptions[OptionMTLSCertChain] mtlsKey := cnOptions[OptionMTLSPrivateKey] switch { @@ -287,33 +286,24 @@ func (d *database) SetOptions(cnOptions map[string]string) error { var err error if tv, ok := cnOptions[OptionTimeoutFetch]; ok { - if d.timeout.fetchTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutFetch) } if tv, ok := cnOptions[OptionTimeoutQuery]; ok { - if d.timeout.queryTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutQuery, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutQuery) } if tv, ok := cnOptions[OptionTimeoutUpdate]; ok { - if d.timeout.updateTimeout, err = getTimeoutOptionValue(tv); err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutUpdate, tv, err.Error()), - Code: adbc.StatusInvalidArgument, - } + if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv); err != nil { + return err } + delete(cnOptions, OptionTimeoutUpdate) } if val, ok := cnOptions[OptionWithBlock]; ok { @@ -369,7 +359,7 @@ func (d *database) SetOptions(cnOptions map[string]string) error { continue } return adbc.Error{ - Msg: fmt.Sprintf("Unknown database option '%s'", key), + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), Code: adbc.StatusInvalidArgument, } } @@ -377,6 +367,118 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) GetOption(key string) (string, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.String(), nil + } + if val, ok := d.options[key]; ok { + return val, nil + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := d.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return d.timeout.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return d.timeout.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return d.timeout.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) SetOption(key, value string) error { + // We can't change most options post-init + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeoutString(key, value) + } + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix), value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} +func (d *database) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return d.timeout.setTimeout(key, value) + } + + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + type timeoutOption struct { grpc.EmptyCallOption @@ -388,6 +490,45 @@ type timeoutOption struct { updateTimeout time.Duration } +func (t *timeoutOption) setTimeout(key string, value float64) error { + if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %f: timeouts must be non-negative and finite", + key, value), + Code: adbc.StatusInvalidArgument, + } + } + + timeout := time.Duration(value * float64(time.Second)) + + switch key { + case OptionTimeoutFetch: + t.fetchTimeout = timeout + case OptionTimeoutQuery: + t.queryTimeout = timeout + case OptionTimeoutUpdate: + t.updateTimeout = timeout + default: + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key), + Code: adbc.StatusNotImplemented, + } + } + return nil +} + +func (t *timeoutOption) setTimeoutString(key string, value string) error { + timeout, err := strconv.ParseFloat(value, 64) + if err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid timeout option value %s = %s: %s", + key, value, err.Error()), + Code: adbc.StatusInvalidArgument, + } + } + return t.setTimeout(key, timeout) +} + func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) { for _, opt := range callOptions { if to, ok := opt.(timeoutOption); ok { @@ -729,6 +870,96 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } +func (c *cnxn) GetOption(key string) (string, error) { + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + headers := c.hdrs.Get(name) + if len(headers) > 0 { + return headers[0], nil + } + return "", adbc.Error{ + Msg: "[Flight SQL] unknown header", + Code: adbc.StatusNotFound, + } + } + + switch key { + case OptionTimeoutFetch: + return c.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return c.timeouts.updateTimeout.String(), nil + case adbc.OptionKeyAutoCommit: + if c.txn != nil { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return "", adbc.Error{ + Msg: "[Flight SQL] current catalog not supported", + Code: adbc.StatusNotFound, + } + + case adbc.OptionKeyCurrentDbSchema: + return "", adbc.Error{ + Msg: "[Flight SQL] current schema not supported", + Code: adbc.StatusNotFound, + } + } + + return "", adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := c.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return c.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return c.timeouts.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return c.timeouts.updateTimeout.Seconds(), nil + } + + return 0.0, adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotFound, + } +} + func (c *cnxn) SetOption(key, value string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) @@ -742,39 +973,16 @@ func (c *cnxn) SetOption(key, value string) error { switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - c.timeouts.fetchTimeout = timeout + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - c.timeouts.queryTimeout = timeout + fallthrough case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(value) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, value, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - c.timeouts.updateTimeout = timeout + return c.timeouts.setTimeoutString(key, value) case adbc.OptionKeyAutoCommit: autocommit := true switch value { case adbc.OptionValueEnabled: + autocommit = true case adbc.OptionValueDisabled: autocommit = false default: @@ -827,6 +1035,45 @@ func (c *cnxn) SetOption(key, value string) error { return nil } +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, float64(value)) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return c.timeouts.setTimeout(key, value) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + // GetInfo returns metadata about the database/driver. // // The result is an Arrow dataset with the following schema: @@ -853,6 +1100,7 @@ func (c *cnxn) SetOption(key, value string) error { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -864,7 +1112,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { @@ -886,6 +1135,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) } } @@ -1350,6 +1603,14 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Execute(ctx, query, opts...) } +func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSchema(ctx, query, opts...) + } + + return c.cl.GetExecuteSchema(ctx, query, opts...) +} + func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) @@ -1358,6 +1619,14 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla return c.cl.ExecuteSubstrait(ctx, plan, opts...) } +func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + if c.txn != nil { + return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) + } + + return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) +} + func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 9d959ac4c6..50f1d9b1f0 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -42,6 +42,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) // ---- Common Infra -------------------- @@ -95,6 +96,14 @@ func TestAuthn(t *testing.T) { suite.Run(t, &AuthnTests{}) } +func TestErrorDetails(t *testing.T) { + suite.Run(t, &ErrorDetailsTests{}) +} + +func TestExecuteSchema(t *testing.T) { + suite.Run(t, &ExecuteSchemaTests{}) +} + func TestTimeout(t *testing.T) { suite.Run(t, &TimeoutTests{}) } @@ -202,6 +211,196 @@ func (suite *AuthnTests) TestBearerTokenUpdated() { defer reader.Release() } +// ---- Error Details Tests -------------------- + +type ErrorDetailsTestServer struct { + flightsql.BaseServer +} + +func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + if query.GetQuery() == "details" { + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, err + } + return nil, st.Err() + } else if query.GetQuery() == "query" { + tkt, err := flightsql.CreateStatementQueryTicket([]byte("fetch")) + if err != nil { + panic(err) + } + return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (ts *ErrorDetailsTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { + sc := arrow.NewSchema([]arrow.Field{}, nil) + detail := wrapperspb.Int32Value{Value: 42} + st, err := status.New(codes.Unknown, "details").WithDetails(&detail) + if err != nil { + return nil, nil, err + } + + ch := make(chan flight.StreamChunk) + go func() { + defer close(ch) + ch <- flight.StreamChunk{ + Data: nil, + Desc: nil, + Err: st.Err(), + } + }() + return sc, ch, nil +} + +type ErrorDetailsTests struct { + ServerBasedTests +} + +func (suite *ErrorDetailsTests) SetupSuite() { + srv := ErrorDetailsTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ErrorDetailsTests) TestGetFlightInfo() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("details")) + + _, _, err = stmt.ExecuteQuery(context.Background()) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + + ts.Equal(1, len(adbcErr.Details)) + + message, ok := adbcErr.Details[0].(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +func (ts *ErrorDetailsTests) TestDoGet() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("query")) + + reader, _, err := stmt.ExecuteQuery(context.Background()) + ts.NoError(err) + + defer reader.Release() + + for reader.Next() { + } + err = reader.Err() + + ts.Error(err) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr, "Error was: %#v", err) + + ts.Equal(1, len(adbcErr.Details)) + + message, ok := adbcErr.Details[0].(*wrapperspb.Int32Value) + ts.True(ok, "Got message: %#v", message) + ts.Equal(int32(42), message.Value) +} + +// ---- ExecuteSchema Tests -------------------- + +type ExecuteSchemaTestServer struct { + flightsql.BaseServer +} + +func (srv *ExecuteSchemaTestServer) GetSchemaStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) { + if query.GetQuery() == "sample query" { + return &flight.SchemaResult{ + Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), srv.Alloc), + }, nil + } + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + +func (srv *ExecuteSchemaTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (res flightsql.ActionCreatePreparedStatementResult, err error) { + if req.GetQuery() == "sample query" { + return flightsql.ActionCreatePreparedStatementResult{ + DatasetSchema: arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil), + }, nil + } + return flightsql.ActionCreatePreparedStatementResult{}, status.Error(codes.Unimplemented, "CreatePreparedStatement not implemented") +} + +type ExecuteSchemaTests struct { + ServerBasedTests +} + +func (suite *ExecuteSchemaTests) SetupSuite() { + srv := ExecuteSchemaTestServer{} + srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(&srv, nil, nil) +} + +func (ts *ExecuteSchemaTests) TestNoQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + es := stmt.(adbc.StatementExecuteSchema) + _, err = es.ExecuteSchema(context.Background()) + + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + ts.Equal(adbc.StatusInvalidState, adbcErr.Code, adbcErr.Error()) +} + +func (ts *ExecuteSchemaTests) TestPreparedQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + ts.NoError(stmt.Prepare(context.Background())) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + +func (ts *ExecuteSchemaTests) TestQuery() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("sample query")) + + es := stmt.(adbc.StatementExecuteSchema) + schema, err := es.ExecuteSchema(context.Background()) + ts.NoError(err) + ts.NotNil(schema) + + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "ints", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + ts.True(expectedSchema.Equal(schema), schema.String()) +} + // ---- Timeout Tests -------------------- type TimeoutTestServer struct { @@ -321,6 +520,67 @@ func (ts *TimeoutTests) TestRemoveTimeout() { } } +func (ts *TimeoutTests) TestGetSet() { + keys := []string{ + "adbc.flight.sql.rpc.timeout_seconds.fetch", + "adbc.flight.sql.rpc.timeout_seconds.query", + "adbc.flight.sql.rpc.timeout_seconds.update", + } + stmt, err := ts.cnxn.NewStatement() + ts.Require().NoError(err) + defer stmt.Close() + + for _, v := range []interface{}{ts.db, ts.cnxn, stmt} { + getset := v.(adbc.GetSetOptions) + + for _, k := range keys { + strval, err := getset.GetOption(k) + ts.NoError(err) + ts.Equal("0s", strval) + + intval, err := getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(0), intval) + + floatval, err := getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.0, floatval) + + err = getset.SetOptionInt(k, 1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("1s", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + ts.Equal(int64(1), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(1.0, floatval) + + err = getset.SetOptionDouble(k, 0.1) + ts.NoError(err) + + strval, err = getset.GetOption(k) + ts.NoError(err) + ts.Equal("100ms", strval) + + intval, err = getset.GetOptionInt(k) + ts.NoError(err) + // truncated + ts.Equal(int64(0), intval) + + floatval, err = getset.GetOptionDouble(k) + ts.NoError(err) + ts.Equal(0.1, floatval) + } + } + +} + func (ts *TimeoutTests) TestDoActionTimeout() { ts.NoError(ts.cnxn.(adbc.PostInitOptions). SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1")) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 53dbac2412..83313e243e 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -229,14 +229,20 @@ func (s *FlightSQLQuirks) DropTable(cnxn adbc.Connection, tblname string) error return err } -func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } -func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } -func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) Alloc() memory.Allocator { return s.mem } +func (s *FlightSQLQuirks) BindParameter(_ int) string { return "?" } +func (s *FlightSQLQuirks) SupportsBulkIngest(string) bool { return false } +func (s *FlightSQLQuirks) SupportsConcurrentStatements() bool { return true } +func (s *FlightSQLQuirks) SupportsCurrentCatalogSchema() bool { return false } + +// The driver supports it, but the server we use for testing does not. +func (s *FlightSQLQuirks) SupportsExecuteSchema() bool { return false } +func (s *FlightSQLQuirks) SupportsGetSetOptions() bool { return true } func (s *FlightSQLQuirks) SupportsPartitionedData() bool { return true } +func (s *FlightSQLQuirks) SupportsStatistics() bool { return false } func (s *FlightSQLQuirks) SupportsTransactions() bool { return true } func (s *FlightSQLQuirks) SupportsGetParameterSchema() bool { return false } func (s *FlightSQLQuirks) SupportsDynamicParameterBinding() bool { return true } -func (s *FlightSQLQuirks) SupportsBulkIngest() bool { return false } func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { case adbc.InfoDriverName: @@ -247,6 +253,8 @@ func (s *FlightSQLQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "db_name" case adbc.InfoVendorVersion: @@ -273,6 +281,7 @@ func (s *FlightSQLQuirks) SampleTableSchemaMetadata(tblName string, dt arrow.Dat } } +func (s *FlightSQLQuirks) Catalog() string { return "" } func (s *FlightSQLQuirks) DBSchema() string { return "" } func TestADBCFlightSQL(t *testing.T) { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index c7f074a800..04f46498f4 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -73,6 +73,29 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C } } +func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { + var ( + res *flight.SchemaResult + err error + ) + if s.sqlQuery != "" { + res, err = cnxn.executeSchema(ctx, s.sqlQuery, opts...) + } else if s.substraitPlan != nil { + res, err = cnxn.executeSubstraitSchema(ctx, flightsql.SubstraitPlan{Plan: s.substraitPlan, Version: s.substraitVersion}, opts...) + } else { + return nil, adbc.Error{ + Code: adbc.StatusInvalidState, + Msg: "[Flight SQL Statement] cannot call ExecuteQuery without a query or prepared statement", + } + } + + if err != nil { + return nil, err + } + + return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) +} + func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) @@ -138,6 +161,72 @@ func (s *statement) Close() (err error) { return err } +func (s *statement) GetOption(key string) (string, error) { + switch key { + case OptionStatementSubstraitVersion: + return s.query.substraitVersion, nil + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.String(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.String(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.String(), nil + } + + if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { + name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) + values := s.hdrs.Get(name) + if len(values) > 0 { + return values[0], nil + } + } + + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + val, err := s.GetOptionDouble(key) + if err != nil { + return 0, err + } + return int64(val), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (s *statement) GetOptionDouble(key string) (float64, error) { + switch key { + case OptionTimeoutFetch: + return s.timeouts.fetchTimeout.Seconds(), nil + case OptionTimeoutQuery: + return s.timeouts.queryTimeout.Seconds(), nil + case OptionTimeoutUpdate: + return s.timeouts.updateTimeout.Seconds(), nil + } + + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (s *statement) SetOption(key string, val string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { @@ -152,35 +241,11 @@ func (s *statement) SetOption(key string, val string) error { switch key { case OptionTimeoutFetch: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.fetchTimeout = timeout + fallthrough case OptionTimeoutQuery: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.queryTimeout = timeout + fallthrough case OptionTimeoutUpdate: - timeout, err := getTimeoutOptionValue(val) - if err != nil { - return adbc.Error{ - Msg: fmt.Sprintf("invalid timeout option value %s = %s : %s", - OptionTimeoutFetch, val, err.Error()), - Code: adbc.StatusInvalidArgument, - } - } - s.timeouts.updateTimeout = timeout + return s.timeouts.setTimeoutString(key, val) case OptionStatementQueueSize: var err error var size int @@ -189,13 +254,8 @@ func (s *statement) SetOption(key string, val string) error { Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), Code: adbc.StatusInvalidArgument, } - } else if size <= 0 { - return adbc.Error{ - Msg: fmt.Sprintf("Invalid value for statement option '%s': '%s' is not a positive integer", OptionStatementQueueSize, val), - Code: adbc.StatusInvalidArgument, - } } - s.queueSize = size + return s.SetOptionInt(key, int64(size)) case OptionStatementSubstraitVersion: s.query.substraitVersion = val default: @@ -207,6 +267,43 @@ func (s *statement) SetOption(key string, val string) error { return nil } +func (s *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (s *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + s.queueSize = int(value) + return nil + } + return s.SetOptionDouble(key, float64(value)) +} + +func (s *statement) SetOptionDouble(key string, value float64) error { + switch key { + case OptionTimeoutFetch: + fallthrough + case OptionTimeoutQuery: + fallthrough + case OptionTimeoutUpdate: + return s.timeouts.setTimeout(key, value) + } + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -422,3 +519,21 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. return sc, out, info.TotalRecords, nil } + +// ExecuteSchema gets the schema of the result set of a query without executing it. +func (s *statement) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { + ctx = metadata.NewOutgoingContext(ctx, s.hdrs) + + if s.prepared != nil { + schema = s.prepared.DatasetSchema() + if schema == nil { + err = adbc.Error{ + Msg: "[Flight SQL Statement] Database server did not provide schema for prepared statement", + Code: adbc.StatusNotImplemented, + } + } + return + } + + return s.query.executeSchema(ctx, s.cnxn, s.timeouts) +} diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go index 409ce58e61..297d35f8dc 100644 --- a/go/adbc/driver/flightsql/record_reader.go +++ b/go/adbc/driver/flightsql/record_reader.go @@ -104,7 +104,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rec.Retain() ch <- rec } - return rdr.Err() + return adbcFromFlightStatus(rdr.Err()) }) endpoints = endpoints[1:] @@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rdr, err := doGet(ctx, cl, endpoint, clCache, opts...) if err != nil { - return err + return adbcFromFlightStatus(err) } defer rdr.Release() @@ -150,7 +150,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. chs[endpointIndex] <- rec } - return rdr.Err() + return adbcFromFlightStatus(rdr.Err()) }) } diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go index 4f1f165c6b..fb7f323150 100644 --- a/go/adbc/driver/flightsql/utils.go +++ b/go/adbc/driver/flightsql/utils.go @@ -29,7 +29,9 @@ func adbcFromFlightStatus(err error) error { } var adbcCode adbc.Status - switch status.Code(err) { + // If not a status.Status, will return codes.Unknown + grpcStatus := status.Convert(err) + switch grpcStatus.Code() { case codes.OK: return nil case codes.Canceled: @@ -71,5 +73,7 @@ func adbcFromFlightStatus(err error) error { return adbc.Error{ Msg: err.Error(), Code: adbcCode, + // slice of proto.Message or error + Details: grpcStatus.Details(), } } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 8f965597c5..c321e77a60 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -22,6 +22,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "io" "strconv" "strings" "time" @@ -95,6 +96,7 @@ type cnxn struct { // codes (the row will be omitted from the result). func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { const strValTypeID arrow.UnionTypeCode = 0 + const intValTypeID arrow.UnionTypeCode = 2 if len(infoCodes) == 0 { infoCodes = infoSupportedCodes @@ -106,7 +108,8 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr := bldr.Field(0).(*array.Uint32Builder) infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(0).(*array.StringBuilder) + strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) for _, code := range infoCodes { switch code { @@ -122,6 +125,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) strInfoBldr.Append(infoDriverArrowVersion) + case adbc.InfoDriverADBCVersion: + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(intValTypeID) + intInfoBldr.Append(adbc.AdbcVersion1_1_0) case adbc.InfoVendorName: infoNameBldr.Append(uint32(code)) infoValueBldr.Append(strValTypeID) @@ -674,6 +681,85 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie return } +func (c *cnxn) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyAutoCommit: + if c.activeTransaction { + // No autocommit + return adbc.OptionValueDisabled, nil + } else { + // Autocommit + return adbc.OptionValueEnabled, nil + } + case adbc.OptionKeyCurrentCatalog: + return c.getStringQuery("SELECT CURRENT_DATABASE()") + case adbc.OptionKeyCurrentDbSchema: + return c.getStringQuery("SELECT CURRENT_SCHEMA()") + } + + return "", adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) getStringQuery(query string) (string, error) { + result, err := c.cn.QueryContext(context.Background(), query, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + defer result.Close() + + if len(result.Columns()) != 1 { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong number of columns: %s", result.Columns()), + Code: adbc.StatusInternal, + } + } + + dest := make([]driver.Value, 1) + err = result.Next(dest) + if err == io.EOF { + return "", adbc.Error{ + Msg: "[Snowflake] Internal query returned no rows", + Code: adbc.StatusInternal, + } + } else if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + value, ok := dest[0].(string) + if !ok { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Internal query returned wrong type of value: %s", dest[0]), + Code: adbc.StatusInternal, + } + } + + return value, nil +} + +func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + +func (c *cnxn) GetOptionDouble(key string) (float64, error) { + return 0.0, adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotFound, + } +} + func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { tblParts := make([]string, 0, 3) if catalog != nil { @@ -840,6 +926,12 @@ func (c *cnxn) SetOption(key, value string) error { Code: adbc.StatusInvalidArgument, } } + case adbc.OptionKeyCurrentCatalog: + _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) + return err + case adbc.OptionKeyCurrentDbSchema: + _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) + return err default: return adbc.Error{ Msg: "[Snowflake] unknown connection option " + key + ": " + value, @@ -847,3 +939,24 @@ func (c *cnxn) SetOption(key, value string) error { } } } + +func (c *cnxn) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} + +func (c *cnxn) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: "[Snowflake] unknown connection option", + Code: adbc.StatusNotImplemented, + } +} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index c02b58ddec..a00513817b 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -209,6 +209,105 @@ type database struct { alloc memory.Allocator } +func (d *database) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyUsername: + return d.cfg.User, nil + case adbc.OptionKeyPassword: + return d.cfg.Password, nil + case OptionDatabase: + return d.cfg.Database, nil + case OptionSchema: + return d.cfg.Schema, nil + case OptionWarehouse: + return d.cfg.Warehouse, nil + case OptionRole: + return d.cfg.Role, nil + case OptionRegion: + return d.cfg.Region, nil + case OptionAccount: + return d.cfg.Account, nil + case OptionProtocol: + return d.cfg.Protocol, nil + case OptionHost: + return d.cfg.Host, nil + case OptionPort: + return strconv.Itoa(d.cfg.Port), nil + case OptionAuthType: + return d.cfg.Authenticator.String(), nil + case OptionLoginTimeout: + return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil + case OptionRequestTimeout: + return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil + case OptionJwtExpireTimeout: + return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil + case OptionClientTimeout: + return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil + case OptionApplicationName: + return d.cfg.Application, nil + case OptionSSLSkipVerify: + if d.cfg.InsecureMode { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionOCSPFailOpenMode: + return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil + case OptionAuthToken: + return d.cfg.Token, nil + case OptionAuthOktaUrl: + return d.cfg.OktaURL.String(), nil + case OptionKeepSessionAlive: + if d.cfg.KeepSessionAlive { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionDisableTelemetry: + if d.cfg.DisableTelemetry { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientRequestMFAToken: + if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionClientStoreTempCred: + if d.cfg.ClientStoreTemporaryCredential == gosnowflake.ConfigBoolTrue { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + case OptionLogTracing: + return d.cfg.Tracing, nil + default: + val, ok := d.cfg.Params[key] + if ok { + return *val, nil + } + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionInt(key string) (int64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (d *database) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotFound, + } +} + func (d *database) SetOptions(cnOptions map[string]string) error { uri, ok := cnOptions[adbc.OptionKeyURI] if ok { @@ -421,6 +520,35 @@ func (d *database) SetOptions(cnOptions map[string]string) error { return nil } +func (d *database) SetOption(key string, val string) error { + // Can't set options after init + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionInt(key string, value int64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (d *database) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + func (d *database) Open(ctx context.Context) (adbc.Connection, error) { connector := gosnowflake.NewConnector(drv, *d.cfg) diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 7ac1f27f84..ca3800ffaf 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -38,10 +38,11 @@ import ( ) type SnowflakeQuirks struct { - dsn string - mem *memory.CheckedAllocator - connector gosnowflake.Connector - schemaName string + dsn string + mem *memory.CheckedAllocator + connector gosnowflake.Connector + catalogName string + schemaName string } func (s *SnowflakeQuirks) SetupDriver(t *testing.T) adbc.Driver { @@ -180,12 +181,17 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error func (s *SnowflakeQuirks) Alloc() memory.Allocator { return s.mem } func (s *SnowflakeQuirks) BindParameter(_ int) string { return "?" } +func (s *SnowflakeQuirks) SupportsBulkIngest(string) bool { return true } func (s *SnowflakeQuirks) SupportsConcurrentStatements() bool { return true } +func (s *SnowflakeQuirks) SupportsCurrentCatalogSchema() bool { return true } +func (s *SnowflakeQuirks) SupportsExecuteSchema() bool { return false } +func (s *SnowflakeQuirks) SupportsGetSetOptions() bool { return true } func (s *SnowflakeQuirks) SupportsPartitionedData() bool { return false } +func (s *SnowflakeQuirks) SupportsStatistics() bool { return true } func (s *SnowflakeQuirks) SupportsTransactions() bool { return true } func (s *SnowflakeQuirks) SupportsGetParameterSchema() bool { return false } func (s *SnowflakeQuirks) SupportsDynamicParameterBinding() bool { return false } -func (s *SnowflakeQuirks) SupportsBulkIngest() bool { return true } +func (s *SnowflakeQuirks) Catalog() string { return s.catalogName } func (s *SnowflakeQuirks) DBSchema() string { return s.schemaName } func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { switch code { @@ -197,6 +203,8 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoDriverADBCVersion: + return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: return "Snowflake" } @@ -225,7 +233,7 @@ func createTempSchema(uri string) string { } defer db.Close() - schemaName := "ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_") + schemaName := strings.ToUpper("ADBC_TESTING_" + strings.ReplaceAll(uuid.New().String(), "-", "_")) _, err = db.Exec(`CREATE SCHEMA ADBC_TESTING.` + schemaName) if err != nil { panic(err) @@ -247,18 +255,28 @@ func dropTempSchema(uri, schema string) { } } -func TestADBCSnowflake(t *testing.T) { +func withQuirks(t *testing.T, fn func(*SnowflakeQuirks)) { uri := os.Getenv("SNOWFLAKE_URI") + database := os.Getenv("SNOWFLAKE_DATABASE") if uri == "" { t.Skip("no SNOWFLAKE_URI defined, skip snowflake driver tests") + } else if database == "" { + t.Skip("no SNOWFLAKE_DATABASE defined, skip snowflake driver tests") } // avoid multiple runs clashing by operating in a fresh schema and then // dropping that schema when we're done. - q := &SnowflakeQuirks{dsn: uri, schemaName: createTempSchema(uri)} + q := &SnowflakeQuirks{dsn: uri, catalogName: database, schemaName: createTempSchema(uri)} defer dropTempSchema(uri, q.schemaName) - suite.Run(t, &validation.DatabaseTests{Quirks: q}) - suite.Run(t, &validation.ConnectionTests{Quirks: q}) - suite.Run(t, &validation.StatementTests{Quirks: q}) + + fn(q) +} + +func TestValidation(t *testing.T) { + withQuirks(t, func(q *SnowflakeQuirks) { + suite.Run(t, &validation.DatabaseTests{Quirks: q}) + suite.Run(t, &validation.ConnectionTests{Quirks: q}) + suite.Run(t, &validation.StatementTests{Quirks: q}) + }) } diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 481e7f7cec..90e456bce8 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -43,7 +43,7 @@ type statement struct { query string targetTable string - append bool + ingestMode string bound arrow.Record streamBind array.RecordReader @@ -71,6 +71,35 @@ func (st *statement) Close() error { return nil } +func (st *statement) GetOption(key string) (string, error) { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionBytes(key string) ([]byte, error) { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionInt(key string) (int64, error) { + switch key { + case OptionStatementQueueSize: + return int64(st.queueSize), nil + } + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} +func (st *statement) GetOptionDouble(key string) (float64, error) { + return 0, adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotFound, + } +} + // SetOption sets a string option on this statement func (st *statement) SetOption(key string, val string) error { switch key { @@ -80,9 +109,13 @@ func (st *statement) SetOption(key string, val string) error { case adbc.OptionKeyIngestMode: switch val { case adbc.OptionValueIngestModeAppend: - st.append = true + fallthrough case adbc.OptionValueIngestModeCreate: - st.append = false + fallthrough + case adbc.OptionValueIngestModeReplace: + fallthrough + case adbc.OptionValueIngestModeCreateAppend: + st.ingestMode = val default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -97,7 +130,7 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } - st.queueSize = sz + return st.SetOptionInt(key, int64(sz)) default: return adbc.Error{ Msg: fmt.Sprintf("invalid statement option %s=%s", key, val), @@ -107,6 +140,38 @@ func (st *statement) SetOption(key string, val string) error { return nil } +func (st *statement) SetOptionBytes(key string, value []byte) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionInt(key string, value int64) error { + switch key { + case OptionStatementQueueSize: + if value <= 0 { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Invalid value for statement option '%s': '%d' is not a positive integer", OptionStatementQueueSize, value), + Code: adbc.StatusInvalidArgument, + } + } + st.queueSize = int(value) + return nil + } + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + +func (st *statement) SetOptionDouble(key string, value float64) error { + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] Unknown statement option '%s'", key), + Code: adbc.StatusNotImplemented, + } +} + // SetSqlQuery sets the query string to be executed. // // The query can then be executed with any of the Execute methods. @@ -173,6 +238,9 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { ) createBldr.WriteString("CREATE TABLE ") + if st.ingestMode == adbc.OptionValueIngestModeCreateAppend { + createBldr.WriteString(" IF NOT EXISTS ") + } createBldr.WriteString(st.targetTable) createBldr.WriteString(" (") @@ -214,7 +282,22 @@ func (st *statement) initIngest(ctx context.Context) (string, error) { createBldr.WriteString(")") insertBldr.WriteString(")") - if !st.append { + switch st.ingestMode { + case adbc.OptionValueIngestModeAppend: + // Do nothing + case adbc.OptionValueIngestModeReplace: + replaceQuery := "DROP TABLE IF EXISTS " + st.targetTable + _, err := st.cnxn.cn.ExecContext(ctx, replaceQuery, nil) + if err != nil { + return "", errToAdbcErr(adbc.StatusInternal, err) + } + + fallthrough + case adbc.OptionValueIngestModeCreate: + fallthrough + case adbc.OptionValueIngestModeCreateAppend: + fallthrough + default: // create the table! createQuery := createBldr.String() _, err := st.cnxn.cn.ExecContext(ctx, createQuery, nil) diff --git a/go/adbc/standard_schemas.go b/go/adbc/standard_schemas.go index b5ca7d42b5..5ec888b40a 100644 --- a/go/adbc/standard_schemas.go +++ b/go/adbc/standard_schemas.go @@ -92,6 +92,34 @@ var ( {Name: "catalog_db_schemas", Type: arrow.ListOf(DBSchemaSchema), Nullable: true}, }, nil) + StatisticsSchema = arrow.StructOf( + arrow.Field{Name: "table_name", Type: arrow.BinaryTypes.String, Nullable: false}, + arrow.Field{Name: "column_name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "statistic_key", Type: arrow.PrimitiveTypes.Int16, Nullable: false}, + arrow.Field{Name: "statistic_value", Type: arrow.DenseUnionOf([]arrow.Field{ + {Name: "int64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "uint64", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "float64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "binary", Type: arrow.BinaryTypes.Binary, Nullable: true}, + }, []arrow.UnionTypeCode{0, 1, 2, 3}), Nullable: false}, + arrow.Field{Name: "statistic_is_approximate", Type: arrow.FixedWidthTypes.Boolean, Nullable: false}, + ) + + StatisticsDBSchemaSchema = arrow.StructOf( + arrow.Field{Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "db_schema_statistics", Type: arrow.ListOf(StatisticsSchema), Nullable: true}, + ) + + GetStatisticsSchema = arrow.NewSchema([]arrow.Field{ + {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "catalog_db_schemas", Type: arrow.ListOf(StatisticsDBSchemaSchema), Nullable: true}, + }, nil) + + GetStatisticNamesSchema = arrow.NewSchema([]arrow.Field{ + {Name: "statistic_name", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "statistic_key", Type: arrow.PrimitiveTypes.Int16, Nullable: false}, + }, nil) + GetTableSchemaSchema = arrow.NewSchema([]arrow.Field{ {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, {Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index ffc9e93dc8..7f9832ce3d 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -44,10 +44,20 @@ type DriverQuirks interface { DatabaseOptions() map[string]string // Return the SQL to reference the bind parameter for a given index BindParameter(index int) string + // Whether the driver supports bulk ingest + SupportsBulkIngest(mode string) bool // Whether two statements can be used at the same time on a single connection SupportsConcurrentStatements() bool + // Whether current catalog/schema are supported + SupportsCurrentCatalogSchema() bool + // Whether GetSetOptions is supported + SupportsGetSetOptions() bool + // Whether AdbcStatementExecuteSchema should work + SupportsExecuteSchema() bool // Whether AdbcStatementExecutePartitions should work SupportsPartitionedData() bool + // Whether statistics are supported + SupportsStatistics() bool // Whether transactions are supported (Commit/Rollback on connection) SupportsTransactions() bool // Whether retrieving the schema of prepared statement params is supported @@ -60,11 +70,10 @@ type DriverQuirks interface { CreateSampleTable(tableName string, r arrow.Record) error // Field Metadata for Sample Table for comparison SampleTableSchemaMetadata(tblName string, dt arrow.DataType) arrow.Metadata - // Whether the driver supports bulk ingest - SupportsBulkIngest() bool // have the driver drop a table with the correct SQL syntax DropTable(adbc.Connection, string) error + Catalog() string DBSchema() string Alloc() memory.Allocator @@ -115,6 +124,30 @@ func (c *ConnectionTests) TearDownTest() { c.DB = nil } +func (c *ConnectionTests) TestGetSetOptions() { + cnxn, err := c.DB.Open(context.Background()) + c.NoError(err) + c.NotNil(cnxn) + + stmt, err := cnxn.NewStatement() + c.NoError(err) + c.NotNil(stmt) + + expected := c.Quirks.SupportsGetSetOptions() + + _, ok := c.DB.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = cnxn.(adbc.GetSetOptions) + c.Equal(expected, ok) + + _, ok = stmt.(adbc.GetSetOptions) + c.Equal(expected, ok) + + c.NoError(stmt.Close()) + c.NoError(cnxn.Close()) +} + func (c *ConnectionTests) TestNewConn() { cnxn, err := c.DB.Open(context.Background()) c.NoError(err) @@ -152,6 +185,12 @@ func (c *ConnectionTests) TestAutocommitDefault() { cnxn, _ := c.DB.Open(ctx) defer cnxn.Close() + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueEnabled, value) + } + expectedCode := adbc.StatusInvalidState var adbcError adbc.Error err := cnxn.Commit(ctx) @@ -188,8 +227,60 @@ func (c *ConnectionTests) TestAutocommitToggle() { c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueEnabled)) c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } + // it is ok to disable autocommit when it isn't enabled c.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled)) + + if getset, ok := cnxn.(adbc.GetSetOptions); ok { + value, err := getset.GetOption(adbc.OptionKeyAutoCommit) + c.NoError(err) + c.Equal(adbc.OptionValueDisabled, value) + } +} + +func (c *ConnectionTests) TestMetadataCurrentCatalog() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentCatalog) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.Catalog(), value) + } else { + c.Error(err) + } +} + +func (c *ConnectionTests) TestMetadataCurrentDbSchema() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + getset, ok := cnxn.(adbc.GetSetOptions) + + if !c.Quirks.SupportsGetSetOptions() { + c.False(ok) + return + } + c.True(ok) + value, err := getset.GetOption(adbc.OptionKeyCurrentDbSchema) + if c.Quirks.SupportsCurrentCatalogSchema() { + c.NoError(err) + c.Equal(c.Quirks.DBSchema(), value) + } else { + c.Error(err) + } } func (c *ConnectionTests) TestMetadataGetInfo() { @@ -201,6 +292,7 @@ func (c *ConnectionTests) TestMetadataGetInfo() { adbc.InfoDriverName, adbc.InfoDriverVersion, adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, adbc.InfoVendorName, adbc.InfoVendorVersion, adbc.InfoVendorArrowVersion, @@ -219,19 +311,55 @@ func (c *ConnectionTests) TestMetadataGetInfo() { valUnion := rec.Column(1).(*array.DenseUnion) for i := 0; i < int(rec.NumRows()); i++ { code := codeCol.Value(i) - child := valUnion.Field(valUnion.ChildID(i)) - if child.IsNull(i) { + offset := int(valUnion.ValueOffset(i)) + valUnion.GetOneForMarshal(i) + if child.IsNull(offset) { exp := c.Quirks.GetMetadata(adbc.InfoCode(code)) c.Nilf(exp, "got nil for info %s, expected: %s", adbc.InfoCode(code), exp) } else { - // currently we only define utf8 values for metadata - c.Equal(c.Quirks.GetMetadata(adbc.InfoCode(code)), child.(*array.String).Value(i), adbc.InfoCode(code).String()) + expected := c.Quirks.GetMetadata(adbc.InfoCode(code)) + var actual interface{} + + switch valUnion.ChildID(i) { + case 0: + // String + actual = child.(*array.String).Value(offset) + case 2: + // int64 + actual = child.(*array.Int64).Value(offset) + default: + c.FailNow("Unknown union type code", valUnion.ChildID(i)) + } + + c.Equal(expected, actual, adbc.InfoCode(code).String()) } } } } +func (c *ConnectionTests) TestMetadataGetStatistics() { + ctx := context.Background() + cnxn, _ := c.DB.Open(ctx) + defer cnxn.Close() + + if c.Quirks.SupportsStatistics() { + stats, ok := cnxn.(adbc.ConnectionGetStatistics) + c.True(ok) + reader, err := stats.GetStatistics(ctx, nil, nil, nil, true) + c.NoError(err) + defer reader.Release() + } else { + stats, ok := cnxn.(adbc.ConnectionGetStatistics) + if ok { + _, err := stats.GetStatistics(ctx, nil, nil, nil, true) + var adbcErr adbc.Error + c.ErrorAs(err, &adbcErr) + c.Equal(adbc.StatusNotImplemented, adbcErr.Code) + } + } +} + func (c *ConnectionTests) TestMetadataGetTableSchema() { rec, _, err := array.RecordFromJSON(c.Quirks.Alloc(), arrow.NewSchema( []arrow.Field{ @@ -407,6 +535,49 @@ func (s *StatementTests) TestNewStatement() { s.Equal(adbc.StatusInvalidState, adbcError.Code) } +func (s *StatementTests) TestSqlExecuteSchema() { + if !s.Quirks.SupportsExecuteSchema() { + s.T().SkipNow() + } + + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + es, ok := stmt.(adbc.StatementExecuteSchema) + s.Require().True(ok, "%#v does not support ExecuteSchema", es) + + s.Run("no query", func() { + var adbcErr adbc.Error + + schema, err := es.ExecuteSchema(s.ctx) + s.ErrorAs(err, &adbcErr) + s.Equal(adbc.StatusInvalidState, adbcErr.Code) + s.Nil(schema) + }) + + s.Run("query", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) + + s.Run("prepared", func() { + s.NoError(stmt.SetSqlQuery("SELECT 1, 'string'")) + s.NoError(stmt.Prepare(s.ctx)) + + schema, err := es.ExecuteSchema(s.ctx) + s.NoError(err) + s.Equal(2, len(schema.Fields())) + s.True(schema.Field(0).Type.ID() == arrow.INT32 || schema.Field(0).Type.ID() == arrow.INT64) + s.Equal(arrow.STRING, schema.Field(1).Type.ID()) + }) +} + func (s *StatementTests) TestSqlPartitionedInts() { stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) @@ -596,7 +767,7 @@ func (s *StatementTests) TestSqlPrepareErrorParamCountMismatch() { } func (s *StatementTests) TestSqlIngestInts() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { s.T().SkipNow() } @@ -647,7 +818,7 @@ func (s *StatementTests) TestSqlIngestInts() { } func (s *StatementTests) TestSqlIngestAppend() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { s.T().SkipNow() } @@ -683,6 +854,10 @@ func (s *StatementTests) TestSqlIngestAppend() { defer batch2.Release() s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { + s.T().SkipNow() + } s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) s.Require().NoError(stmt.Bind(s.ctx, batch2)) @@ -716,11 +891,151 @@ func (s *StatementTests) TestSqlIngestAppend() { s.Require().NoError(rdr.Err()) } +func (s *StatementTests) TestSqlIngestReplace() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeReplace) { + s.T().SkipNow() + } + + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + + schema := arrow.NewSchema([]arrow.Field{{ + Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + + batchbldr := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr.Release() + bldr := batchbldr.Field(0).(*array.Int64Builder) + bldr.AppendValues([]int64{42}, []bool{true}) + batch := batchbldr.NewRecord() + defer batch.Release() + + // ingest and create table + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err := stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // now replace + schema = arrow.NewSchema([]arrow.Field{{ + Name: "newintcol", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + batchbldr2 := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr2.Release() + bldr2 := batchbldr2.Field(0).(*array.Int64Builder) + bldr2.AppendValues([]int64{42}, []bool{true}) + batch2 := batchbldr2.NewRecord() + defer batch2.Release() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeReplace)) + s.Require().NoError(stmt.Bind(s.ctx, batch2)) + + affected, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + rdr, rows, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + if rows != -1 && rows != 1 { + s.FailNowf("invalid number of returned rows", "should be -1 or 1, got: %d", rows) + } + defer rdr.Release() + + s.Truef(schema.Equal(utils.RemoveSchemaMetadata(rdr.Schema())), "expected: %s\n got: %s", schema, rdr.Schema()) + s.Require().True(rdr.Next()) + rec := rdr.Record() + s.EqualValues(1, rec.NumRows()) + s.EqualValues(1, rec.NumCols()) + col, ok := rec.Column(0).(*array.Int64) + s.True(ok) + s.Equal(int64(42), col.Value(0)) + + s.Require().False(rdr.Next()) + s.Require().NoError(rdr.Err()) +} + +func (s *StatementTests) TestSqlIngestCreateAppend() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreateAppend) { + s.T().SkipNow() + } + + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + + schema := arrow.NewSchema([]arrow.Field{{ + Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) + + batchbldr := array.NewRecordBuilder(s.Quirks.Alloc(), schema) + defer batchbldr.Release() + bldr := batchbldr.Field(0).(*array.Int64Builder) + bldr.AppendValues([]int64{42}, []bool{true}) + batch := batchbldr.NewRecord() + defer batch.Release() + + // ingest and create table + stmt, err := s.Cnxn.NewStatement() + s.Require().NoError(err) + defer stmt.Close() + + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreateAppend)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err := stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // append + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreateAppend)) + s.Require().NoError(stmt.Bind(s.ctx, batch)) + + affected, err = stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) + if affected != -1 && affected != 1 { + s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) + } + + // validate + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + rdr, rows, err := stmt.ExecuteQuery(s.ctx) + s.Require().NoError(err) + if rows != -1 && rows != 2 { + s.FailNowf("invalid number of returned rows", "should be -1 or 2, got: %d", rows) + } + defer rdr.Release() + + s.Truef(schema.Equal(utils.RemoveSchemaMetadata(rdr.Schema())), "expected: %s\n got: %s", schema, rdr.Schema()) + s.Require().True(rdr.Next()) + rec := rdr.Record() + s.EqualValues(2, rec.NumRows()) + s.EqualValues(1, rec.NumCols()) + col, ok := rec.Column(0).(*array.Int64) + s.True(ok) + s.Equal(int64(42), col.Value(0)) + s.Equal(int64(42), col.Value(1)) + + s.Require().False(rdr.Next()) + s.Require().NoError(rdr.Err()) +} + func (s *StatementTests) TestSqlIngestErrors() { - if !s.Quirks.SupportsBulkIngest() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { s.T().SkipNow() } + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) + stmt, err := s.Cnxn.NewStatement() s.Require().NoError(err) defer stmt.Close() @@ -735,6 +1050,10 @@ func (s *StatementTests) TestSqlIngestErrors() { }) s.Run("append to nonexistent table", func() { + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeAppend) { + s.T().SkipNow() + } + s.Require().NoError(s.Quirks.DropTable(s.Cnxn, "bulk_ingest")) schema := arrow.NewSchema([]arrow.Field{{ Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil) @@ -795,6 +1114,10 @@ func (s *StatementTests) TestSqlIngestErrors() { batch = batchbldr.NewRecord() defer batch.Release() + if !s.Quirks.SupportsBulkIngest(adbc.OptionValueIngestModeCreate) { + s.T().SkipNow() + } + s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest")) s.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) s.Require().NoError(stmt.Bind(s.ctx, batch))