diff --git a/codec.go b/codec.go index f95814cf..cb83035a 100644 --- a/codec.go +++ b/codec.go @@ -86,6 +86,10 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder { } func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { + if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil { + return dec + } + // Handle eface case when it isnt a union if typ.Kind() == reflect.Interface && schema.Type() != Union { if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { @@ -152,6 +156,10 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { } func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { + if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil { + return enc + } + if typ.Kind() == reflect.Interface { return &interfaceEncoder{schema: schema, typ: typ} } diff --git a/codec_marshaler.go b/codec_marshaler.go new file mode 100644 index 00000000..fa705119 --- /dev/null +++ b/codec_marshaler.go @@ -0,0 +1,70 @@ +package avro + +import ( + "encoding" + "unsafe" + + "github.com/modern-go/reflect2" +) + +var ( + textMarshalerType = reflect2.TypeOfPtr((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshalerType = reflect2.TypeOfPtr((*encoding.TextUnmarshaler)(nil)).Elem() +) + +func createDecoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { + if typ.Implements(textUnmarshalerType) && schema.Type() == String { + return &textMarshalerCodec{typ} + } + ptrType := reflect2.PtrTo(typ) + if ptrType.Implements(textUnmarshalerType) && schema.Type() == String { + return &referenceDecoder{ + &textMarshalerCodec{ptrType}, + } + } + return nil +} + +func createEncoderOfMarshaler(_ *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { + if typ.Implements(textMarshalerType) && schema.Type() == String { + return &textMarshalerCodec{ + typ: typ, + } + } + return nil +} + +type textMarshalerCodec struct { + typ reflect2.Type +} + +func (c textMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { + obj := c.typ.UnsafeIndirect(ptr) + if reflect2.IsNil(obj) { + ptrType := c.typ.(*reflect2.UnsafePtrType) + newPtr := ptrType.Elem().UnsafeNew() + *((*unsafe.Pointer)(ptr)) = newPtr + obj = c.typ.UnsafeIndirect(ptr) + } + unmarshaler := (obj).(encoding.TextUnmarshaler) + b := r.ReadBytes() + err := unmarshaler.UnmarshalText(b) + if err != nil { + r.ReportError("textMarshalerCodec", err.Error()) + } +} + +func (c textMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) { + obj := c.typ.UnsafeIndirect(ptr) + if c.typ.IsNullable() && reflect2.IsNil(obj) { + w.WriteBytes(nil) + return + } + marshaler := (obj).(encoding.TextMarshaler) + b, err := marshaler.MarshalText() + if err != nil { + w.Error = err + return + } + w.WriteBytes(b) +} diff --git a/codec_marshaler_test.go b/codec_marshaler_test.go new file mode 100644 index 00000000..6d6f9b4e --- /dev/null +++ b/codec_marshaler_test.go @@ -0,0 +1,143 @@ +package avro_test + +import ( + "bytes" + "errors" + "testing" + "time" + + "github.com/hamba/avro" + "github.com/stretchr/testify/assert" +) + +func TestDecoder_TextUnmarshalerPtr(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a} + schema := "string" + dec, err := avro.NewDecoder(schema, bytes.NewReader(data)) + assert.NoError(t, err) + + var ts TestTimestampPtr + err = dec.Decode(&ts) + + assert.NoError(t, err) + want := TestTimestampPtr(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC)) + assert.Equal(t, want, ts) +} + +func TestDecoder_TextUnmarshalerPtrPtr(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a} + schema := "string" + dec, err := avro.NewDecoder(schema, bytes.NewReader(data)) + assert.NoError(t, err) + + var ts *TestTimestampPtr + err = dec.Decode(&ts) + + assert.NoError(t, err) + assert.NotNil(t, ts) + want := TestTimestampPtr(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC)) + assert.Equal(t, want, *ts) +} + +func TestDecoder_TextUnmarshalerError(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a} + schema := "string" + dec, err := avro.NewDecoder(schema, bytes.NewReader(data)) + assert.NoError(t, err) + + var ts *TestTimestampError + err = dec.Decode(&ts) + + assert.Error(t, err) +} + +func TestEncoder_TextMarshaler(t *testing.T) { + defer ConfigTeardown() + + schema := "string" + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + ts := TestTimestamp(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC)) + + err = enc.Encode(ts) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}, buf.Bytes()) +} + +func TestEncoder_TextMarshalerPtr(t *testing.T) { + defer ConfigTeardown() + + schema := "string" + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + ts := TestTimestampPtr(time.Date(2020, 01, 02, 03, 04, 05, 00, time.UTC)) + + err = enc.Encode(&ts) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x28, 0x32, 0x30, 0x32, 0x30, 0x2d, 0x30, 0x31, 0x2d, 0x30, 0x32, 0x54, 0x30, 0x33, 0x3a, 0x30, 0x34, 0x3a, 0x30, 0x35, 0x5a}, buf.Bytes()) +} + +func TestEncoder_TextMarshalerPtrNil(t *testing.T) { + defer ConfigTeardown() + + schema := "string" + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + var ts *TestTimestampPtr + + err = enc.Encode(ts) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x00}, buf.Bytes()) +} + +func TestEncoder_TextMarshalerError(t *testing.T) { + defer ConfigTeardown() + + schema := "string" + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + ts := TestTimestampError{} + + err = enc.Encode(&ts) + + assert.Error(t, err) +} + +type TestTimestamp time.Time + +func (t TestTimestamp) MarshalText() ([]byte, error) { + return (time.Time)(t).MarshalText() +} + +type TestTimestampPtr time.Time + +func (t *TestTimestampPtr) UnmarshalText(data []byte) error { + return (*time.Time)(t).UnmarshalText(data) +} + +func (t *TestTimestampPtr) MarshalText() ([]byte, error) { + return (*time.Time)(t).MarshalText() +} + +type TestTimestampError time.Time + +func (t *TestTimestampError) UnmarshalText(data []byte) error { + return errors.New("test") +} + +func (t *TestTimestampError) MarshalText() ([]byte, error) { + return nil, errors.New("test") +} diff --git a/codec_ptr.go b/codec_ptr.go index 1958b8a6..fc94a68c 100644 --- a/codec_ptr.go +++ b/codec_ptr.go @@ -56,3 +56,11 @@ func (d *dereferenceEncoder) Encode(ptr unsafe.Pointer, w *Writer) { d.encoder.Encode(*((*unsafe.Pointer)(ptr)), w) } + +type referenceDecoder struct { + decoder ValDecoder +} + +func (decoder *referenceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + decoder.decoder.Decode(unsafe.Pointer(&ptr), r) +} diff --git a/encoder_native_test.go b/encoder_native_test.go index 691a5254..b5997f94 100644 --- a/encoder_native_test.go +++ b/encoder_native_test.go @@ -462,7 +462,7 @@ func TestEncoder_BytesRat_Zero(t *testing.T) { func TestEncoder_BytesRatInvalidSchema(t *testing.T) { defer ConfigTeardown() - schema := `{"type":"string"}` + schema := `{"type":"int"}` buf := bytes.NewBuffer([]byte{}) enc, err := avro.NewEncoder(schema, buf) assert.NoError(t, err) @@ -475,7 +475,7 @@ func TestEncoder_BytesRatInvalidSchema(t *testing.T) { func TestEncoder_BytesRatInvalidLogicalSchema(t *testing.T) { defer ConfigTeardown() - schema := `{"type":"string","logicalType":"uuid"}` + schema := `{"type":"int","logicalType":"date"}` buf := bytes.NewBuffer([]byte{}) enc, err := avro.NewEncoder(schema, buf) assert.NoError(t, err)