Skip to content

Commit

Permalink
fix encoding recursive structs (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed May 4, 2024
1 parent 3ba3d5d commit aa3c35c
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 37 deletions.
22 changes: 13 additions & 9 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ func marshal(v interface{}, flags encodeFlags) (*dynamodb.AttributeValue, error)
}

rt := rv.Type()
enc, err := encodeType(rt, flags)
def, err := typedefOf(rt)
if err != nil {
return nil, err
}
enc, err := def.encodeType(rt, flags)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -106,7 +110,7 @@ type isZeroer interface {
IsZero() bool
}

func isZeroFunc(rt reflect.Type) func(rv reflect.Value) bool {
func (def *typedef) isZeroFunc(rt reflect.Type) func(rv reflect.Value) bool {
if rt.Implements(rtypeIsZeroer) {
return isZeroIface(rt, func(v isZeroer) bool {
return v.IsZero()
Expand All @@ -131,10 +135,10 @@ func isZeroFunc(rt reflect.Type) func(rv reflect.Value) bool {
return isNil

case reflect.Array:
return isZeroArray(rt)
return def.isZeroArray(rt)

case reflect.Struct:
return isZeroStruct(rt)
return def.isZeroStruct(rt)
}

return isZeroValue
Expand All @@ -160,13 +164,13 @@ func isZeroIface[T any](rt reflect.Type, isZero func(v T) bool) func(rv reflect.
}
}

func isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool {
fields, err := structFields(rt)
func (def *typedef) isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool {
fields, err := def.structFields(rt, false)
if err != nil {
return nil
}
return func(rv reflect.Value) bool {
for _, info := range fields {
for _, info := range *fields {
if info.isZero == nil {
continue
}
Expand All @@ -184,8 +188,8 @@ func isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool {
}
}

func isZeroArray(rt reflect.Type) func(reflect.Value) bool {
elemIsZero := isZeroFunc(rt.Elem())
func (def *typedef) isZeroArray(rt reflect.Type) func(reflect.Value) bool {
elemIsZero := def.isZeroFunc(rt.Elem())
return func(rv reflect.Value) bool {
for i := 0; i < rv.Len(); i++ {
if !elemIsZero(rv.Index(i)) {
Expand Down
70 changes: 70 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,73 @@ func TestMarshalItemBypass(t *testing.T) {
t.Error("bad unmarshal")
}
}

func TestMarshalRecursive(t *testing.T) {
t.SkipNow()

type Person struct {
Spouse *Person
Children []Person
Name string
}
type Friend struct {
ID int
Person Person
Nickname string
}
children := []Person{
{Name: "Bobby"},
}

hank := Person{
Spouse: &Person{
Name: "Peggy",
Children: children,
},
Children: children,
Name: "Hank",
}

t.Run("self-recursive", func(t *testing.T) {

want := map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Hank")},
"Spouse": {M: map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Peggy")},
"Children": {L: []*dynamodb.AttributeValue{
{M: map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Bobby")},
"Children": {L: []*dynamodb.AttributeValue{}},
}},
},
},
}},
"Children": {L: []*dynamodb.AttributeValue{
{M: map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Bobby")},
"Children": {L: []*dynamodb.AttributeValue{}},
}},
}},
}

got, err := MarshalItem(hank)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Error("bad", got)
}
})

t.Run("field is recursive", func(t *testing.T) {
friend := Friend{
Person: hank,
Nickname: "H-Dawg",
}
got, err := MarshalItem(friend)
if err != nil {
t.Fatal(err)
}
t.Fatal(got)
})
}
46 changes: 28 additions & 18 deletions encodefunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ import (

type encodeFunc func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error)

func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
try := rt
for {
// deref := func()
switch try {
case rtypeAttr:
return encode2(func(av *dynamodb.AttributeValue, _ encodeFlags) (*dynamodb.AttributeValue, error) {
Expand Down Expand Up @@ -54,7 +53,7 @@ func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {

switch rt.Kind() {
case reflect.Pointer:
return encodePtr(rt, flags)
return def.encodePtr(rt, flags)

// BOOL
case reflect.Bool:
Expand Down Expand Up @@ -84,30 +83,30 @@ func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
return encodeSet(rt, flags)
}
// lists (L)
return encodeList(rt, flags)
return def.encodeList(rt, flags)

case reflect.Map:
// sets (NS, SS, BS)
if flags&flagSet != 0 {
return encodeSet(rt, flags)
}
// M
return encodeMapM(rt, flags)
return def.encodeMapM(rt, flags)

// M
case reflect.Struct:
return encodeStruct(rt)
return def.encodeStruct(rt)

case reflect.Interface:
if rt.NumMethod() == 0 {
return encodeAny, nil
return def.encodeAny, nil
}
}
return nil, fmt.Errorf("dynamo marshal: unsupported type %s", rt.String())
}

func encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
elem, err := encodeType(rt.Elem(), flags)
func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
elem, err := def.encodeType(rt.Elem(), flags)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -210,13 +209,24 @@ func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc {
}
}

func encodeStruct(rt reflect.Type) (encodeFunc, error) {
fields, err := structFields(rt)
func (def *typedef) encodeStruct(rt reflect.Type) (encodeFunc, error) {
var fields *[]structField
var err error
if def.sameAsRoot(rt) {
fields, err = def.structFields(rt, false)
} else {
var subdef *typedef
subdef, err = typedefOf(rt)
if subdef != nil {
fields = &subdef.fields
}
}
if err != nil {
return nil, err
}

return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) {
item, err := encodeItem(fields, rv)
item, err := encodeItem(*fields, rv)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -303,7 +313,7 @@ func encodeSliceBS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValu
return &dynamodb.AttributeValue{BS: bs}, nil
}

func encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
keyString := encodeMapKeyFunc(rt)
if keyString == nil {
return nil, fmt.Errorf("dynamo marshal: map key type must be string or encoding.TextMarshaler, have %v", rt)
Expand All @@ -319,7 +329,7 @@ func encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
subflags |= flagOmitEmpty
}

valueEnc, err := encodeType(rt.Elem(), subflags)
valueEnc, err := def.encodeType(rt.Elem(), subflags)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -516,7 +526,7 @@ func encodeSet(rt /* []T | map[T]bool | map[T]struct{} */ reflect.Type, flags en
return nil, fmt.Errorf("dynamo: marshal: invalid type for set %s", rt.String())
}

func encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
// lists CAN be empty
subflags := flagNone
if flags&flagOmitEmptyElem == 0 {
Expand All @@ -530,7 +540,7 @@ func encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
subflags |= flagAllowEmptyElem
}

valueEnc, err := encodeType(rt.Elem(), subflags)
valueEnc, err := def.encodeType(rt.Elem(), subflags)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -560,14 +570,14 @@ func encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
}, nil
}

func encodeAny(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) {
func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) {
if !rv.CanInterface() || rv.IsNil() {
if flags&flagNull != 0 {
return nullAV, nil
}
return nil, nil
}
enc, err := encodeType(rv.Elem().Type(), flags)
enc, err := def.encodeType(rv.Elem().Type(), flags)
if err != nil {
return nil, err
}
Expand Down
45 changes: 35 additions & 10 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ import (
var typeCache sync.Map // unmarshalKey → *typedef

type typedef struct {
decoders map[unmarshalKey]decodeFunc
fields []structField
marshaler bool
decoders map[unmarshalKey]decodeFunc
fields []structField
root reflect.Type
}

func newTypedef(rt reflect.Type) (*typedef, error) {
def := &typedef{
decoders: make(map[unmarshalKey]decodeFunc),
// encoders: make(map[encodeKey]encodeFunc),
root: rt,
}
err := def.init(rt)
return def, err
Expand All @@ -44,8 +46,10 @@ func (def *typedef) init(rt reflect.Type) error {
return nil
}

var err error
def.fields, err = structFields(rt)
fieldptr, err := def.structFields(rt, true)
if fieldptr != nil {
def.fields = *fieldptr
}
return err
}

Expand Down Expand Up @@ -95,7 +99,7 @@ func (def *typedef) encodeItem(rv reflect.Value) (Item, error) {
case reflect.Struct:
return encodeItem(def.fields, rv)
case reflect.Map:
enc, err := encodeMapM(rv.Type(), flagNone)
enc, err := def.encodeMapM(rv.Type(), flagNone)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -395,10 +399,31 @@ type structField struct {
isZero func(reflect.Value) bool
}

func structFields(rt reflect.Type) ([]structField, error) {
// type encodeKey struct {
// rt reflect.Type
// flags encodeFlags
// }

func (def *typedef) sameAsRoot(rt reflect.Type) bool {
switch {
case rt == def.root:
return true
case def.root.Kind() == reflect.Pointer && rt.Kind() != reflect.Pointer:
return def.root.Elem() == rt
case def.root.Kind() != reflect.Pointer && rt.Kind() == reflect.Pointer:
return rt.Elem() == def.root
}
return false
}

func (def *typedef) structFields(rt reflect.Type, isRoot bool) (*[]structField, error) {
if !isRoot && def.sameAsRoot(rt) {
return &def.fields, nil
}

var fields []structField
err := visitTypeFields(rt, nil, nil, func(name string, index []int, flags encodeFlags, vt reflect.Type) error {
enc, err := encodeType(vt, flags)
enc, err := def.encodeType(vt, flags)
if err != nil {
return err
}
Expand All @@ -407,12 +432,12 @@ func structFields(rt reflect.Type) ([]structField, error) {
name: name,
flags: flags,
enc: enc,
isZero: isZeroFunc(vt),
isZero: def.isZeroFunc(vt),
}
fields = append(fields, field)
return nil
})
return fields, err
return &fields, err
}

var (
Expand Down
Loading

0 comments on commit aa3c35c

Please sign in to comment.