diff --git a/marshal.go b/marshal.go index afe6c371..dcddad8d 100644 --- a/marshal.go +++ b/marshal.go @@ -320,20 +320,25 @@ func (e *Encoder) valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, er tval := e.nextTree() switch mtype.Kind() { case reflect.Struct: - for i := 0; i < mtype.NumField(); i++ { - mtypef, mvalf := mtype.Field(i), mval.Field(i) - opts := tomlOptions(mtypef, e.annotation) - if opts.include && ((mtypef.Type.Kind() != reflect.Interface && !opts.omitempty) || !isZero(mvalf)) { - val, err := e.valueToToml(mtypef.Type, mvalf) - if err != nil { - return nil, err - } + switch mval.Interface().(type) { + case Tree: + reflect.ValueOf(tval).Elem().Set(mval) + default: + for i := 0; i < mtype.NumField(); i++ { + mtypef, mvalf := mtype.Field(i), mval.Field(i) + opts := tomlOptions(mtypef, e.annotation) + if opts.include && ((mtypef.Type.Kind() != reflect.Interface && !opts.omitempty) || !isZero(mvalf)) { + val, err := e.valueToToml(mtypef.Type, mvalf) + if err != nil { + return nil, err + } - tval.SetWithOptions(opts.name, SetOptions{ - Comment: opts.comment, - Commented: opts.commented, - Multiline: opts.multiline, - }, val) + tval.SetWithOptions(opts.name, SetOptions{ + Comment: opts.comment, + Commented: opts.commented, + Multiline: opts.multiline, + }, val) + } } } case reflect.Map: @@ -570,11 +575,17 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V mval = reflect.New(mtype).Elem() } - for i := 0; i < mtype.NumField(); i++ { - mtypef := mtype.Field(i) - an := annotation{tag: d.tagName} - opts := tomlOptions(mtypef, an) - if opts.include { + switch mval.Interface().(type) { + case Tree: + mval.Set(reflect.ValueOf(tval).Elem()) + default: + for i := 0; i < mtype.NumField(); i++ { + mtypef := mtype.Field(i) + an := annotation{tag: d.tagName} + opts := tomlOptions(mtypef, an) + if !opts.include { + continue + } baseKey := opts.name keysToTry := []string{ baseKey, diff --git a/marshal_test.go b/marshal_test.go index 9b099927..fec1662d 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1444,6 +1444,54 @@ func TestMarshalCustomMultiline(t *testing.T) { } } +func TestMarshalEmbedTree(t *testing.T) { + expected := []byte(`OuterField1 = "Out" +OuterField2 = 1024 + +[TreeField] + InnerField1 = "In" + InnerField2 = 2048 + + [TreeField.EmbedStruct] + EmbedField = "Embed" +`) + type InnerStruct struct { + InnerField1 string + InnerField2 int + EmbedStruct struct{ + EmbedField string + } + } + + type OuterStruct struct { + OuterField1 string + OuterField2 int + TreeField *Tree + } + + tree, err := Load(` +InnerField1 = "In" +InnerField2 = 2048 + +[EmbedStruct] + EmbedField = "Embed" +`) + if err != nil { + t.Fatal(err) + } + + out := OuterStruct{ + "Out", + 1024, + tree, + } + actual, _ := Marshal(out) + + if !bytes.Equal(actual, expected){ + t.Errorf("Bad marshal: expected %s, got %s", expected, actual) + } +} + var testDocBasicToml = []byte(` [document] bool_val = true @@ -2674,3 +2722,53 @@ InnerField = "After4" t.Fatal(err) } } + +func TestUnmarshalEmbedTree(t *testing.T) { + toml := []byte(` +OuterField1 = "Out" +OuterField2 = 1024 + +[TreeField] +InnerField1 = "In" +InnerField2 = 2048 + + [TreeField.EmbedStruct] + EmbedField = "Embed" + +`) + type InnerStruct struct { + InnerField1 string + InnerField2 int + EmbedStruct struct{ + EmbedField string + } + } + + type OuterStruct struct { + OuterField1 string + OuterField2 int + TreeField *Tree + } + + out := OuterStruct{} + actual := InnerStruct{} + expected := InnerStruct{ + "In", + 2048, + struct{ + EmbedField string + }{ + EmbedField:"Embed", + }, + } + if err := Unmarshal(toml, &out); err != nil { + t.Fatal(err) + } + if err := out.TreeField.Unmarshal(&actual); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(actual, expected){ + t.Errorf("Bad unmarshal: expected %v, got %v", expected, actual) + } +}