Skip to content

Commit

Permalink
fix: support pointer to structs in marshal/unmarshal
Browse files Browse the repository at this point in the history
For siderolabs/talos#6057

Signed-off-by: Dmitriy Matrenichev <[email protected]>
  • Loading branch information
DmitriyMV committed Aug 10, 2022
1 parent 49a85fa commit 3e56913
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
11 changes: 11 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) {
return
}

// If the pointer is to a struct
if deref(val.Type()).Kind() == reflect.Struct {
b, ok := tryEncodeFunc(val)
if ok {
putTag(m, num, protowire.BytesType)
putBytes(m, b)

return
}
}

m.encodeValue(num, val.Elem())

case reflect.Interface:
Expand Down
21 changes: 18 additions & 3 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,16 +489,16 @@ func TestCustomEcnoders(t *testing.T) {
},
"should use custom encoder on pointer": {
testCustomEncodersDecoders(
encodeCustomEncoderStruct,
decodeCustomEncoderStruct,
encodeCustomEncoderStructPtr,
decodeCustomEncoderStructPtr,
OneFieldStruct[*CustomEncoderStruct]{
Field: &CustomEncoderStruct{
Value: 150,
},
},
OneFieldStruct[*CustomEncoderStruct]{
Field: &CustomEncoderStruct{
Value: 152,
Value: 156,
},
},
),
Expand Down Expand Up @@ -547,6 +547,21 @@ func decodeCustomEncoderStruct(slc []byte) (CustomEncoderStruct, error) {
}, err
}

func encodeCustomEncoderStructPtr(v *CustomEncoderStruct) ([]byte, error) {
return []byte(strconv.Itoa(v.Value + 3)), nil
}

func decodeCustomEncoderStructPtr(slc []byte) (*CustomEncoderStruct, error) {
res, err := strconv.Atoi(string(slc))
if err != nil {
return &CustomEncoderStruct{}, err
}

return &CustomEncoderStruct{
Value: res + 3,
}, err
}

func testCustomEncodersDecoders[V any, T any](
enc func(T) ([]byte, error),
dec func([]byte) (T, error),
Expand Down
14 changes: 13 additions & 1 deletion unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func instantiate(dst reflect.Value) error {
return nil
}

//nolint:cyclop,gocyclo
//nolint:cyclop,gocognit,gocyclo
func unmarshalBytes(dst reflect.Value, value complexValue) (err error) {
defer func() {
if err != nil {
Expand Down Expand Up @@ -413,6 +413,18 @@ func unmarshalBytes(dst reflect.Value, value complexValue) (err error) {
}
}

// If the pointer is to a struct
if deref(dst.Type()).Kind() == reflect.Struct {
ok, err := tryDecodeFunc(bytes, dst)
if err != nil {
return err
}

if ok {
return nil
}
}

return unmarshalBytes(dst.Elem(), value)

case reflect.String:
Expand Down

0 comments on commit 3e56913

Please sign in to comment.