diff --git a/scalars.go b/scalars.go index 94c1943a..b7dd73dd 100644 --- a/scalars.go +++ b/scalars.go @@ -3,12 +3,79 @@ package graphql import ( "fmt" "math" + "reflect" "strconv" "time" "github.com/graphql-go/graphql/language/ast" ) +func unwrapInt(value interface{}) (interface{}, bool) { + r := reflect.Indirect(reflect.ValueOf(value)) + if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) { + return nil, false + } + + switch r.Kind() { + case reflect.Int: + return int(r.Int()), true + case reflect.Int8: + return int8(r.Int()), true + case reflect.Int16: + return int16(r.Int()), true + case reflect.Int32: + return int32(r.Int()), true + case reflect.Int64: + return r.Int(), true + default: + return nil, false + } +} + +func unwrapFloat(value interface{}) (interface{}, bool) { + r := reflect.Indirect(reflect.ValueOf(value)) + if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) { + return nil, false + } + + switch r.Kind() { + case reflect.Float32: + return float32(r.Float()), true + case reflect.Float64: + return r.Float(), true + default: + return nil, false + } +} + +func unwrapBool(value interface{}) (interface{}, bool) { + r := reflect.Indirect(reflect.ValueOf(value)) + if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) { + return nil, false + } + + switch r.Kind() { + case reflect.Bool: + return r.Bool(), true + default: + return nil, false + } +} + +func unwrapString(value interface{}) (interface{}, bool) { + r := reflect.Indirect(reflect.ValueOf(value)) + if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) { + return nil, false + } + + switch r.Kind() { + case reflect.String: + return r.String(), true + default: + return nil, false + } +} + // As per the GraphQL Spec, Integers are only treated as valid when a valid // 32-bit signed integer, providing the broadest support across platforms. // @@ -142,11 +209,14 @@ func coerceInt(value interface{}) interface{} { return nil } return coerceInt(*value) + default: + if v, ok := unwrapInt(value); ok { + return coerceInt(v) + } + // If the value cannot be transformed into an int, return nil instead of '0' + // to denote 'no integer found' + return nil } - - // If the value cannot be transformed into an int, return nil instead of '0' - // to denote 'no integer found' - return nil } // Int is the GraphQL Integer type definition. @@ -276,6 +346,10 @@ func coerceFloat(value interface{}) interface{} { return coerceFloat(*value) } + if v, ok := unwrapFloat(value); ok { + return coerceFloat(v) + } + // If the value cannot be transformed into an float, return nil instead of '0.0' // to denote 'no float found' return nil @@ -305,13 +379,23 @@ var Float = NewScalar(ScalarConfig{ }) func coerceString(value interface{}) interface{} { - if v, ok := value.(*string); ok { - if v == nil { + switch t := value.(type) { + case *string: + if t == nil { return nil } - return *v + return *t + case string: + return t + default: + if v, ok := unwrapString(value); ok { + return coerceString(v) + } + if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() { + return nil + } + return fmt.Sprintf("%v", value) } - return fmt.Sprintf("%v", value) } // String is the GraphQL string type definition @@ -472,6 +556,13 @@ func coerceBool(value interface{}) interface{} { } return coerceBool(*value) } + + if v, ok := unwrapBool(value); ok { + return coerceBool(v) + } + if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() { + return nil + } return false } diff --git a/scalars_test.go b/scalars_test.go index 26987e5c..ca3746b6 100644 --- a/scalars_test.go +++ b/scalars_test.go @@ -5,6 +5,50 @@ import ( "testing" ) +type ( + myInt int + myString string + myBool bool + myFloat32 float32 +) + +func TestCoerceString(t *testing.T) { + tests := []struct { + in interface{} + want interface{} + }{ + { + in: "hello", + want: "hello", + }, + { + in: func() interface{} { s := "hello"; return &s }(), + want: "hello", + }, + // Typedef + { + in: myString("hello"), + want: "hello", + }, + // Typedef with pointer + { + in: func() interface{} { v := myString("hello"); return &v }(), + want: "hello", + }, + // Typedef with nil pointer + { + in: (*myString)(nil), + want: nil, + }, + } + + for i, tt := range tests { + if got, want := coerceString(tt.in), tt.want; got != want { + t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want) + } + } +} + func TestCoerceInt(t *testing.T) { tests := []struct { in interface{} @@ -240,11 +284,26 @@ func TestCoerceInt(t *testing.T) { in: make(map[string]interface{}), want: nil, }, + // Typedef + { + in: myInt(42), + want: int(42), + }, + // Typedef with pointer + { + in: func() interface{} { v := myInt(42); return &v }(), + want: int(42), + }, + // Typedef with nil pointer + { + in: (*myInt)(nil), + want: nil, + }, } for i, tt := range tests { if got, want := coerceInt(tt.in), tt.want; got != want { - t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want) + t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want) } } } @@ -438,11 +497,26 @@ func TestCoerceFloat(t *testing.T) { in: make(map[string]interface{}), want: nil, }, + // Typedef + { + in: myFloat32(3.14), + want: float32(3.14), + }, + // Typedef with pointer + { + in: func() interface{} { v := myFloat32(3.14); return &v }(), + want: float32(3.14), + }, + // Typedef with nil pointer + { + in: (*myFloat32)(nil), + want: nil, + }, } for i, tt := range tests { if got, want := coerceFloat(tt.in), tt.want; got != want { - t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want) + t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want) } } } @@ -740,11 +814,26 @@ func TestCoerceBool(t *testing.T) { in: make(map[string]interface{}), want: false, }, + // Typedef + { + in: myBool(true), + want: true, + }, + // Typedef with pointer + { + in: func() interface{} { v := myBool(true); return &v }(), + want: true, + }, + // Typedef with nil pointer + { + in: (*myBool)(nil), + want: nil, + }, } for i, tt := range tests { if got, want := coerceBool(tt.in), tt.want; got != want { - t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want) + t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want) } } }