Skip to content

Commit

Permalink
feat: implement custom encoders/decoders
Browse files Browse the repository at this point in the history
This PR adds the ability to use free functions as custom encoders/decoders. This is important because we want to implement encoding for foreign types (for example netip.*).

Signed-off-by: Dmitriy Matrenichev <[email protected]>
  • Loading branch information
DmitriyMV committed Aug 1, 2022
1 parent 549761b commit 8a48bf0
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 9 deletions.
32 changes: 32 additions & 0 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package protoenc_test
import (
"testing"

"github.com/stretchr/testify/require"

"github.com/siderolabs/protoenc"
)

Expand Down Expand Up @@ -43,3 +45,33 @@ func BenchmarkEncode(b *testing.B) {
Store = result
}
}

func BenchmarkCustom(b *testing.B) {
b.Cleanup(func() {
protoenc.CleanEncoderDecoder()
})

o := OneFieldStruct[CustomEncoderStruct]{
Field: CustomEncoderStruct{
Value: 150,
},
}

protoenc.RegisterEncoderDecoder(encodeCustomEncoderStruct, decodeCustomEncoderStruct)

encoded, err := protoenc.Marshal(&o)
require.NoError(b, err)

b.ResetTimer()
b.ReportAllocs()

target := &OneFieldStruct[CustomEncoderStruct]{}
for i := 0; i < b.N; i++ {
*target = OneFieldStruct[CustomEncoderStruct]{}

err := protoenc.Unmarshal(encoded, target)
if err != nil {
b.Fatal(err)
}
}
}
23 changes: 23 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ func (m *marshaller) encodeStruct(val reflect.Value) {
panic("encodeStruct takes a struct")
}

res, ok := tryEncodeFunc(val)
if ok {
m.buf = append(m.buf, res...)

return
}

structFields, err := StructFields(val.Type())
if err != nil {
panic(err)
Expand Down Expand Up @@ -283,6 +290,22 @@ func (m *marshaller) tryEncodePredefined(num protowire.Number, val reflect.Value
return true
}

func tryEncodeFunc(val reflect.Value) ([]byte, bool) {
typ := val.Type()

enc, ok := encoders.Get(typ)
if !ok {
return nil, false
}

b, err := enc(val.Interface())
if err != nil {
panic(err)
}

return b, true
}

func asBinaryMarshaler(val reflect.Value) (encoding.BinaryMarshaler, bool) {
if enc, ok := val.Interface().(encoding.BinaryMarshaler); ok {
return enc, true
Expand Down
127 changes: 127 additions & 0 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding"
"encoding/hex"
"math/big"
"strconv"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -449,3 +450,129 @@ func makeIncorrectEmbedTest[V any](v V) func(t *testing.T) {
require.Error(t, err)
}
}

func TestCustomEcnoders(t *testing.T) {
tests := map[string]struct {
fn func(t *testing.T)
}{
"should use custom encoder": {
testCustomEncodersDecoders(
encodeCustomEncoderStruct,
decodeCustomEncoderStruct,
OneFieldStruct[CustomEncoderStruct]{
Field: CustomEncoderStruct{
Value: 150,
},
},
OneFieldStruct[CustomEncoderStruct]{
Field: CustomEncoderStruct{
Value: 152,
},
},
),
},
"should use custom encoder on pointer": {
testCustomEncodersDecoders(
encodeCustomEncoderStruct,
decodeCustomEncoderStruct,
OneFieldStruct[*CustomEncoderStruct]{
Field: &CustomEncoderStruct{
Value: 150,
},
},
OneFieldStruct[*CustomEncoderStruct]{
Field: &CustomEncoderStruct{
Value: 152,
},
},
),
},
"should use custom encoder on slice": {
testCustomEncodersDecoders(
encodeCustomEncoderStruct,
decodeCustomEncoderStruct,
OneFieldStruct[[]CustomEncoderStruct]{
Field: []CustomEncoderStruct{
{Value: 150},
{Value: 151},
},
},
OneFieldStruct[[]CustomEncoderStruct]{
Field: []CustomEncoderStruct{
{Value: 152},
{Value: 153},
},
},
),
},
}

for name, test := range tests {
t.Run(name, test.fn)
}
}

type CustomEncoderStruct struct {
Value int
}

func encodeCustomEncoderStruct(v CustomEncoderStruct) ([]byte, error) {
return []byte(strconv.Itoa(v.Value + 1)), nil
}

func decodeCustomEncoderStruct(slc []byte) (CustomEncoderStruct, error) {
res, err := strconv.Atoi(string(slc))
if err != nil {
return CustomEncoderStruct{}, err
}

return CustomEncoderStruct{
Value: res + 1,
}, err
}

func testCustomEncodersDecoders[V any, T any](
enc func(T) ([]byte, error),
dec func([]byte) (T, error),
original V,
expected V,
) func(t *testing.T) {
return func(t *testing.T) {
t.Cleanup(func() {
protoenc.CleanEncoderDecoder()
})

protoenc.RegisterEncoderDecoder(enc, dec)

encoded := must(protoenc.Marshal(&original))(t)

var result V

require.NoError(t, protoenc.Unmarshal(encoded, &result))
require.Equal(t, expected, result)
}
}

type OneFieldStruct[T any] struct {
Field T `protobuf:"1"`
}

func TestIncorrectCustomEncoders(t *testing.T) {
t.Cleanup(func() {
protoenc.CleanEncoderDecoder()
})

require.Panics(t, func() {
protoenc.RegisterEncoderDecoder(
func(v []CustomEncoderStruct) ([]byte, error) { return nil, nil },
func(slc []byte) ([]CustomEncoderStruct, error) { return nil, nil },
)
})

require.Panics(t, func() {
protoenc.RegisterEncoderDecoder(
func(v string) ([]byte, error) { return nil, nil },
func(slc []byte) (string, error) { return "", nil },
)
})
}
25 changes: 25 additions & 0 deletions pointer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package protoenc

//TODO: remove this once Go 1.19 lands

import (
"sync/atomic"
"unsafe"
)

// A Pointer is an atomic pointer of type *T. The zero value is a nil *T.
type Pointer[T any] struct {
v unsafe.Pointer
}

// Load atomically loads and returns the value stored in x.
func (x *Pointer[T]) Load() *T { return (*T)(atomic.LoadPointer(&x.v)) }

// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) {
return atomic.CompareAndSwapPointer(&x.v, unsafe.Pointer(old), unsafe.Pointer(new))
}
82 changes: 73 additions & 9 deletions type_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strconv"
"sync"
"unsafe"

"google.golang.org/protobuf/encoding/protowire"
)
Expand Down Expand Up @@ -147,21 +148,84 @@ func ParseTag(field reflect.StructField) int {
return num
}

