diff --git a/mapstructure.go b/mapstructure.go index 1cd6204..05bc140 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -278,6 +278,10 @@ type DecoderConfig struct { // field name or tag. Defaults to `strings.EqualFold`. This can be used // to implement case-sensitive tag values, support snake casing, etc. MatchName func(mapKey, fieldName string) bool + + // DecodeNil, if set to true, will cause the DecodeHook (if present) to run + // even if the input is nil. This can be used to provide default values. + DecodeNil bool } // A Decoder takes a raw interface value and turns it into structured @@ -451,6 +455,8 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e } } + decodeNil := d.config.DecodeNil && d.config.DecodeHook != nil + if input == nil { // If the data is nil, then we don't set anything, unless ZeroFields is set // to true. @@ -461,17 +467,27 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) } } - return nil + + if !decodeNil { + return nil + } } if !inputVal.IsValid() { - // If the input value is invalid, then we just set the value - // to be the zero value. - outVal.Set(reflect.Zero(outVal.Type())) - if d.config.Metadata != nil && name != "" { - d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + if !decodeNil { + // If the input value is invalid, then we just set the value + // to be the zero value. + outVal.Set(reflect.Zero(outVal.Type())) + if d.config.Metadata != nil && name != "" { + d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) + } + return nil } - return nil + + // If we get here, we have an untyped nil so the type of the input is assumed. + // We do this because all subsequent code requires a valid value for inputVal. + var mapVal map[string]interface{} + inputVal = reflect.MakeMap(reflect.TypeOf(mapVal)) } if d.cachedDecodeHook != nil { diff --git a/mapstructure_test.go b/mapstructure_test.go index 4d246cc..e30ff47 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -3083,6 +3083,171 @@ func TestDecoder_IgnoreUntaggedFieldsWithStruct(t *testing.T) { } } +func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) { + t.Parallel() + + type Transformed struct { + Message string + When string + } + + helloHook := func(reflect.Type, reflect.Type, interface{}) (interface{}, error) { + return Transformed{Message: "hello"}, nil + } + goodbyeHook := func(reflect.Type, reflect.Type, interface{}) (interface{}, error) { + return Transformed{Message: "goodbye"}, nil + } + appendHook := func(from reflect.Value, to reflect.Value) (interface{}, error) { + if from.Kind() == reflect.Map { + stringMap := from.Interface().(map[string]interface{}) + stringMap["when"] = "see you later" + return stringMap, nil + } + return from.Interface(), nil + } + + tests := []struct { + name string + decodeNil bool + input interface{} + result Transformed + expectedResult Transformed + decodeHook DecodeHookFunc + }{ + { + name: "decodeNil=true for nil input with hook", + decodeNil: true, + input: nil, + decodeHook: helloHook, + expectedResult: Transformed{Message: "hello"}, + }, + { + name: "decodeNil=true for nil input without hook", + decodeNil: true, + input: nil, + expectedResult: Transformed{Message: ""}, + }, + { + name: "decodeNil=false for nil input with hook", + decodeNil: false, + input: nil, + decodeHook: helloHook, + expectedResult: Transformed{Message: ""}, + }, + { + name: "decodeNil=false for nil input without hook", + decodeNil: false, + input: nil, + expectedResult: Transformed{Message: ""}, + }, + { + name: "decodeNil=true for non-nil input without hook", + decodeNil: true, + input: map[string]interface{}{"message": "bar"}, + expectedResult: Transformed{Message: "bar"}, + }, + { + name: "decodeNil=true for non-nil input with hook", + decodeNil: true, + input: map[string]interface{}{"message": "bar"}, + decodeHook: goodbyeHook, + expectedResult: Transformed{Message: "goodbye"}, + }, + { + name: "decodeNil=false for non-nil input without hook", + decodeNil: false, + input: map[string]interface{}{"message": "bar"}, + expectedResult: Transformed{Message: "bar"}, + }, + { + name: "decodeNil=false for non-nil input with hook", + decodeNil: false, + input: map[string]interface{}{"message": "bar"}, + decodeHook: goodbyeHook, + expectedResult: Transformed{Message: "goodbye"}, + }, + { + name: "decodeNil=true for nil input without hook and non-empty result", + decodeNil: true, + input: nil, + result: Transformed{Message: "foo"}, + expectedResult: Transformed{Message: "foo"}, + }, + { + name: "decodeNil=true for nil input with hook and non-empty result", + decodeNil: true, + input: nil, + result: Transformed{Message: "foo"}, + decodeHook: helloHook, + expectedResult: Transformed{Message: "hello"}, + }, + { + name: "decodeNil=false for nil input without hook and non-empty result", + decodeNil: false, + input: nil, + result: Transformed{Message: "foo"}, + expectedResult: Transformed{Message: "foo"}, + }, + { + name: "decodeNil=false for nil input with hook and non-empty result", + decodeNil: false, + input: nil, + result: Transformed{Message: "foo"}, + decodeHook: helloHook, + expectedResult: Transformed{Message: "foo"}, + }, + { + name: "decodeNil=false for non-nil input with hook that appends a value", + decodeNil: false, + input: map[string]interface{}{"message": "bar"}, + decodeHook: appendHook, + expectedResult: Transformed{Message: "bar", When: "see you later"}, + }, + { + name: "decodeNil=true for non-nil input with hook that appends a value", + decodeNil: true, + input: map[string]interface{}{"message": "bar"}, + decodeHook: appendHook, + expectedResult: Transformed{Message: "bar", When: "see you later"}, + }, + { + name: "decodeNil=true for nil input with hook that appends a value", + decodeNil: true, + decodeHook: appendHook, + expectedResult: Transformed{When: "see you later"}, + }, + { + name: "decodeNil=false for nil input with hook that appends a value", + decodeNil: false, + decodeHook: appendHook, + expectedResult: Transformed{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + config := &DecoderConfig{ + Result: &test.result, + DecodeNil: test.decodeNil, + DecodeHook: test.decodeHook, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := decoder.Decode(test.input); err != nil { + t.Fatalf("got an err: %s", err) + } + + if test.result != test.expectedResult { + t.Errorf("result should be: %#v, got %#v", test.expectedResult, test.result) + } + }) + } +} + func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { var result Slice err := Decode(input, &result)