Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle pointers to anonymous structs when (un)mashaling #139

Merged
merged 1 commit into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,28 @@ func fieldsInStruct(rv reflect.Value) map[string]reflect.Value {
for i := 0; i < rv.Type().NumField(); i++ {
field := rv.Type().Field(i)
fv := rv.Field(i)
isPtr := fv.Type().Kind() == reflect.Ptr

name, _ := fieldInfo(field)
if name == "-" {
// skip
continue
}

// embed anonymous structs
if fv.Type().Kind() == reflect.Struct && field.Anonymous {
// need to protect from setting unexported pointers because it will panic
if !fv.CanSet() && isPtr {
continue
}

// embed anonymous structs, they could be pointers so test that too
if (fv.Type().Kind() == reflect.Struct || isPtr && fv.Type().Elem().Kind() == reflect.Struct) && field.Anonymous {
// set zero value for pointer
if isPtr {
zero := reflect.New(fv.Type().Elem())
fv.Set(zero)
fv = zero
}

innerFields := fieldsInStruct(fv)
for k, v := range innerFields {
// don't clobber top-level fields
Expand Down
12 changes: 9 additions & 3 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,23 @@ func TestUnmarshal(t *testing.T) {

func TestUnmarshalItem(t *testing.T) {
for _, tc := range itemEncodingTests {
expected := tc.in
if tc.asymmetric {
continue
if tc.expectedDecode == nil {
continue
}

expected = tc.expectedDecode
}

rv := reflect.New(reflect.TypeOf(tc.in))
err := unmarshalItem(tc.out, rv.Interface())
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
}

if !reflect.DeepEqual(rv.Elem().Interface(), tc.in) {
t.Errorf("%s: bad result: %#v ≠ %#v", tc.name, rv.Elem().Interface(), tc.in)
if !reflect.DeepEqual(rv.Elem().Interface(), expected) {
t.Errorf("%s: bad result: %#v ≠ %#v", tc.name, rv.Elem().Interface(), expected)
}
}
}
Expand Down
12 changes: 10 additions & 2 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ func marshalStruct(rv reflect.Value) (map[string]*dynamodb.AttributeValue, error
name, flags := fieldInfo(field)
omitempty := flags&flagOmitEmpty != 0
anonStruct := fv.Type().Kind() == reflect.Struct && field.Anonymous
pointerAnonStruct := fv.Type().Kind() == reflect.Ptr && fv.Type().Elem().Kind() == reflect.Struct && field.Anonymous
switch {
case !fv.CanInterface():
if !anonStruct {
if !(!anonStruct || !pointerAnonStruct) {
continue
}
case name == "-":
Expand All @@ -92,7 +93,14 @@ func marshalStruct(rv reflect.Value) (map[string]*dynamodb.AttributeValue, error
}

// embed anonymous structs
if anonStruct {
if anonStruct || pointerAnonStruct {
if pointerAnonStruct {
if fv.IsNil() {
continue
}
fv = fv.Elem()
}

avs, err := marshalStruct(fv)
if err != nil {
return nil, err
Expand Down
100 changes: 96 additions & 4 deletions encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ var encodingTests = []struct {
}

var itemEncodingTests = []struct {
name string
in interface{}
out map[string]*dynamodb.AttributeValue
asymmetric bool
name string
in interface{}
out map[string]*dynamodb.AttributeValue
expectedDecode interface{}
asymmetric bool
}{
{
name: "strings",
Expand Down Expand Up @@ -425,6 +426,49 @@ var itemEncodingTests = []struct {
"Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)},
},
},
{
name: "pointer embedded struct",
in: struct {
*embedded
}{
embedded: &embedded{
Embedded: true,
},
},
out: map[string]*dynamodb.AttributeValue{
"Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)},
},
expectedDecode: struct {
*embedded
}{},
asymmetric: true,
},
{
name: "exported embedded struct",
in: struct {
ExportedEmbedded
}{
ExportedEmbedded: ExportedEmbedded{
Embedded: true,
},
},
out: map[string]*dynamodb.AttributeValue{
"Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)},
},
},
{
name: "exported pointer embedded struct",
in: struct {
*ExportedEmbedded
}{
ExportedEmbedded: &ExportedEmbedded{
Embedded: true,
},
},
out: map[string]*dynamodb.AttributeValue{
"Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)},
},
},
{
name: "embedded struct clobber",
in: struct {
Expand All @@ -437,6 +481,50 @@ var itemEncodingTests = []struct {
"Embedded": &dynamodb.AttributeValue{S: aws.String("OK")},
},
},
{
name: "pointer embedded struct clobber",
in: struct {
Embedded string
*embedded
}{
Embedded: "OK",
},
out: map[string]*dynamodb.AttributeValue{
"Embedded": &dynamodb.AttributeValue{S: aws.String("OK")},
},
},
{
name: "exported embedded struct clobber",
in: struct {
Embedded string
ExportedEmbedded
}{
Embedded: "OK",
},
out: map[string]*dynamodb.AttributeValue{
"Embedded": &dynamodb.AttributeValue{S: aws.String("OK")},
},
},
{
name: "exported pointer embedded struct clobber",
in: struct {
Embedded string
*ExportedEmbedded
}{
Embedded: "OK",
},
out: map[string]*dynamodb.AttributeValue{
"Embedded": &dynamodb.AttributeValue{S: aws.String("OK")},
},
expectedDecode: struct {
Embedded string
*ExportedEmbedded
}{
Embedded: "OK",
ExportedEmbedded: &ExportedEmbedded{},
},
asymmetric: true,
},
{
name: "sets",
in: struct {
Expand Down Expand Up @@ -626,6 +714,10 @@ type embedded struct {
Embedded bool
}

type ExportedEmbedded struct {
Embedded bool
}

type customMarshaler int

func (cm customMarshaler) MarshalDynamo() (*dynamodb.AttributeValue, error) {
Expand Down