diff --git a/decode_hooks.go b/decode_hooks.go index 3a754ca7..ed723133 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -99,6 +99,19 @@ func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { } } +// RecoveringDecodeHookFunc executes the input hook function and turns a panic into an error. +func RecoveringDecodeHookFunc(hook DecodeHookFunc) DecodeHookFunc { + return func(from, to reflect.Value) (v interface{}, err error) { + defer func() { + if r := recover(); r != nil { + v = nil + err = fmt.Errorf("internal error while parsing: %s", r) + } + }() + return DecodeHookExec(hook, from, to) + } +} + // StringToSliceHookFunc returns a DecodeHookFunc that converts // string to []string by splitting on the given sep. func StringToSliceHookFunc(sep string) DecodeHookFunc { diff --git a/decode_hooks_test.go b/decode_hooks_test.go index bf029526..66b4c50e 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -204,6 +204,67 @@ func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) { } } +func TestRecoveringDecodeHook(t *testing.T) { + f1 := func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + f := RecoveringDecodeHookFunc(f1) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "bar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestRecoveringDecodeHook_err(t *testing.T) { + f1 := func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() == reflect.String { + panic(errors.New("noooo")) + } + return data, nil + } + f := RecoveringDecodeHookFunc(f1) + + type myStruct struct { + A string + B string + } + src := map[string]string{ + "A": "one", + "B": "two", + } + dst := &myStruct{} + dConf := &DecoderConfig{ + Result: dst, + ErrorUnused: true, + DecodeHook: f, + } + d, err := NewDecoder(dConf) + if err != nil { + t.Fatal(err) + } + err = d.Decode(src) + if err == nil { + t.Fatalf("bad: should return an error") + } + if err.Error() != `2 error(s) decoding: + +* error decoding 'A': internal error while parsing: noooo +* error decoding 'B': internal error while parsing: noooo` { + t.Fatalf("bad: %s", err) + } +} + func TestStringToSliceHookFunc(t *testing.T) { f := StringToSliceHookFunc(",")