Skip to content

Commit

Permalink
feat: add support of text marshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma committed Oct 27, 2020
1 parent 96df394 commit e2aac9a
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 2 deletions.
8 changes: 8 additions & 0 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder {
}

func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil {
return dec
}

// Handle eface case when it isnt a union
if typ.Kind() == reflect.Interface && schema.Type() != Union {
if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
Expand Down Expand Up @@ -152,6 +156,10 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
}

func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil {
return enc
}

if typ.Kind() == reflect.Interface {
return &interfaceEncoder{schema: schema, typ: typ}
}
Expand Down
70 changes: 70 additions & 0 deletions codec_marshaler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package avro

import (
"encoding"
"unsafe"

"github.com/modern-go/reflect2"
)

var (
textMarshalerType = reflect2.TypeOfPtr((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerType = reflect2.TypeOfPtr((*encoding.TextUnmarshaler)(nil)).Elem()
)

func createDecoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
if typ.Implements(textUnmarshalerType) && schema.Type() == String {
return &textMarshalerCodec{typ}
}
ptrType := reflect2.PtrTo(typ)
if ptrType.Implements(textUnmarshalerType) && schema.Type() == String {
return &referenceDecoder{
&textMarshalerCodec{ptrType},
}
}
return nil
}

func createEncoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
if typ.Implements(textMarshalerType) && schema.Type() == String {
return &textMarshalerCodec{
typ: typ,
}
}
return nil
}

type textMarshalerCodec struct {
typ reflect2.Type
}

func (c textMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
obj := c.typ.UnsafeIndirect(ptr)
if reflect2.IsNil(obj) {
ptrType := c.typ.(*reflect2.UnsafePtrType)
newPtr := ptrType.Elem().UnsafeNew()
*((*unsafe.Pointer)(ptr)) = newPtr
obj = c.typ.UnsafeIndirect(ptr)
}
unmarshaler := (obj).(encoding.TextUnmarshaler)
b := r.ReadBytes()
err := unmarshaler.UnmarshalText(b)
if err != nil {
r.ReportError("textMarshalerCodec", err.Error())
}
}

func (c textMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) {
obj := c.typ.UnsafeIndirect(ptr)
if c.typ.IsNullable() && reflect2.IsNil(obj) {
w.WriteBytes(nil)
return
}
marshaler := (obj).(encoding.TextMarshaler)
b, err := marshaler.MarshalText()
if err != nil {
w.Error = err
return
}
w.WriteBytes(b)
}
143 changes: 143 additions & 0 deletions codec_marshaler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package avro_test

import (
"bytes"
"errors"
"testing"
"time"

"github.com/hamba/avro"
"github.com/stretchr/testify/assert"
)

func TestDecoder_TextUnmarshalerPtr(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}
schema := "string"
dec, err := avro.NewDecoder(schema, bytes.NewReader(data))
assert.NoError(t, err)

var ts TestTimestampPtr
err = dec.Decode(&ts)

assert.NoError(t, err)
want := TestTimestampPtr(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC))
assert.Equal(t, want, ts)
}

func TestDecoder_TextUnmarshalerPtrPtr(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}
schema := "string"
dec, err := avro.NewDecoder(schema, bytes.NewReader(data))
assert.NoError(t, err)

var ts *TestTimestampPtr
err = dec.Decode(&ts)

assert.NoError(t, err)
assert.NotNil(t, ts)
want := TestTimestampPtr(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC))
assert.Equal(t, want, *ts)
}

func TestDecoder_TextUnmarshalerError(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}
schema := "string"
dec, err := avro.NewDecoder(schema, bytes.NewReader(data))
assert.NoError(t, err)

var ts *TestTimestampError
err = dec.Decode(&ts)

assert.Error(t, err)
}

func TestEncoder_TextMarshaler(t *testing.T) {
defer ConfigTeardown()

schema := "string"
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)
ts := TestTimestamp(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC))

err = enc.Encode(ts)

assert.NoError(t, err)
assert.Equal(t, []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}, buf.Bytes())
}

func TestEncoder_TextMarshalerPtr(t *testing.T) {
defer ConfigTeardown()

schema := "string"
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)
ts := TestTimestampPtr(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC))

err = enc.Encode(&ts)

assert.NoError(t, err)
assert.Equal(t, []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}, buf.Bytes())
}

func TestEncoder_TextMarshalerPtrNil(t *testing.T) {
defer ConfigTeardown()

schema := "string"
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)
var ts *TestTimestampPtr

err = enc.Encode(ts)

assert.NoError(t, err)
assert.Equal(t, []byte{0x00}, buf.Bytes())
}

func TestEncoder_TextMarshalerError(t *testing.T) {
defer ConfigTeardown()

schema := "string"
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)
ts := TestTimestampError{}

err = enc.Encode(&ts)

assert.Error(t, err)
}

type TestTimestamp time.Time

func (t TestTimestamp) MarshalText() ([]byte, error) {
return (time.Time)(t).MarshalText()
}

type TestTimestampPtr time.Time

func (t *TestTimestampPtr) UnmarshalText(data []byte) error {
return (*time.Time)(t).UnmarshalText(data)
}

func (t *TestTimestampPtr) MarshalText() ([]byte, error) {
return (*time.Time)(t).MarshalText()
}

type TestTimestampError time.Time

func (t *TestTimestampError) UnmarshalText(data []byte) error {
return errors.New("test")
}

func (t *TestTimestampError) MarshalText() ([]byte, error) {
return nil, errors.New("test")
}
8 changes: 8 additions & 0 deletions codec_ptr.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ func (d *dereferenceEncoder) Encode(ptr unsafe.Pointer, w *Writer) {

d.encoder.Encode(*((*unsafe.Pointer)(ptr)), w)
}

type referenceDecoder struct {
decoder ValDecoder
}

func (decoder *referenceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
decoder.decoder.Decode(unsafe.Pointer(&ptr), r)
}
4 changes: 2 additions & 2 deletions encoder_native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func TestEncoder_BytesRat_Zero(t *testing.T) {
func TestEncoder_BytesRatInvalidSchema(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"string"}`
schema := `{"type":"int"}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)
Expand All @@ -475,7 +475,7 @@ func TestEncoder_BytesRatInvalidSchema(t *testing.T) {
func TestEncoder_BytesRatInvalidLogicalSchema(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"string","logicalType":"uuid"}`
schema := `{"type":"int","logicalType":"uuid"}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)
Expand Down

0 comments on commit e2aac9a

Please sign in to comment.