diff --git a/README.md b/README.md index a505f166..e254b6b0 100644 --- a/README.md +++ b/README.md @@ -477,6 +477,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host * Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. * Supports query notifications * Supports Kerberos Authentication +* Supports handling the `uniqueidentifier` data type with the `UniqueIdentifier` and `NullUniqueIdentifier` go types * Pluggable Dialer implementations through `msdsn.ProtocolParsers` and `msdsn.ProtocolDialers` * A `namedpipe` package to support connections using named pipes (np:) on Windows * A `sharedmemory` package to support connections using shared memory (lpc:) on Windows diff --git a/bulkcopy_test.go b/bulkcopy_test.go index 5de35154..ce7168cf 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -68,6 +68,7 @@ func TestBulkcopy(t *testing.T) { {"test_intf32", float32(1234.56), 1234}, {"test_geom", geom, string(geom)}, {"test_uniqueidentifier", uid, string(uid)}, + {"test_nulluniqueidentifier", nil, nil}, // {"test_smallmoney", 1234.56, nil}, // {"test_money", 1234.56, nil}, {"test_decimal_18_0", 1234.0001, "1234"}, @@ -270,6 +271,7 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str [test_geog] [geography] NULL, [text_xml] [xml] NULL, [test_uniqueidentifier] [uniqueidentifier] NULL, + [test_nulluniqueidentifier] [uniqueidentifier] NULL, [test_decimal_18_0] [decimal](18, 0) NULL, [test_decimal_18_2] [decimal](18, 2) NULL, [test_decimal_9_2] [decimal](9, 2) NULL, diff --git a/uniqueidentifier_null.go b/uniqueidentifier_null.go new file mode 100644 index 00000000..a9c4ba47 --- /dev/null +++ b/uniqueidentifier_null.go @@ -0,0 +1,65 @@ +package mssql + +import ( + "database/sql/driver" +) + +type NullUniqueIdentifier struct { + UUID UniqueIdentifier + Valid bool // Valid is true if UUID is not NULL +} + +func (n *NullUniqueIdentifier) Scan(v interface{}) error { + if v == nil { + *n = NullUniqueIdentifier{ + UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + Valid: false, + } + return nil + } + u := n.UUID + err := u.Scan(v) + *n = NullUniqueIdentifier{ + UUID: u, + Valid: true, + } + return err +} + +func (n NullUniqueIdentifier) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.UUID.Value() +} + +func (n NullUniqueIdentifier) String() string { + if !n.Valid { + return "NULL" + } + return n.UUID.String() +} + +func (n NullUniqueIdentifier) MarshalText() (text []byte, err error) { + if !n.Valid { + return []byte("null"), nil + } + return n.UUID.MarshalText() +} + +func (n *NullUniqueIdentifier) UnmarshalJSON(b []byte) error { + u := n.UUID + if string(b) == "null" { + *n = NullUniqueIdentifier{ + UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + Valid: false, + } + return nil + } + err := u.UnmarshalJSON(b) + *n = NullUniqueIdentifier{ + UUID: u, + Valid: true, + } + return err +} diff --git a/uniqueidentifier_null_test.go b/uniqueidentifier_null_test.go new file mode 100644 index 00000000..dc29276c --- /dev/null +++ b/uniqueidentifier_null_test.go @@ -0,0 +1,215 @@ +package mssql + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "testing" +) + +func TestNullableUniqueIdentifierScanNull(t *testing.T) { + t.Parallel() + nullUUID := NullUniqueIdentifier{ + UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + Valid: false, + } + + sut := NullUniqueIdentifier{ + UUID: [16]byte{0x1}, + Valid: true, + } + scanErr := sut.Scan(nil) // NULL in the DB + if scanErr != nil { + t.Fatal("NullUniqueIdentifier should not error out on Scan(nil)") + } + if sut != nullUUID { + t.Errorf("bytes not swapped correctly: got %q; want %q", sut, nullUUID) + } +} + +func TestNullableUniqueIdentifierScanBytes(t *testing.T) { + t.Parallel() + dbUUID := [16]byte{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} + uuid := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + + var sut NullUniqueIdentifier + scanErr := sut.Scan(dbUUID[:]) + if scanErr != nil { + t.Fatal(scanErr) + } + if sut != uuid { + t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid) + } +} + +func TestNullableUniqueIdentifierScanString(t *testing.T) { + t.Parallel() + uuid := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + + var sut NullUniqueIdentifier + scanErr := sut.Scan(uuid.String()) + if scanErr != nil { + t.Fatal(scanErr) + } + if sut != uuid { + t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid) + } +} + +func TestNullableUniqueIdentifierScanUnexpectedType(t *testing.T) { + t.Parallel() + var sut NullUniqueIdentifier + scanErr := sut.Scan(int(1)) + if scanErr == nil { + t.Fatal(scanErr) + } +} + +func TestNullableUniqueIdentifierValue(t *testing.T) { + t.Parallel() + dbUUID := [16]byte{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} + + uuid := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + + sut := uuid + v, valueErr := sut.Value() + if valueErr != nil { + t.Fatal(valueErr) + } + + b, ok := v.([]byte) + if !ok { + t.Fatalf("(%T) is not []byte", v) + } + + if !bytes.Equal(b, dbUUID[:]) { + t.Errorf("got %q; want %q", b, dbUUID) + } +} + +func TestNullableUniqueIdentifierValueNull(t *testing.T) { + t.Parallel() + uuid := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: false, + } + + sut := uuid + v, valueErr := sut.Value() + if valueErr != nil { + t.Errorf("unexpected error for invalid uuid: %s", valueErr) + } + + if v != nil { + t.Errorf("expected non-nil value for invalid uuid: %s", v) + } +} + +func TestNullableUniqueIdentifierString(t *testing.T) { + t.Parallel() + sut := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + expected := "01234567-89AB-CDEF-0123-456789ABCDEF" + if actual := sut.String(); actual != expected { + t.Errorf("sut.String() = %s; want %s", sut, expected) + } +} + +func TestNullableUniqueIdentifierStringNull(t *testing.T) { + t.Parallel() + sut := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: false, + } + expected := "NULL" + if actual := sut.String(); actual != expected { + t.Errorf("sut.String() = %s; want %s", sut, expected) + } +} + +func TestNullableUniqueIdentifierMarshalText(t *testing.T) { + t.Parallel() + sut := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + expected := []byte{48, 49, 50, 51, 52, 53, 54, 55, 45, 56, 57, 65, 66, 45, 67, 68, 69, 70, 45, 48, 49, 50, 51, 45, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69, 70} + text, marshalErr := sut.MarshalText() + if marshalErr != nil { + t.Errorf("unexpected error while marshalling: %s", marshalErr) + } + if actual := text; !reflect.DeepEqual(actual, expected) { + t.Errorf("sut.MarshalText() = %v; want %v", actual, expected) + } +} + +func TestNullableUniqueIdentifierMarshalTextNull(t *testing.T) { + t.Parallel() + sut := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: false, + } + expected := []byte("null") + text, marshalErr := sut.MarshalText() + if marshalErr != nil { + t.Errorf("unexpected error while marshalling: %s", marshalErr) + } + if actual := text; !reflect.DeepEqual(actual, expected) { + t.Errorf("sut.MarshalText() = %v; want %v", actual, expected) + } +} + +func TestNullableUniqueIdentifierUnmarshalJSON(t *testing.T) { + t.Parallel() + input := []byte("01234567-89AB-CDEF-0123-456789ABCDEF") + var u NullUniqueIdentifier + + err := u.UnmarshalJSON(input) + if err != nil { + t.Fatal(err) + } + expected := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + if u != expected { + t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected) + } +} + +func TestNullableUniqueIdentifierUnmarshalJSONNull(t *testing.T) { + t.Parallel() + u := NullUniqueIdentifier{ + UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}, + Valid: true, + } + + err := u.UnmarshalJSON([]byte("null")) + if err != nil { + t.Fatal(err) + } + expected := NullUniqueIdentifier{ + UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + Valid: false, + } + if u != expected { + t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected) + } +} + +var _ fmt.Stringer = NullUniqueIdentifier{} +var _ sql.Scanner = &NullUniqueIdentifier{} +var _ driver.Valuer = NullUniqueIdentifier{} diff --git a/uniqueidentifier_test.go b/uniqueidentifier_test.go index 23070d03..2ec5e107 100644 --- a/uniqueidentifier_test.go +++ b/uniqueidentifier_test.go @@ -9,56 +9,86 @@ import ( "testing" ) -func TestUniqueIdentifier(t *testing.T) { +func TestUniqueIdentifierScanNull(t *testing.T) { + t.Parallel() + + sut := UniqueIdentifier{0x01} + scanErr := sut.Scan(nil) // NULL in the DB + if scanErr == nil { + t.Fatal("expected an error for Scan(nil)") + } +} + +func TestUniqueIdentifierScanBytes(t *testing.T) { + t.Parallel() dbUUID := UniqueIdentifier{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, } + uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} + var sut UniqueIdentifier + scanErr := sut.Scan(dbUUID[:]) + if scanErr != nil { + t.Fatal(scanErr) + } + if sut != uuid { + t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid) + } +} + +func TestUniqueIdentifierScanString(t *testing.T) { + t.Parallel() uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} - t.Run("Scan", func(t *testing.T) { - t.Run("[]byte", func(t *testing.T) { - var sut UniqueIdentifier - if err := sut.Scan(dbUUID[:]); err != nil { - t.Fatal(err) - } - if sut != uuid { - t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid) - } - }) - - t.Run("string", func(t *testing.T) { - var sut UniqueIdentifier - if err := sut.Scan(uuid.String()); err != nil { - t.Fatal(err) - } - if sut != uuid { - t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid) - } - }) - }) - - t.Run("Value", func(t *testing.T) { - sut := uuid - v, err := sut.Value() - if err != nil { - t.Fatal(err) - } - - b, ok := v.([]byte) - if !ok { - t.Fatalf("(%T) is not []byte", v) - } - - if !bytes.Equal(b, dbUUID[:]) { - t.Errorf("got %q; want %q", b, dbUUID) - } - }) + var sut UniqueIdentifier + scanErr := sut.Scan(uuid.String()) + if scanErr != nil { + t.Fatal(scanErr) + } + if sut != uuid { + t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid) + } +} + +func TestUniqueIdentifierScanUnexpectedType(t *testing.T) { + t.Parallel() + var sut UniqueIdentifier + scanErr := sut.Scan(int(1)) + if scanErr == nil { + t.Fatal(scanErr) + } +} + +func TestUniqueIdentifierValue(t *testing.T) { + t.Parallel() + dbUUID := UniqueIdentifier{0x67, 0x45, 0x23, 0x01, + 0xAB, 0x89, + 0xEF, 0xCD, + 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, + } + + uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} + + sut := uuid + v, valueErr := sut.Value() + if valueErr != nil { + t.Fatal(valueErr) + } + + b, ok := v.([]byte) + if !ok { + t.Fatalf("(%T) is not []byte", v) + } + + if !bytes.Equal(b, dbUUID[:]) { + t.Errorf("got %q; want %q", b, dbUUID) + } } func TestUniqueIdentifierString(t *testing.T) { + t.Parallel() sut := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} expected := "01234567-89AB-CDEF-0123-456789ABCDEF" if actual := sut.String(); actual != expected { @@ -67,6 +97,7 @@ func TestUniqueIdentifierString(t *testing.T) { } func TestUniqueIdentifierMarshalText(t *testing.T) { + t.Parallel() sut := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF} expected := []byte{48, 49, 50, 51, 52, 53, 54, 55, 45, 56, 57, 65, 66, 45, 67, 68, 69, 70, 45, 48, 49, 50, 51, 45, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69, 70} text, _ := sut.MarshalText() @@ -76,6 +107,7 @@ func TestUniqueIdentifierMarshalText(t *testing.T) { } func TestUniqueIdentifierUnmarshalJSON(t *testing.T) { + t.Parallel() input := []byte("01234567-89AB-CDEF-0123-456789ABCDEF") var u UniqueIdentifier