Skip to content

Commit

Permalink
Add nil aware decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Mar 16, 2021
1 parent da5c475 commit 19b8e3c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 100 deletions.
128 changes: 48 additions & 80 deletions decode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,43 +59,43 @@ func _getDecoder(typ reflect.Type) decoderFunc {

if kind == reflect.Ptr {
if _, ok := typeDecMap.Load(typ.Elem()); ok {
return ptrDecoderFunc(typ)
return ptrValueDecoder(typ)
}
}

if typ.Implements(customDecoderType) {
return decodeCustomValue
return nilAwareDecoder(typ, decodeCustomValue)
}
if typ.Implements(unmarshalerType) {
return unmarshalValue
return nilAwareDecoder(typ, unmarshalValue)
}
if typ.Implements(binaryUnmarshalerType) {
return unmarshalBinaryValue
return nilAwareDecoder(typ, unmarshalBinaryValue)
}
if typ.Implements(textUnmarshalerType) {
return unmarshalTextValue
return nilAwareDecoder(typ, unmarshalTextValue)
}

// Addressable struct field value.
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(customDecoderType) {
return decodeCustomValueAddr
return addrDecoder(nilAwareDecoder(typ, decodeCustomValue))
}
if ptr.Implements(unmarshalerType) {
return unmarshalValueAddr
return addrDecoder(nilAwareDecoder(typ, unmarshalValue))
}
if ptr.Implements(binaryUnmarshalerType) {
return unmarshalBinaryValueAddr
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue))
}
if ptr.Implements(textUnmarshalerType) {
return unmarshalTextValueAddr
return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue))
}
}

