diff --git a/mapstructure.go b/mapstructure.go index 7581806a..768dc9e3 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -1333,11 +1333,16 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } if squash { - if fieldVal.Kind() != reflect.Struct { + switch fieldVal.Kind() { + case reflect.Struct: + structs = append(structs, fieldVal) + case reflect.Interface: + if !fieldVal.IsNil() { + structs = append(structs, fieldVal.Elem().Elem()) + } + default: errors = appendErrors(errors, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) - } else { - structs = append(structs, fieldVal) } continue } diff --git a/mapstructure_test.go b/mapstructure_test.go index d31129d7..acd74192 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -106,6 +106,60 @@ type SquashOnNonStructType struct { InvalidSquashType int `mapstructure:",squash"` } +type TestInterface interface { + GetVfoo() string + GetVbarfoo() string + GetVfoobar() string +} + +type TestInterfaceImpl struct { + Vfoo string +} + +func (t *TestInterfaceImpl) GetVfoo() string { + return t.Vfoo +} + +func (t *TestInterfaceImpl) GetVbarfoo() string { + return "" +} + +func (t *TestInterfaceImpl) GetVfoobar() string { + return "" +} + +type TestNestedInterfaceImpl struct { + SquashOnNestedInterfaceType `mapstructure:",squash"` + Vfoo string +} + +func (t *TestNestedInterfaceImpl) GetVfoo() string { + return t.Vfoo +} + +func (t *TestNestedInterfaceImpl) GetVbarfoo() string { + return t.Vbarfoo +} + +func (t *TestNestedInterfaceImpl) GetVfoobar() string { + return t.NestedSquash.Vfoobar +} + +type SquashOnInterfaceType struct { + TestInterface `mapstructure:",squash"` + Vbar string +} + +type NestedSquash struct { + SquashOnInterfaceType `mapstructure:",squash"` + Vfoobar string +} + +type SquashOnNestedInterfaceType struct { + NestedSquash NestedSquash `mapstructure:",squash"` + Vbarfoo string +} + type Map struct { Vfoo string Vother map[string]string @@ -978,6 +1032,147 @@ func TestDecode_SquashOnNonStructType(t *testing.T) { } } +func TestDecode_SquashOnInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + } + + var result = SquashOnInterfaceType{ + TestInterface: &TestInterfaceImpl{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.GetVfoo() + if res != "42" { + t.Errorf("unexpected value for VFoo: %s", res) + } + + res = result.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } +} + +func TestDecode_SquashOnOuterNestedInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + "Vfoobar": "44", + "Vbarfoo": "45", + } + + var result = SquashOnNestedInterfaceType{ + NestedSquash: NestedSquash{ + SquashOnInterfaceType: SquashOnInterfaceType{ + TestInterface: &TestInterfaceImpl{}, + }, + }, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.NestedSquash.GetVfoo() + if res != "42" { + t.Errorf("unexpected value for VFoo: %s", res) + } + + res = result.NestedSquash.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } + + res = result.NestedSquash.Vfoobar + if res != "44" { + t.Errorf("unexpected value for Vfoobar: %s", res) + } + + res = result.Vbarfoo + if res != "45" { + t.Errorf("unexpected value for Vbarfoo: %s", res) + } +} + +func TestDecode_SquashOnInnerNestedInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + "Vfoobar": "44", + "Vbarfoo": "45", + } + + var result = SquashOnInterfaceType{ + TestInterface: &TestNestedInterfaceImpl{ + SquashOnNestedInterfaceType: SquashOnNestedInterfaceType{ + NestedSquash: NestedSquash{ + SquashOnInterfaceType: SquashOnInterfaceType{ + TestInterface: &TestInterfaceImpl{}, + }, + }, + }, + }, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.GetVfoo() + if res != "42" { + t.Errorf("unexpected value for VFoo: %s", res) + } + + res = result.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } + + res = result.GetVfoobar() + if res != "44" { + t.Errorf("unexpected value for Vfoobar: %s", res) + } + + res = result.GetVbarfoo() + if res != "45" { + t.Errorf("unexpected value for Vbarfoo: %s", res) + } +} + +func TestDecode_SquashOnNilInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + } + + var result = SquashOnInterfaceType{ + TestInterface: nil, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } +} + func TestDecode_DecodeHook(t *testing.T) { t.Parallel()