Skip to content

Commit

Permalink
fix: proper order for custom EncoderDecoder
Browse files Browse the repository at this point in the history
This commit changes code it correctly tries to use custom Encoders/Decoders before MarshalBinary/UnmarshalBinary. It also forbids the usage of custom Encoders/Decoders on top level structs since we pretty much control those.

Signed-off-by: Dmitriy Matrenichev <[email protected]>
  • Loading branch information
DmitriyMV committed Sep 7, 2022
1 parent 3617e19 commit dceb5a6
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 42 deletions.
41 changes: 17 additions & 24 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

// Marshal a Go struct into protocol buffer format.
// The caller must pass a pointer to the struct to encode.
func Marshal(ptr interface{}) (result []byte, err error) {
func Marshal(structPtr interface{}) (result []byte, err error) {
defer func() {
if recovered := recover(); recovered != nil {
e, ok := recovered.(error)
Expand All @@ -36,21 +36,25 @@ func Marshal(ptr interface{}) (result []byte, err error) {
}
}()

if ptr == nil {
if structPtr == nil {
return nil, nil
}

if bu, ok := ptr.(encoding.BinaryMarshaler); ok {
if hasCustomEncoders(reflect.TypeOf(structPtr)) {
return nil, errors.New("custom encoders are not supported for top-level structs, use BinaryMarshaler instead")
}

if bu, ok := structPtr.(encoding.BinaryMarshaler); ok {
return bu.MarshalBinary()
}

m := marshaller{
buf: make([]byte, 0, 32),
}

val := reflect.ValueOf(ptr)
if val.Kind() != reflect.Pointer {
return nil, errors.New("encode takes a pointer to struct")
val := reflect.ValueOf(structPtr)
if val.Kind() != reflect.Pointer || val.Type().Elem().Kind() != reflect.Struct {
return nil, errors.New("marshal takes a pointer to struct")
}

m.encodeStruct(val.Elem())
Expand All @@ -71,17 +75,6 @@ func (m *marshaller) Bytes() []byte {
}

func (m *marshaller) encodeStruct(val reflect.Value) {
if val.Type().Kind() != reflect.Struct {
panic("encodeStruct takes a struct")
}

res, ok := tryEncodeFunc(val)
if ok {
m.buf = append(m.buf, res...)

return
}

structFields, err := StructFields(val.Type())
if err != nil {
panic(err)
Expand Down Expand Up @@ -146,7 +139,7 @@ func fieldByIndex(structVal reflect.Value, data FieldData) reflect.Value {
return result
}

//nolint:cyclop
//nolint:cyclop,gocyclo
func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) {
if m.tryEncodePredefined(num, val) {
return
Expand Down Expand Up @@ -187,15 +180,15 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) {

case reflect.Struct:
var b []byte

bmarshaler, ok := asBinaryMarshaler(val)
if ok {
var err error

b, err = bmarshaler.MarshalBinary()
if result, ok := tryEncodeFunc(val); ok {
b = result
} else if bmarshaler, ok := asBinaryMarshaler(val); ok {
result, err := bmarshaler.MarshalBinary()
if err != nil {
panic(err)
}

b = result
} else {
inner := marshaller{}
inner.encodeStruct(val)
Expand Down
69 changes: 69 additions & 0 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,72 @@ func TestResursiveTypes(t *testing.T) {
)(t)
})
}

type MarshallableString struct {
Value string
}

func (m *MarshallableString) MarshalBinary() ([]byte, error) {
return []byte(m.Value + "MarshalBinary"), nil
}

func (m *MarshallableString) UnmarshalBinary(data []byte) error {
m.Value = string(data) + "UnmarshalBinary"

return nil
}

func TestCustomEncodersWithMarshalBinary(t *testing.T) {
t.Cleanup(func() {
protoenc.CleanEncoderDecoder()
})

protoenc.RegisterEncoderDecoder(
func(v *MarshallableString) ([]byte, error) { return []byte(v.Value + "EncoderPtr"), nil },
func(slc []byte) (*MarshallableString, error) {
return &MarshallableString{
Value: string(slc) + "DecoderPtr",
}, nil
},
)

protoenc.RegisterEncoderDecoder(
func(v MarshallableString) ([]byte, error) { return []byte(v.Value + "Encoder"), nil },
func(slc []byte) (MarshallableString, error) {
return MarshallableString{
Value: string(slc) + "Decoder",
}, nil
},
)

type T = MarshallableString

sliceOriginal := []MarshallableString{{Value: "MyVal"}, {Value: "MyVal2"}}
sliceExpected := []MarshallableString{{Value: "MyValEncoderDecoder"}, {Value: "MyVal2EncoderDecoder"}}

tests := map[string]struct {
fn func(t *testing.T)
}{
"custom encoder on struct field": {testEncodeDecodeResult(makeValue(T{Value: "MyVal"}), makeValue(T{Value: "MyValEncoderDecoder"}))},
"custom encoder on struct field pointer": {testEncodeDecodeResult(makeValue(&T{Value: "MyVal"}), makeValue(&T{Value: "MyValEncoderPtrDecoderPtr"}))},
"custom encoder on slice": {testEncodeDecodeResult(makeValue(sliceOriginal), makeValue(sliceExpected))},
}

for name, test := range tests {
t.Run(name, test.fn)
}
}

func testEncodeDecodeResult[V any](original, expected V) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
encoded := must(protoenc.Marshal(&original))(t)

t.Logf("\n%s", hex.Dump(encoded))

var result V

require.NoError(t, protoenc.Unmarshal(encoded, &result))
require.Equal(t, expected, result)
}
}
6 changes: 5 additions & 1 deletion slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ type Value[T any] struct {
V T `protobuf:"1"`
}

func makeValue[T any](t T) Value[T] {
return Value[T]{V: t}
}

func (v Value[T]) Val() T {
return v.V
}
Expand Down Expand Up @@ -392,7 +396,7 @@ func testDisallowedTypes[T any](t *testing.T) {

_, err := protoenc.Marshal(&original)
require.Error(t, err)
assert.Regexp(t, "(is not supported)|(takes a struct)", err.Error())
assert.Regexp(t, "(is not supported)|(takes a pointer to struct)", err.Error())
}

func TestDuration(t *testing.T) {
Expand Down
31 changes: 14 additions & 17 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (

// Unmarshal a protobuf value into a Go value.
// The caller must pass a pointer to the struct to decode into.
func Unmarshal(buf []byte, ptr interface{}) error {
return unmarshal(buf, ptr)
func Unmarshal(buf []byte, structPtr interface{}) error {
return unmarshal(buf, structPtr)
}

func unmarshal(buf []byte, structPtr interface{}) (returnErr error) {
Expand All @@ -41,34 +41,25 @@ func unmarshal(buf []byte, structPtr interface{}) (returnErr error) {
return nil
}

if hasCustomEncoders(reflect.TypeOf(structPtr)) {
return errors.New("custom decoders are not supported for top-level structs, use BinaryUnmarshaler instead")
}

if bu, ok := structPtr.(encoding.BinaryUnmarshaler); ok {
return bu.UnmarshalBinary(buf)
}

val := reflect.ValueOf(structPtr)
if val.Kind() != reflect.Pointer {
return errors.New("decode has been given a non pointer type")
if val.Kind() != reflect.Pointer || val.Type().Elem().Kind() != reflect.Struct {
return errors.New("unmarshal takes a pointer to struct")
}

return unmarshalStruct(val.Elem(), buf)
}

func unmarshalStruct(structVal reflect.Value, buf []byte) error {
if structVal.Kind() != reflect.Struct {
return errors.New("not a struct")
}

zeroStructFields(structVal)

ok, err := tryDecodeFunc(buf, structVal)
if err != nil {
return err
}

if ok {
return nil
}

structFields, err := StructFields(structVal.Type())
if err != nil {
return err
Expand Down Expand Up @@ -398,6 +389,12 @@ func unmarshalBytes(dst reflect.Value, value complexValue) (err error) {
return nil

case reflect.Struct:
if ok, err := tryDecodeFunc(bytes, dst); ok {
return nil
} else if err != nil {
return err
}

if enc, ok := dst.Addr().Interface().(encoding.BinaryUnmarshaler); ok {
return enc.UnmarshalBinary(bytes)
}
Expand Down

0 comments on commit dceb5a6

Please sign in to comment.