Skip to content

Commit

Permalink
Feat: Add NullableUniqueIdentifier type
Browse files Browse the repository at this point in the history
* Refactor UniqueIdentifier tests

* Parallelize tests

* Add NullableUniqueIdentifier type

* Add missing test case for UniqueIdentifier

* Improve error message

* Rename to NullUniqueIdentifier

* Add NullUniqueIdentifier to TestBulkcopy

* Add uniqueidentifier parsing to the list of Features

* Add Valid bool to NullUniqueIdentifier

* Handle null in UnmarshalJSON()

* Handle !Valid in Value(),String(),MarshalText()

---------

Co-authored-by: Norman Gehrsitz <[email protected]>
  • Loading branch information
ngehrsitz and Norman Gehrsitz authored Feb 21, 2024
1 parent f37ff1d commit fe7c3d4
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 39 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host
* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas.
* Supports query notifications
* Supports Kerberos Authentication
* Supports handling the `uniqueidentifier` data type with the `UniqueIdentifier` and `NullUniqueIdentifier` go types
* Pluggable Dialer implementations through `msdsn.ProtocolParsers` and `msdsn.ProtocolDialers`
* A `namedpipe` package to support connections using named pipes (np:) on Windows
* A `sharedmemory` package to support connections using shared memory (lpc:) on Windows
Expand Down
2 changes: 2 additions & 0 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func TestBulkcopy(t *testing.T) {
{"test_intf32", float32(1234.56), 1234},
{"test_geom", geom, string(geom)},
{"test_uniqueidentifier", uid, string(uid)},
{"test_nulluniqueidentifier", nil, nil},
// {"test_smallmoney", 1234.56, nil},
// {"test_money", 1234.56, nil},
{"test_decimal_18_0", 1234.0001, "1234"},
Expand Down Expand Up @@ -270,6 +271,7 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_geog] [geography] NULL,
[text_xml] [xml] NULL,
[test_uniqueidentifier] [uniqueidentifier] NULL,
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
[test_decimal_18_0] [decimal](18, 0) NULL,
[test_decimal_18_2] [decimal](18, 2) NULL,
[test_decimal_9_2] [decimal](9, 2) NULL,
Expand Down
65 changes: 65 additions & 0 deletions uniqueidentifier_null.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package mssql

import (
"database/sql/driver"
)

type NullUniqueIdentifier struct {
UUID UniqueIdentifier
Valid bool // Valid is true if UUID is not NULL
}

func (n *NullUniqueIdentifier) Scan(v interface{}) error {
if v == nil {
*n = NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}
return nil
}
u := n.UUID
err := u.Scan(v)
*n = NullUniqueIdentifier{
UUID: u,
Valid: true,
}
return err
}

func (n NullUniqueIdentifier) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.UUID.Value()
}

func (n NullUniqueIdentifier) String() string {
if !n.Valid {
return "NULL"
}
return n.UUID.String()
}

func (n NullUniqueIdentifier) MarshalText() (text []byte, err error) {
if !n.Valid {
return []byte("null"), nil
}
return n.UUID.MarshalText()
}

func (n *NullUniqueIdentifier) UnmarshalJSON(b []byte) error {
u := n.UUID
if string(b) == "null" {
*n = NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}
return nil
}
err := u.UnmarshalJSON(b)
*n = NullUniqueIdentifier{
UUID: u,
Valid: true,
}
return err
}
215 changes: 215 additions & 0 deletions uniqueidentifier_null_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package mssql

import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"testing"
)

func TestNullableUniqueIdentifierScanNull(t *testing.T) {
t.Parallel()
nullUUID := NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}

sut := NullUniqueIdentifier{
UUID: [16]byte{0x1},
Valid: true,
}
scanErr := sut.Scan(nil) // NULL in the DB
if scanErr != nil {
t.Fatal("NullUniqueIdentifier should not error out on Scan(nil)")
}
if sut != nullUUID {
t.Errorf("bytes not swapped correctly: got %q; want %q", sut, nullUUID)
}
}