switch kind {
case reflect.Ptr:
return ptrDecoderFunc(typ)
return ptrValueDecoder(typ)
case reflect.Slice:
elem := typ.Elem()
if elem.Kind() == reflect.Uint8 {
Expand All @@ -122,7 +122,7 @@ func _getDecoder(typ reflect.Type) decoderFunc {
return valueDecoders[kind]
}

func ptrDecoderFunc(typ reflect.Type) decoderFunc {
func ptrValueDecoder(typ reflect.Type) decoderFunc {
decoder := getDecoder(typ.Elem())
return func(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
Expand All @@ -138,53 +138,34 @@ func ptrDecoderFunc(typ reflect.Type) decoderFunc {
}
}

func decodeCustomValueAddr(d *Decoder, v reflect.Value) error {
if !v.CanAddr() {
return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
}
return decodeCustomValue(d, v.Addr())
}

func decodeCustomValue(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}

if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}

decoder := v.Interface().(CustomDecoder)
return decoder.DecodeMsgpack(d)
}

func unmarshalValueAddr(d *Decoder, v reflect.Value) error {
if !v.CanAddr() {
return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
func addrDecoder(fn decoderFunc) decoderFunc {
return func(d *Decoder, v reflect.Value) error {
if !v.CanAddr() {
return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
}
return fn(d, v.Addr())
}
return unmarshalValue(d, v.Addr())
}

func unmarshalValue(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}

if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
func nilAwareDecoder(typ reflect.Type, fn decoderFunc) decoderFunc {
if nilable(typ.Kind()) {
return func(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return fn(d, v)
}
}

var b []byte

d.rec = make([]byte, 0, 64)
if err := d.Skip(); err != nil {
return err
return func(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}
return fn(d, v)
}
b = d.rec
d.rec = nil

unmarshaler := v.Interface().(Unmarshaler)
return unmarshaler.UnmarshalMsgpack(b)
}

func decodeBoolValue(d *Decoder, v reflect.Value) error {
Expand Down Expand Up @@ -229,22 +210,26 @@ func decodeUnsupportedValue(d *Decoder, v reflect.Value) error {

//------------------------------------------------------------------------------

func unmarshalBinaryValueAddr(d *Decoder, v reflect.Value) error {
if !v.CanAddr() {
return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
}
return unmarshalBinaryValue(d, v.Addr())
func decodeCustomValue(d *Decoder, v reflect.Value) error {
decoder := v.Interface().(CustomDecoder)
return decoder.DecodeMsgpack(d)
}

func unmarshalBinaryValue(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}
func unmarshalValue(d *Decoder, v reflect.Value) error {
var b []byte

if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
d.rec = make([]byte, 0, 64)
if err := d.Skip(); err != nil {
return err
}
b = d.rec
d.rec = nil

unmarshaler := v.Interface().(Unmarshaler)
return unmarshaler.UnmarshalMsgpack(b)
}

func unmarshalBinaryValue(d *Decoder, v reflect.Value) error {
data, err := d.DecodeBytes()
if err != nil {
return err
Expand All @@ -254,24 +239,7 @@ func unmarshalBinaryValue(d *Decoder, v reflect.Value) error {
return unmarshaler.UnmarshalBinary(data)
}

//------------------------------------------------------------------------------

func unmarshalTextValueAddr(d *Decoder, v reflect.Value) error {
if !v.CanAddr() {
return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
}
return unmarshalTextValue(d, v.Addr())
}

func unmarshalTextValue(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}

if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}

data, err := d.DecodeBytes()
if err != nil {
return err
Expand Down
12 changes: 6 additions & 6 deletions encode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func encodeCustomValuePtr(e *Encoder, v reflect.Value) error {
}

func encodeCustomValue(e *Encoder, v reflect.Value) error {
if nilable(v) && v.IsNil() {
if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}

Expand All @@ -155,7 +155,7 @@ func marshalValuePtr(e *Encoder, v reflect.Value) error {
}

func marshalValue(e *Encoder, v reflect.Value) error {
if nilable(v) && v.IsNil() {
if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}

Expand Down Expand Up @@ -190,8 +190,8 @@ func encodeUnsupportedValue(e *Encoder, v reflect.Value) error {
return fmt.Errorf("msgpack: Encode(unsupported %s)", v.Type())
}

func nilable(v reflect.Value) bool {
switch v.Kind() {
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
}
Expand All @@ -208,7 +208,7 @@ func marshalBinaryValueAddr(e *Encoder, v reflect.Value) error {
}

func marshalBinaryValue(e *Encoder, v reflect.Value) error {
if nilable(v) && v.IsNil() {
if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}

Expand All @@ -231,7 +231,7 @@ func marshalTextValueAddr(e *Encoder, v reflect.Value) error {
}

func marshalTextValue(e *Encoder, v reflect.Value) error {
if nilable(v) && v.IsNil() {
if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}

Expand Down
17 changes: 3 additions & 14 deletions ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,7 @@ func makeExtDecoder(
typ reflect.Type,
decoder func(d *Decoder, v reflect.Value, extLen int) error,
) decoderFunc {
nilable := typ.Kind() == reflect.Ptr

return func(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
v.Set(reflect.Zero(typ))
return d.DecodeNil()
}

if nilable && v.IsNil() {
v.Set(reflect.New(typ.Elem()))
}

return nilAwareDecoder(typ, func(d *Decoder, v reflect.Value) error {
extID, extLen, err := d.DecodeExtHeader()
if err != nil {
return err
Expand All @@ -162,7 +151,7 @@ func makeExtDecoder(
return fmt.Errorf("msgpack: got ext type=%d, wanted %d", extID, wantedExtID)
}
return decoder(d, v, extLen)
}
})
}

func makeExtDecoderAddr(extDecoder decoderFunc) decoderFunc {
Expand Down Expand Up @@ -266,7 +255,7 @@ func (d *Decoder) decodeInterfaceExt(c byte) (interface{}, error) {
}

v := reflect.New(info.Type).Elem()
if nilable(v) && v.IsNil() {
if nilable(v.Kind()) && v.IsNil() {
v.Set(reflect.New(info.Type.Elem()))
}

Expand Down
15 changes: 15 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ type CustomEncoderField struct {
Field CustomEncoder
}

type CustomEncoderEmbeddedPtr struct {
*CustomEncoder
}

func (s *CustomEncoderEmbeddedPtr) DecodeMsgpack(dec *msgpack.Decoder) error {
if s.CustomEncoder == nil {
s.CustomEncoder = new(CustomEncoder)
}
return s.CustomEncoder.DecodeMsgpack(dec)
}

//------------------------------------------------------------------------------

type JSONFallbackTest struct {
Expand Down Expand Up @@ -561,6 +572,10 @@ var (
in: &CustomEncoderField{Field: CustomEncoder{"a", nil, 1}},
out: new(CustomEncoderField),
},
{
in: &CustomEncoderEmbeddedPtr{&CustomEncoder{"a", nil, 1}},
out: new(CustomEncoderEmbeddedPtr),
},

{in: repoURL, out: new(url.URL)},
{in: repoURL, out: new(*url.URL)},
Expand Down

0 comments on commit 19b8e3c

Please sign in to comment.