Skip to content
This repository has been archived by the owner on Jul 22, 2024. It is now read-only.

feat: squash with prefix #291

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 49 additions & 34 deletions mapstructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,13 @@ func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value)
copied = true

// Make *T
copy := reflect.New(elem.Type())
clone := reflect.New(elem.Type())
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prevent shadowing of built-in function copy


// *T = elem
copy.Elem().Set(elem)
clone.Elem().Set(elem)

// Set elem so we decode into it
elem = copy
elem = clone
}

// Decode. If we have an error then return. We also return right
Expand Down Expand Up @@ -857,7 +857,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
valElemType := valType.Elem()

// Accumulate errors
errors := make([]string, 0)
errs := make([]string, 0)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prevent shadowing of errors package


// If the input data is empty, then we just match what the input data is.
if dataVal.Len() == 0 {
Expand All @@ -879,15 +879,15 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
// First decode the key into the proper type
currentKey := reflect.Indirect(reflect.New(valKeyType))
if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
continue
}

// Next decode the data into the proper type
v := dataVal.MapIndex(k).Interface()
currentVal := reflect.Indirect(reflect.New(valElemType))
if err := d.decode(fieldName, v, currentVal); err != nil {
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
continue
}

Expand All @@ -898,14 +898,14 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
val.Set(valMap)

// If we had errors, return those
if len(errors) > 0 {
return &Error{errors}
if len(errs) > 0 {
return &Error{errs}
}

return nil
}

func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
func (d *Decoder) decodeMapFromStruct(_ string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused name argument

typ := dataVal.Type()
for i := 0; i < typ.NumField(); i++ {
// Get the StructField first since this is a cheap operation. If the
Expand Down Expand Up @@ -1128,7 +1128,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
}

// Accumulate any errors
errors := make([]string, 0)
errs := make([]string, 0)

for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
Expand All @@ -1139,16 +1139,16 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)

fieldName := name + "[" + strconv.Itoa(i) + "]"
if err := d.decode(fieldName, currentData, currentField); err != nil {
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
}
}

// Finally, set the value to the slice we built up
val.Set(valSlice)

// If there were errors, we return those
if len(errors) > 0 {
return &Error{errors}
if len(errs) > 0 {
return &Error{errs}
}

return nil
Expand Down Expand Up @@ -1198,24 +1198,24 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value)
}

// Accumulate any errors
errors := make([]string, 0)
errs := make([]string, 0)

for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
currentField := valArray.Index(i)

fieldName := name + "[" + strconv.Itoa(i) + "]"
if err := d.decode(fieldName, currentData, currentField); err != nil {
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
}
}

// Finally, set the value to the array we built up
val.Set(valArray)

// If there were errors, we return those
if len(errors) > 0 {
return &Error{errors}
if len(errs) > 0 {
return &Error{errs}
}

return nil
Expand Down Expand Up @@ -1280,7 +1280,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}

targetValKeysUnused := make(map[interface{}]struct{})
errors := make([]string, 0)
errs := make([]string, 0)

// This slice will keep track of all the structs we'll be decoding.
// There can be more than one struct if there are embedded structs
Expand All @@ -1291,20 +1291,23 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
// Compile the list of all the fields that we're going to be decoding
// from all the structs.
type field struct {
field reflect.StructField
val reflect.Value
field reflect.StructField
val reflect.Value
prefix string
}

// remainField is set to a valid field set with the "remain" tag if
// we are keeping track of remaining values.
var remainField *field

