From b9b454735cf68b7bbd9051e07f48fedc2f3416e6 Mon Sep 17 00:00:00 2001 From: James Ross Date: Thu, 23 May 2019 14:16:09 -0700 Subject: [PATCH] Add merge package for merging maps (#157) --- merge/merge.go | 100 ++++++++++++++++++++++ merge/merge_test.go | 201 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 merge/merge.go create mode 100644 merge/merge_test.go diff --git a/merge/merge.go b/merge/merge.go new file mode 100644 index 00000000..c0ad0f30 --- /dev/null +++ b/merge/merge.go @@ -0,0 +1,100 @@ +// Copyright (c) 2019 Palantir Technologies. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package merge + +import ( + "fmt" + "reflect" +) + +// Maps returns a new map that is the result of merging the two provided inputs, which must both be maps. Returns an +// error if either of the inputs are not maps. If the types of the input values differ, an error is returned. +// Merging is performed by creating a new map, setting its contents to be "dest", and then setting the key/value pairs in +// "src" on the new map (unless the value is a map, in which case a merge is performed recursively). +func Maps(dest, src interface{}) (interface{}, error) { + result, err := mergeMaps(reflect.ValueOf(dest), reflect.ValueOf(src)) + if err != nil { + return nil, err + } + return result.Interface(), nil +} + +// mergeMaps requires both inputs to be maps; if not, an error is returned. If both input maps have the same type, +// the returned map has the same type as well. If the input maps have different +// types, an error is returned. Otherwise, a new map is created and populated +// with the merge result for the return value. For map entries with the same key, +// the determineValue helper method is used to determine the resulting value for the key. +// Entries with nil values are preserved in the map. +func mergeMaps(dest, src reflect.Value) (reflect.Value, error) { + if dest.Kind() != reflect.Map { + return reflect.Value{}, fmt.Errorf("expected destination to be a map") + } + if src.Kind() != reflect.Map { + return reflect.Value{}, fmt.Errorf("expected source be a map") + } + + if dest.Type() != src.Type() { + return reflect.Value{}, fmt.Errorf("expected maps of same type") + } + result := reflect.MakeMap(dest.Type()) + for _, destKey := range dest.MapKeys() { + result.SetMapIndex(destKey, dest.MapIndex(destKey)) + } + for _, srcKey := range src.MapKeys() { + srcVal := src.MapIndex(srcKey) + destVal := dest.MapIndex(srcKey) + var resultVal reflect.Value + var err error + if !destVal.IsValid() { + if safeIsNil(srcVal) { + result.SetMapIndex(srcKey, srcVal) + continue + } + resultVal = srcVal + } else { + if safeIsNil(srcVal) { + result.SetMapIndex(srcKey, srcVal) + continue + } + if resultVal, err = determineValue(destVal, srcVal); err != nil { + return reflect.Value{}, err + } + } + result.SetMapIndex(srcKey, resultVal) + } + return result, nil +} + +// determineValue inspects the 'dest' and 'src' values and follows these rules: +// 1. If the values have different kinds, the value of 'src' is returned. +// 2. If the values are maps with the same type, the maps are recursively merged using the mergeMaps helper method. +// 3. If the values are interfaces, determineValue is called with the element values that the interfaces contain. +// 4. If the values are pointers, determineValue is called with the pointer's elements, and the address of the result is returned. +// 5. If the values are any other kind, the value of 'src' is returned. +func determineValue(destVal, srcVal reflect.Value) (reflect.Value, error) { + if destVal.Kind() != srcVal.Kind() { + return srcVal, nil + } + switch srcVal.Kind() { + case reflect.Map: + return mergeMaps(destVal, srcVal) + case reflect.Interface: + return determineValue(destVal.Elem(), srcVal.Elem()) + default: + return srcVal, nil + } +} + +// safeIsNil only calls IsNil if the value is an interface, pointer, map, or slice (IsNil will not panic in these cases) +func safeIsNil(val reflect.Value) bool { + switch val.Kind() { + case reflect.Interface, reflect.Ptr: + return val.IsNil() || safeIsNil(val.Elem()) + case reflect.Slice, reflect.Map: + return val.IsNil() + default: + return false + } +} diff --git a/merge/merge_test.go b/merge/merge_test.go new file mode 100644 index 00000000..aee0759a --- /dev/null +++ b/merge/merge_test.go @@ -0,0 +1,201 @@ +// Copyright (c) 2019 Palantir Technologies. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package merge_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/palantir/pkg/merge" +) + +type TestStruct1 struct { + Foo string +} + +func TestMergeMaps(t *testing.T) { + srcVal := "src" + destVal := "dest" + for _, test := range []struct { + name string + src, dest, expected interface{} + expectedErr string + }{ + { + name: "config maps", + src: map[string]interface{}{ + "conf": map[string]interface{}{ + "map": map[string]interface{}{ + "value1": 1, + "value2": 2, + }, + "string": "What number am I thinking of?", + "array": []string{"one", "two", "three"}, + }, + "location": "src location", + }, + dest: map[string]interface{}{ + "conf": map[string]interface{}{ + "map": map[string]interface{}{ + "value1": 5, + }, + "array": map[string]string{"key": "four", "key2": "five"}, + }, + "string": "What letter am I thinking of?", + }, + expected: map[string]interface{}{ + "conf": map[string]interface{}{ + "map": map[string]interface{}{ + "value1": 1, + "value2": 2, + }, + "string": "What number am I thinking of?", + "array": []string{"one", "two", "three"}, + }, + "string": "What letter am I thinking of?", + "location": "src location", + }, + }, + { + name: "no overlap", + src: map[string]interface{}{ + "b": &srcVal, + }, + dest: map[string]interface{}{ + "c": &destVal, + }, + expected: map[string]interface{}{ + "b": &srcVal, + "c": &destVal, + }, + }, + { + name: "pointers", + src: map[string]*string{ + "a": &srcVal, + "b": &srcVal, + }, + dest: map[string]*string{ + "a": &destVal, + "c": &destVal, + }, + expected: map[string]*string{ + "a": &srcVal, + "b": &srcVal, + "c": &destVal, + }, + }, + { + name: "different map types returns error", + src: map[string]interface{}{ + "a": "a", + "b": "b", + }, + dest: map[string]string{ + "a": "a", + "c": "c", + }, + expectedErr: "expected maps of same type", + }, + { + name: "different map entry value types return the value from src", + src: map[string]interface{}{ + "a": "a string", + }, + dest: map[string]interface{}{ + "a": []string{"a string in a slice that will be overridden"}, + "b": "c", + }, + expected: map[string]interface{}{ + "a": "a string", + "b": "c", + }, + }, + { + name: "typed nil value for a src map entry results in a typed nil entry for that key", + src: map[string]interface{}{ + "a": (*string)(nil), + }, + dest: map[string]interface{}{ + "a": "foo", + "b": "c", + }, + expected: map[string]interface{}{ + "a": (*string)(nil), + "b": "c", + }, + }, + { + name: "untyped nil value for a src map entry results in a nil entry for that key", + src: map[string]interface{}{ + "a": nil, + "c": nil, + }, + dest: map[string]interface{}{ + "a": "foo", + "b": "c", + }, + expected: map[string]interface{}{ + "a": nil, + "b": "c", + "c": nil, + }, + }, + { + name: "src val for structs is used", + src: map[string]interface{}{ + "a": TestStruct1{ + Foo: "src foo value", + }, + }, + dest: map[string]interface{}{ + "a": "dest bar value", + }, + expected: map[string]interface{}{ + "a": TestStruct1{ + Foo: "src foo value", + }, + }, + }, + { + name: "src value for pointers is used", + src: map[string]interface{}{ + "a": &map[string]interface{}{ + "b": "c", + }, + "b": (*string)(nil), + "c": &[]string{"d"}, + }, + dest: map[string]interface{}{ + "a": &map[string]interface{}{ + "c": "d", + }, + "b": &destVal, + "c": "d", + "d": "non-pointer type", + }, + expected: map[string]interface{}{ + "a": &map[string]interface{}{ + "b": "c", + }, + "b": (*string)(nil), + "c": &[]string{"d"}, + "d": "non-pointer type", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + merged, err := merge.Maps(test.dest, test.src) + if test.expectedErr == "" { + assert.NoError(t, err) + assert.Equal(t, test.expected, merged) + } else { + assert.EqualError(t, err, test.expectedErr) + assert.Nil(t, merged) + } + }) + } +}