From 80b58f21b881e2d17f48987733c84e0c97ef61dc Mon Sep 17 00:00:00 2001 From: David Shiflet Date: Wed, 6 Mar 2024 13:43:29 -0600 Subject: [PATCH] Fix: Support nullable types in Always Encrypted (#179) * preserve type information for Valuer parameters * support uniqueidentifier in AE * update readme --- README.md | 3 +-- alwaysencrypted_test.go | 17 +++++++++++++---- mssql.go | 34 ++++++++++++++++++++++++++++++++++ mssql_go19.go | 2 ++ queries_go19_test.go | 27 +++++++++++++++++++++++++++ queries_test.go | 13 +++++++++++++ 6 files changed, 90 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index e254b6b0..5e98abe1 100644 --- a/README.md +++ b/README.md @@ -427,9 +427,8 @@ If the correct key provider is included in your application, decryption of encry Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`. -*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`. +*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. https://github.com/microsoft/go-mssqldb/issues/129 -https://github.com/microsoft/go-mssqldb/issues/130 ### Local certificate AE key provider diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index c78d638f..9dd08b59 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "database/sql" + "database/sql/driver" "fmt" "math/big" "strings" @@ -65,8 +66,10 @@ func TestAlwaysEncryptedE2E(t *testing.T) { {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt}, {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)}, {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, - // TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that. - // {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, + {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, + {"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}}, + {"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}}, + {"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}}, } for _, test := range providerTests { // turn off key caching @@ -230,13 +233,19 @@ func comparisonValueFromObject(object interface{}) string { case time.Time: return civil.DateTimeOf(v).String() //return v.Format(time.RFC3339) - case fmt.Stringer: - return v.String() case bool: if v == true { return "1" } return "0" + case driver.Valuer: + val, _ := v.Value() + if val == nil { + return "" + } + return comparisonValueFromObject(val) + case fmt.Stringer: + return v.String() default: return fmt.Sprintf("%v", v) } diff --git a/mssql.go b/mssql.go index 8870410d..f86d5361 100644 --- a/mssql.go +++ b/mssql.go @@ -982,7 +982,37 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.Size = 0 return } + switch valuer := val.(type) { + case UniqueIdentifier: + case NullUniqueIdentifier: + default: + break + case driver.Valuer: + // If the value has a non-nil value, call MakeParam on its Value + val, e := driver.DefaultParameterConverter.ConvertValue(valuer) + if e != nil { + err = e + return + } + if val != nil { + return s.makeParam(val) + } + } switch val := val.(type) { + case UniqueIdentifier: + res.ti.TypeId = typeGuid + res.ti.Size = 16 + guid, _ := val.Value() + res.buffer = guid.([]byte) + case NullUniqueIdentifier: + res.ti.TypeId = typeGuid + res.ti.Size = 16 + if val.Valid { + guid, _ := val.Value() + res.buffer = guid.([]byte) + } else { + res.buffer = []byte{} + } case int: res.ti.TypeId = typeIntN // Rather than guess if the caller intends to pass a 32bit int from a 64bit app based on the @@ -1021,6 +1051,10 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.TypeId = typeIntN res.ti.Size = 8 res.buffer = []byte{} + case sql.NullInt32: + res.ti.TypeId = typeIntN + res.ti.Size = 4 + res.buffer = []byte{} case byte: res.ti.TypeId = typeIntN res.buffer = []byte{val} diff --git a/mssql_go19.go b/mssql_go19.go index b0285eef..6435f67e 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -75,6 +75,8 @@ func convertInputParameter(val interface{}) (interface{}, error) { // return nil case float32: return val, nil + case driver.Valuer: + return val, nil default: return driver.DefaultParameterConverter.ConvertValue(v) } diff --git a/queries_go19_test.go b/queries_go19_test.go index 7578a175..68eb1db9 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -10,6 +10,7 @@ import ( "fmt" "reflect" "regexp" + "strings" "testing" "time" @@ -31,6 +32,32 @@ func TestOutputParam(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + t.Run("varchar(max) to sql.NullString", func(t *testing.T) { + sqltextcreate := `CREATE PROCEDURE [GetTask] + @strparam varchar(max) = NULL OUTPUT + AS + SELECT @strparam = REPLICATE('a', 8000) + RETURN 0` + sqltextdrop := `drop procedure GetTask` + sqltextrun := `GetTask` + _, _ = db.ExecContext(ctx, sqltextdrop) + _, err = db.ExecContext(ctx, sqltextcreate) + if err != nil { + t.Fatal(err) + } + defer db.ExecContext(ctx, sqltextdrop) + nullstr := sql.NullString{} + _, err := db.ExecContext(ctx, sqltextrun, + sql.Named("strparam", sql.Out{Dest: &nullstr}), + ) + if err != nil { + t.Error(err) + } + defer db.ExecContext(ctx, sqltextdrop) + if nullstr.String != strings.Repeat("a", 8000) { + t.Error("Got incorrect NullString of length:", len(nullstr.String)) + } + }) t.Run("sp with rows", func(t *testing.T) { sqltextcreate := ` CREATE PROCEDURE spwithrows diff --git a/queries_test.go b/queries_test.go index 3d150d21..679ae184 100644 --- a/queries_test.go +++ b/queries_test.go @@ -198,6 +198,19 @@ func TestSelect(t *testing.T) { } }) }) + t.Run("scan into sql.NullString", func(t *testing.T) { + row := conn.QueryRow("SELECT REPLICATE('a', 8000)") + var out sql.NullString + err := row.Scan(&out) + if err != nil { + t.Error("Scan to NullString failed", err.Error()) + return + } + + if out.String != strings.Repeat("a", 8000) { + t.Error("got back a string with count:", len(out.String)) + } + }) } func TestSelectDateTimeOffset(t *testing.T) {