diff --git a/mapstructure.go b/mapstructure.go index 7581806a..2b16c4f3 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -265,6 +265,10 @@ type DecoderConfig struct { // defaults to "mapstructure" TagName string + // The name of the value in the tag that indicates a field should + // be squashed. This defaults to "squash". + SquashName string + // IgnoreUntaggedFields ignores all struct fields without explicit // TagName, comparable to `mapstructure:"-"` as default behaviour. IgnoreUntaggedFields bool @@ -400,6 +404,10 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) { config.TagName = "mapstructure" } + if config.SquashName == "" { + config.SquashName = "squash" + } + if config.MatchName == nil { config.MatchName = strings.EqualFold } @@ -945,7 +953,7 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re } // If "squash" is specified in the tag, we squash the field down. - squash = squash || strings.Index(tagValue[index+1:], "squash") != -1 + squash = squash || strings.Contains(tagValue[index+1:], d.config.SquashName) if squash { // When squashing, the embedded type can be a pointer to a struct. if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { @@ -1321,7 +1329,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e // We always parse the tags cause we're looking for other tags too tagParts := strings.Split(fieldType.Tag.Get(d.config.TagName), ",") for _, tag := range tagParts[1:] { - if tag == "squash" { + if tag == d.config.SquashName { squash = true break } diff --git a/mapstructure_test.go b/mapstructure_test.go index d31129d7..9aa03fdc 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -48,6 +48,10 @@ type BasicSquash struct { Test Basic `mapstructure:",squash"` } +type BasicJSONInline struct { + Test Basic `json:",inline"` +} + type Embedded struct { Basic Vunique string @@ -476,6 +480,62 @@ func TestDecodeFrom_BasicSquash(t *testing.T) { } } +func TestDecode_BasicJSONInline(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "vstring": "foo", + } + + var result BasicJSONInline + d, err := NewDecoder(&DecoderConfig{TagName: "json", SquashName: "inline", Result: &result}) + if err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + + if err := d.Decode(input); err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + + if result.Test.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Test.Vstring) + } +} + +func TestDecodeFrom_BasicJSONInline(t *testing.T) { + t.Parallel() + + var v interface{} + var ok bool + + input := BasicJSONInline{ + Test: Basic{ + Vstring: "foo", + }, + } + + var result map[string]interface{} + d, err := NewDecoder(&DecoderConfig{TagName: "json", SquashName: "inline", Result: &result}) + if err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + + if err := d.Decode(input); err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + + if _, ok = result["Test"]; ok { + t.Error("test should not be present in map") + } + + v, ok = result["Vstring"] + if !ok { + t.Error("vstring should be present in map") + } else if !reflect.DeepEqual(v, "foo") { + t.Errorf("vstring value should be 'foo': %#v", v) + } +} + func TestDecode_Embedded(t *testing.T) { t.Parallel()