func TestNullableUniqueIdentifierScanBytes(t *testing.T) {
t.Parallel()
dbUUID := [16]byte{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}
uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

var sut NullUniqueIdentifier
scanErr := sut.Scan(dbUUID[:])
if scanErr != nil {
t.Fatal(scanErr)
}
if sut != uuid {
t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid)
}
}

func TestNullableUniqueIdentifierScanString(t *testing.T) {
t.Parallel()
uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

var sut NullUniqueIdentifier
scanErr := sut.Scan(uuid.String())
if scanErr != nil {
t.Fatal(scanErr)
}
if sut != uuid {
t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid)
}
}

func TestNullableUniqueIdentifierScanUnexpectedType(t *testing.T) {
t.Parallel()
var sut NullUniqueIdentifier
scanErr := sut.Scan(int(1))
if scanErr == nil {
t.Fatal(scanErr)
}
}

func TestNullableUniqueIdentifierValue(t *testing.T) {
t.Parallel()
dbUUID := [16]byte{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}

uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

sut := uuid
v, valueErr := sut.Value()
if valueErr != nil {
t.Fatal(valueErr)
}

b, ok := v.([]byte)
if !ok {
t.Fatalf("(%T) is not []byte", v)
}

if !bytes.Equal(b, dbUUID[:]) {
t.Errorf("got %q; want %q", b, dbUUID)
}
}

func TestNullableUniqueIdentifierValueNull(t *testing.T) {
t.Parallel()
uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: false,
}

sut := uuid
v, valueErr := sut.Value()
if valueErr != nil {
t.Errorf("unexpected error for invalid uuid: %s", valueErr)
}

if v != nil {
t.Errorf("expected non-nil value for invalid uuid: %s", v)
}
}

func TestNullableUniqueIdentifierString(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}
expected := "01234567-89AB-CDEF-0123-456789ABCDEF"
if actual := sut.String(); actual != expected {
t.Errorf("sut.String() = %s; want %s", sut, expected)
}
}

func TestNullableUniqueIdentifierStringNull(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: false,
}
expected := "NULL"
if actual := sut.String(); actual != expected {
t.Errorf("sut.String() = %s; want %s", sut, expected)
}
}

func TestNullableUniqueIdentifierMarshalText(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}
expected := []byte{48, 49, 50, 51, 52, 53, 54, 55, 45, 56, 57, 65, 66, 45, 67, 68, 69, 70, 45, 48, 49, 50, 51, 45, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69, 70}
text, marshalErr := sut.MarshalText()
if marshalErr != nil {
t.Errorf("unexpected error while marshalling: %s", marshalErr)
}
if actual := text; !reflect.DeepEqual(actual, expected) {
t.Errorf("sut.MarshalText() = %v; want %v", actual, expected)
}
}

func TestNullableUniqueIdentifierMarshalTextNull(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: false,
}
expected := []byte("null")
text, marshalErr := sut.MarshalText()
if marshalErr != nil {
t.Errorf("unexpected error while marshalling: %s", marshalErr)
}
if actual := text; !reflect.DeepEqual(actual, expected) {
t.Errorf("sut.MarshalText() = %v; want %v", actual, expected)
}
}

func TestNullableUniqueIdentifierUnmarshalJSON(t *testing.T) {
t.Parallel()
input := []byte("01234567-89AB-CDEF-0123-456789ABCDEF")
var u NullUniqueIdentifier

err := u.UnmarshalJSON(input)
if err != nil {
t.Fatal(err)
}
expected := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}
if u != expected {
t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected)
}
}

func TestNullableUniqueIdentifierUnmarshalJSONNull(t *testing.T) {
t.Parallel()
u := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

err := u.UnmarshalJSON([]byte("null"))
if err != nil {
t.Fatal(err)
}
expected := NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}
if u != expected {
t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected)
}
}

var _ fmt.Stringer = NullUniqueIdentifier{}
var _ sql.Scanner = &NullUniqueIdentifier{}
var _ driver.Valuer = NullUniqueIdentifier{}
Loading

0 comments on commit fe7c3d4

Please sign in to comment.