fields := []field{}
var fields []field
fieldPrefixes := make(map[reflect.Value]string)
for len(structs) > 0 {
structVal := structs[0]
structs = structs[1:]

structType := structVal.Type()
fieldPrefix := fieldPrefixes[structVal]

structs = structs[1:]

for i := 0; i < structType.NumField(); i++ {
fieldType := structType.Field(i)
Expand Down Expand Up @@ -1334,34 +1337,39 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e

if squash {
if fieldVal.Kind() != reflect.Struct {
errors = appendErrors(errors,
errs = appendErrors(errs,
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
} else {
structs = append(structs, fieldVal)
if prefix := tagParts[0]; prefix != "" {
fieldPrefixes[fieldVal] = addPrefix(prefix, fieldPrefix)
}
}
continue
}

// Build our field
if remain {
remainField = &field{fieldType, fieldVal}
remainField = &field{fieldType, fieldVal, fieldPrefix}
} else {
// Normal struct field, store it away
fields = append(fields, field{fieldType, fieldVal})
fields = append(fields, field{fieldType, fieldVal, fieldPrefix})
}
}
}

// for fieldType, field := range fields {
for _, f := range fields {
field, fieldValue := f.field, f.val
field, fieldValue, fieldPrefix := f.field, f.val, f.prefix
fieldName := field.Name

tagValue := field.Tag.Get(d.config.TagName)
tagValue = strings.SplitN(tagValue, ",", 2)[0]
if tagValue != "" {
fieldName = tagValue
}
if fieldPrefix != "" {
fieldName = addPrefix(fieldName, fieldPrefix)
}

rawMapKey := reflect.ValueOf(fieldName)
rawMapVal := dataVal.MapIndex(rawMapKey)
Expand Down Expand Up @@ -1411,7 +1419,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}

if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
}
}

Expand All @@ -1426,7 +1434,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e

// Decode it as-if we were just decoding this map onto our map.
if err := d.decodeMap(name, remain, remainField.val); err != nil {
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
}

// Set the map to nil so we have none so that the next check will
Expand All @@ -1442,7 +1450,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
sort.Strings(keys)

err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", "))
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
}

if d.config.ErrorUnset && len(targetValKeysUnused) > 0 {
Expand All @@ -1453,11 +1461,11 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
sort.Strings(keys)

err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", "))
errors = appendErrors(errors, err)
errs = appendErrors(errs, err)
}

if len(errors) > 0 {
return &Error{errors}
if len(errs) > 0 {
return &Error{errs}
}

// Add the unused keys to the list of unused keys if we're tracking metadata
Expand Down Expand Up @@ -1540,3 +1548,10 @@ func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Val
}
return v
}

func addPrefix(s string, prefix string) string {
if prefix == "" {
return s
}
return prefix + "_" + s

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add the _? Wouldn't it be more flexible to just add it to the prefix?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be more flexible if added to DecoderConfig. I preferred to make it simple.

Also, for me, it's just the same approach as https://cs.opensource.google/go/go/+/refs/tags/go1.21.3:src/path/path.go;l=162-180

}
45 changes: 45 additions & 0 deletions mapstructure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2732,6 +2732,51 @@ func TestDecoder_IgnoreUntaggedFields(t *testing.T) {
}
}

func TestDecoder_Decode_SquashWithPrefix(t *testing.T) {
type Git struct {
Remote string `mapstructure:"remote"`
}

type GitHub struct {
Git `mapstructure:"git,squash"`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems not possible to use another name than git here. For example if you rename this to

Suggested change
Git `mapstructure:"git,squash"`
Git `mapstructure:"gitx,squash"`

and update the input below accordingly (to GITHUB_GITX_REMOTE), the test panics.

Is this the desired behavior? I would have expected to be able to use any name here instead of just the name of the embedded struct.

Token string `mapstructure:"token"`
}

type Config struct {
GitHub `mapstructure:"github,squash"`
}

var cnf Config
decoder, err := NewDecoder(&DecoderConfig{
DecodeHook: nil,
ErrorUnused: false,
ZeroFields: false,
WeaklyTypedInput: false,
Squash: false,
Metadata: nil,
Result: &cnf,
TagName: "",
MatchName: nil,
})
if err != nil {
t.Fatalf("err: %s", err)
}

input := map[string]interface{}{
"GITHUB_GIT_REMOTE": "[email protected]:mitchellh/mapstructure.git",
"GITHUB_TOKEN": "secret",
}
if err := decoder.Decode(input); err != nil {
t.Fatalf("err: %s", err)
}
if cnf.Remote != input["GITHUB_GIT_REMOTE"].(string) {
t.Errorf("expected: %#v, obtained: %#v", input["GITHUB_GIT_REMOTE"], cnf.Remote)
}
if cnf.Token != input["GITHUB_TOKEN"].(string) {
t.Errorf("expected: %#v, obtained: %#v", input["GITHUB_TOKEN"], cnf.Token)
}
}

func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) {
var result Slice
err := Decode(input, &result)
Expand Down