diff --git a/simplejson.go b/simplejson.go index 95e73fd..5a65e57 100644 --- a/simplejson.go +++ b/simplejson.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "log" + "reflect" ) // returns the current implementation version @@ -60,14 +61,15 @@ func (j *Json) Set(key string, val interface{}) { if err != nil { return } - m[key] = val + + m[key] = handleArray(val) } // SetPath modifies `Json`, recursively checking/creating map keys for the supplied path, // and then finally writing in the value func (j *Json) SetPath(branch []string, val interface{}) { if len(branch) == 0 { - j.data = val + j.data = handleArray(val) return } @@ -99,7 +101,7 @@ func (j *Json) SetPath(branch []string, val interface{}) { } // add remaining k/v - curr[branch[len(branch)-1]] = val + curr[branch[len(branch)-1]] = handleArray(val) } // Del modifies `Json` map by deleting `key` if it is present. @@ -444,3 +446,21 @@ func (j *Json) MustUint64(args ...uint64) uint64 { return def } + +func handleArray(val interface{}) interface{} { + if val != nil { + // If val is an array convert to []interface{} + typ := reflect.TypeOf(val) + kind := typ.Kind() + if kind == reflect.Array || kind == reflect.Slice { + + v := reflect.ValueOf(val) + arr := make([]interface{}, v.Len()) + for i := 0; i < v.Len(); i++ { + arr[i] = v.Index(i).Interface() + } + val = arr + } + } + return val +} diff --git a/simplejson_test.go b/simplejson_test.go index 477f1a4..73cf323 100644 --- a/simplejson_test.go +++ b/simplejson_test.go @@ -122,6 +122,24 @@ func TestSimplejson(t *testing.T) { js.GetPath("test", "sub_obj").Set("a", 3) assert.Equal(t, 3, js.GetPath("test", "sub_obj", "a").MustInt()) + + a := [3]string{"one", "two", "three"} + js = New() + js.Set("array", a) + a2, err := js.Get("array").Array() + assert.Equal(t, nil, err) + assert.NotEqual(t, nil, a2) + assert.Equal(t, a2[0], "one") + assert.Equal(t, a2[1], "two") + assert.Equal(t, a2[2], "three") + + a3, err := js.Get("array").StringArray() + assert.Equal(t, nil, err) + assert.NotEqual(t, nil, a3) + assert.Equal(t, a3[0], "one") + assert.Equal(t, a3[1], "two") + assert.Equal(t, a3[2], "three") + } func TestStdlibInterfaces(t *testing.T) {