var cache = typesCache{}

type typesCache struct {
type syncMap[K, V any] struct {
m sync.Map
}

func (tc *typesCache) Get(t reflect.Type) ([]FieldData, bool) {
value, ok := tc.m.Load(t)
func (sm *syncMap[K, V]) Get(k K) (V, bool) {
value, ok := sm.m.Load(k)
if !ok {
return nil, false
var zero V

return zero, false
}

return value.(V), true //nolint:forcetypeassert
}

func (sm *syncMap[K, V]) Add(k K, v V) {
sm.m.Store(k, v)
}

var (
cache = syncMap[reflect.Type, []FieldData]{}
encoders = syncMap[reflect.Type, encoder]{}
decoders = syncMap[reflect.Type, decoder]{}
)

type (
encoder func(any) ([]byte, error)
decoder func(slc []byte, dst reflect.Value) error
)

// RegisterEncoderDecoder registers the given encoder and decoder for the given type. T should be struct or
// pointer to struct. T and pointer to T are treated the same.
func RegisterEncoderDecoder[T any, Enc func(T) ([]byte, error), Dec func([]byte) (T, error)](enc Enc, dec Dec) {
var zero T

typ := deref(reflect.TypeOf(zero))
if typ.Kind() != reflect.Struct {
panic("RegisterEncoderDecoder: T must be a struct")
}

fnEnc := func(val any) ([]byte, error) {
v, ok := val.(T)
if !ok {
return nil, fmt.Errorf("%T is not %T", val, zero)
}

return enc(v)
}

fnDec := func(b []byte, dst reflect.Value) error {
if dst.Type() != typ {
return fmt.Errorf("%T is not %T", dst, zero)
}

v, err := dec(b)
if err != nil {
return err
}

*(*T)(unsafe.Pointer(dst.UnsafeAddr())) = v

return nil
}

if _, ok := encoders.Get(typ); ok {
panic("RegisterEncoderDecoder: encoder for type " + typ.String() + " already registered")
}

if _, ok := decoders.Get(typ); ok {
panic("RegisterEncoderDecoder: decoder for type " + typ.String() + " already registered")
}

return value.([]FieldData), true //nolint:forcetypeassert
encoders.Add(typ, fnEnc)
decoders.Add(typ, fnDec)
}

func (tc *typesCache) Add(t reflect.Type, structFields []FieldData) {
tc.m.Store(t, structFields)
// CleanEncoderDecoder cleans the map of encoders and decoders. It's not safe to it call concurrently.
func CleanEncoderDecoder() {
encoders = syncMap[reflect.Type, encoder]{}
decoders = syncMap[reflect.Type, decoder]{}
}
22 changes: 22 additions & 0 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ func (u *unmarshaller) unmarshalStruct(buf []byte, structVal reflect.Value) erro

zeroStructFields(structVal)

ok, err := tryDecodeFunc(buf, structVal)
if err != nil {
return err
}

if ok {
return nil
}

structFields, err := StructFields(structVal.Type())
if err != nil {
return err
Expand Down Expand Up @@ -392,6 +401,19 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin
return nil
}

func tryDecodeFunc(vb []byte, dst reflect.Value) (bool, error) {
dec, ok := decoders.Get(dst.Type())
if !ok {
return false, nil
}

if err := dec(vb, dst); err != nil {
return false, err
}

return true, nil
}

func decodeSignedInt(wiretype protowire.Type, v uint64) (int64, error) {
switch wiretype { //nolint:exhaustive
case protowire.VarintType:
Expand Down

0 comments on commit 8a48bf0

Please sign in to comment.