Skip to content

Commit

Permalink
Fix: Support nullable types in Always Encrypted (denisenkom#179)
Browse files Browse the repository at this point in the history
* preserve type information for Valuer parameters

* support uniqueidentifier in AE

* update readme
  • Loading branch information
shueybubbles authored Mar 6, 2024
1 parent fe7c3d4 commit 80b58f2
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 6 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ If the correct key provider is included in your application, decryption of encry

Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`.

*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`.
*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values.
https://github.com/microsoft/go-mssqldb/issues/129
https://github.com/microsoft/go-mssqldb/issues/130
### Local certificate AE key provider
Expand Down
17 changes: 13 additions & 4 deletions alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"strings"
Expand Down Expand Up @@ -65,8 +66,10 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt},
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
// TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that.
// {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
}
for _, test := range providerTests {
// turn off key caching
Expand Down Expand Up @@ -230,13 +233,19 @@ func comparisonValueFromObject(object interface{}) string {
case time.Time:
return civil.DateTimeOf(v).String()
//return v.Format(time.RFC3339)
case fmt.Stringer:
return v.String()
case bool:
if v == true {
return "1"
}
return "0"
case driver.Valuer:
val, _ := v.Value()
if val == nil {
return "<nil>"
}
return comparisonValueFromObject(val)
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
Expand Down
34 changes: 34 additions & 0 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,37 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 0
return
}
switch valuer := val.(type) {
case UniqueIdentifier:
case NullUniqueIdentifier:
default:
break
case driver.Valuer:
// If the value has a non-nil value, call MakeParam on its Value
val, e := driver.DefaultParameterConverter.ConvertValue(valuer)
if e != nil {
err = e
return
}
if val != nil {
return s.makeParam(val)
}
}
switch val := val.(type) {
case UniqueIdentifier:
res.ti.TypeId = typeGuid
res.ti.Size = 16
guid, _ := val.Value()
res.buffer = guid.([]byte)
case NullUniqueIdentifier:
res.ti.TypeId = typeGuid
res.ti.Size = 16
if val.Valid {
guid, _ := val.Value()
res.buffer = guid.([]byte)
} else {
res.buffer = []byte{}
}
case int:
res.ti.TypeId = typeIntN
// Rather than guess if the caller intends to pass a 32bit int from a 64bit app based on the
Expand Down Expand Up @@ -1021,6 +1051,10 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.TypeId = typeIntN
res.ti.Size = 8
res.buffer = []byte{}
case sql.NullInt32:
res.ti.TypeId = typeIntN
res.ti.Size = 4
res.buffer = []byte{}
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
Expand Down
2 changes: 2 additions & 0 deletions mssql_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func convertInputParameter(val interface{}) (interface{}, error) {
// return nil
case float32:
return val, nil
case driver.Valuer:
return val, nil
default:
return driver.DefaultParameterConverter.ConvertValue(v)
}
Expand Down
27 changes: 27 additions & 0 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"reflect"
"regexp"
"strings"
"testing"
"time"

Expand All @@ -31,6 +32,32 @@ func TestOutputParam(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

t.Run("varchar(max) to sql.NullString", func(t *testing.T) {
sqltextcreate := `CREATE PROCEDURE [GetTask]
@strparam varchar(max) = NULL OUTPUT
AS
SELECT @strparam = REPLICATE('a', 8000)
RETURN 0`
sqltextdrop := `drop procedure GetTask`
sqltextrun := `GetTask`
_, _ = db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
nullstr := sql.NullString{}
_, err := db.ExecContext(ctx, sqltextrun,
sql.Named("strparam", sql.Out{Dest: &nullstr}),
)
if err != nil {
t.Error(err)
}
defer db.ExecContext(ctx, sqltextdrop)
if nullstr.String != strings.Repeat("a", 8000) {
t.Error("Got incorrect NullString of length:", len(nullstr.String))
}
})
t.Run("sp with rows", func(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE spwithrows
Expand Down
13 changes: 13 additions & 0 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ func TestSelect(t *testing.T) {
}
})
})
t.Run("scan into sql.NullString", func(t *testing.T) {
row := conn.QueryRow("SELECT REPLICATE('a', 8000)")
var out sql.NullString
err := row.Scan(&out)
if err != nil {
t.Error("Scan to NullString failed", err.Error())
return
}

if out.String != strings.Repeat("a", 8000) {
t.Error("got back a string with count:", len(out.String))
}
})
}

func TestSelectDateTimeOffset(t *testing.T) {
Expand Down

0 comments on commit 80b58f2

Please sign in to comment.