diff --git a/decode.go b/decode.go index 1b78725..785f6ec 100644 --- a/decode.go +++ b/decode.go @@ -329,6 +329,7 @@ func fieldsInStruct(rv reflect.Value) map[string]reflect.Value { for i := 0; i < rv.Type().NumField(); i++ { field := rv.Type().Field(i) fv := rv.Field(i) + isPtr := fv.Type().Kind() == reflect.Ptr name, _ := fieldInfo(field) if name == "-" { @@ -336,8 +337,20 @@ func fieldsInStruct(rv reflect.Value) map[string]reflect.Value { continue } - // embed anonymous structs - if fv.Type().Kind() == reflect.Struct && field.Anonymous { + // need to protect from setting unexported pointers because it will panic + if !fv.CanSet() && isPtr { + continue + } + + // embed anonymous structs, they could be pointers so test that too + if (fv.Type().Kind() == reflect.Struct || isPtr && fv.Type().Elem().Kind() == reflect.Struct) && field.Anonymous { + // set zero value for pointer + if isPtr { + zero := reflect.New(fv.Type().Elem()) + fv.Set(zero) + fv = zero + } + innerFields := fieldsInStruct(fv) for k, v := range innerFields { // don't clobber top-level fields diff --git a/decode_test.go b/decode_test.go index 26af6d6..5f81814 100644 --- a/decode_test.go +++ b/decode_test.go @@ -70,17 +70,23 @@ func TestUnmarshal(t *testing.T) { func TestUnmarshalItem(t *testing.T) { for _, tc := range itemEncodingTests { + expected := tc.in if tc.asymmetric { - continue + if tc.expectedDecode == nil { + continue + } + + expected = tc.expectedDecode } + rv := reflect.New(reflect.TypeOf(tc.in)) err := unmarshalItem(tc.out, rv.Interface()) if err != nil { t.Errorf("%s: unexpected error: %v", tc.name, err) } - if !reflect.DeepEqual(rv.Elem().Interface(), tc.in) { - t.Errorf("%s: bad result: %#v ≠ %#v", tc.name, rv.Elem().Interface(), tc.in) + if !reflect.DeepEqual(rv.Elem().Interface(), expected) { + t.Errorf("%s: bad result: %#v ≠ %#v", tc.name, rv.Elem().Interface(), expected) } } } diff --git a/encode.go b/encode.go index dc5df95..8b6772d 100644 --- a/encode.go +++ b/encode.go @@ -78,9 +78,10 @@ func marshalStruct(rv reflect.Value) (map[string]*dynamodb.AttributeValue, error name, flags := fieldInfo(field) omitempty := flags&flagOmitEmpty != 0 anonStruct := fv.Type().Kind() == reflect.Struct && field.Anonymous + pointerAnonStruct := fv.Type().Kind() == reflect.Ptr && fv.Type().Elem().Kind() == reflect.Struct && field.Anonymous switch { case !fv.CanInterface(): - if !anonStruct { + if !(!anonStruct || !pointerAnonStruct) { continue } case name == "-": @@ -92,7 +93,14 @@ func marshalStruct(rv reflect.Value) (map[string]*dynamodb.AttributeValue, error } // embed anonymous structs - if anonStruct { + if anonStruct || pointerAnonStruct { + if pointerAnonStruct { + if fv.IsNil() { + continue + } + fv = fv.Elem() + } + avs, err := marshalStruct(fv) if err != nil { return nil, err diff --git a/encoding_test.go b/encoding_test.go index 831bafa..6d2a353 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -231,10 +231,11 @@ var encodingTests = []struct { } var itemEncodingTests = []struct { - name string - in interface{} - out map[string]*dynamodb.AttributeValue - asymmetric bool + name string + in interface{} + out map[string]*dynamodb.AttributeValue + expectedDecode interface{} + asymmetric bool }{ { name: "strings", @@ -425,6 +426,49 @@ var itemEncodingTests = []struct { "Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)}, }, }, + { + name: "pointer embedded struct", + in: struct { + *embedded + }{ + embedded: &embedded{ + Embedded: true, + }, + }, + out: map[string]*dynamodb.AttributeValue{ + "Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)}, + }, + expectedDecode: struct { + *embedded + }{}, + asymmetric: true, + }, + { + name: "exported embedded struct", + in: struct { + ExportedEmbedded + }{ + ExportedEmbedded: ExportedEmbedded{ + Embedded: true, + }, + }, + out: map[string]*dynamodb.AttributeValue{ + "Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)}, + }, + }, + { + name: "exported pointer embedded struct", + in: struct { + *ExportedEmbedded + }{ + ExportedEmbedded: &ExportedEmbedded{ + Embedded: true, + }, + }, + out: map[string]*dynamodb.AttributeValue{ + "Embedded": &dynamodb.AttributeValue{BOOL: aws.Bool(true)}, + }, + }, { name: "embedded struct clobber", in: struct { @@ -437,6 +481,50 @@ var itemEncodingTests = []struct { "Embedded": &dynamodb.AttributeValue{S: aws.String("OK")}, }, }, + { + name: "pointer embedded struct clobber", + in: struct { + Embedded string + *embedded + }{ + Embedded: "OK", + }, + out: map[string]*dynamodb.AttributeValue{ + "Embedded": &dynamodb.AttributeValue{S: aws.String("OK")}, + }, + }, + { + name: "exported embedded struct clobber", + in: struct { + Embedded string + ExportedEmbedded + }{ + Embedded: "OK", + }, + out: map[string]*dynamodb.AttributeValue{ + "Embedded": &dynamodb.AttributeValue{S: aws.String("OK")}, + }, + }, + { + name: "exported pointer embedded struct clobber", + in: struct { + Embedded string + *ExportedEmbedded + }{ + Embedded: "OK", + }, + out: map[string]*dynamodb.AttributeValue{ + "Embedded": &dynamodb.AttributeValue{S: aws.String("OK")}, + }, + expectedDecode: struct { + Embedded string + *ExportedEmbedded + }{ + Embedded: "OK", + ExportedEmbedded: &ExportedEmbedded{}, + }, + asymmetric: true, + }, { name: "sets", in: struct { @@ -626,6 +714,10 @@ type embedded struct { Embedded bool } +type ExportedEmbedded struct { + Embedded bool +} + type customMarshaler int func (cm customMarshaler) MarshalDynamo() (*dynamodb.AttributeValue, error) {