Skip to content

Commit

Permalink
Fix coercion of typedef primitives and their pointers. For example,
Browse files Browse the repository at this point in the history
with "type MyInt int", values of type MyInt and *MyInt should be treated
as ints. Fixes #488.
  • Loading branch information
atombender committed Jun 14, 2019
1 parent 199d20b commit 0d8ef75
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 11 deletions.
107 changes: 99 additions & 8 deletions scalars.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,79 @@ package graphql
import (
"fmt"
"math"
"reflect"
"strconv"
"time"

"github.com/graphql-go/graphql/language/ast"
)

func unwrapInt(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Int:
return int(r.Int()), true
case reflect.Int8:
return int8(r.Int()), true
case reflect.Int16:
return int16(r.Int()), true
case reflect.Int32:
return int32(r.Int()), true
case reflect.Int64:
return r.Int(), true
default:
return nil, false
}
}

func unwrapFloat(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Float32:
return float32(r.Float()), true
case reflect.Float64:
return r.Float(), true
default:
return nil, false
}
}

func unwrapBool(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Bool:
return r.Bool(), true
default:
return nil, false
}
}

func unwrapString(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.String:
return r.String(), true
default:
return nil, false
}
}

// As per the GraphQL Spec, Integers are only treated as valid when a valid
// 32-bit signed integer, providing the broadest support across platforms.
//
Expand Down Expand Up @@ -142,11 +209,14 @@ func coerceInt(value interface{}) interface{} {
return nil
}
return coerceInt(*value)
default:
if v, ok := unwrapInt(value); ok {
return coerceInt(v)
}
// If the value cannot be transformed into an int, return nil instead of '0'
// to denote 'no integer found'
return nil
}

// If the value cannot be transformed into an int, return nil instead of '0'
// to denote 'no integer found'
return nil
}

// Int is the GraphQL Integer type definition.
Expand Down Expand Up @@ -276,6 +346,10 @@ func coerceFloat(value interface{}) interface{} {
return coerceFloat(*value)
}

if v, ok := unwrapFloat(value); ok {
return coerceFloat(v)
}

// If the value cannot be transformed into an float, return nil instead of '0.0'
// to denote 'no float found'
return nil
Expand Down Expand Up @@ -305,13 +379,23 @@ var Float = NewScalar(ScalarConfig{
})

func coerceString(value interface{}) interface{} {
if v, ok := value.(*string); ok {
if v == nil {
switch t := value.(type) {
case *string:
if t == nil {
return nil
}
return *v
return *t
case string:
return t
default:
if v, ok := unwrapString(value); ok {
return coerceString(v)
}
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
return nil
}
return fmt.Sprintf("%v", value)
}
return fmt.Sprintf("%v", value)
}

// String is the GraphQL string type definition
Expand Down Expand Up @@ -472,6 +556,13 @@ func coerceBool(value interface{}) interface{} {
}
return coerceBool(*value)
}

if v, ok := unwrapBool(value); ok {
return coerceBool(v)
}
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
return nil
}
return false
}

Expand Down
95 changes: 92 additions & 3 deletions scalars_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,50 @@ import (
"testing"
)

type (
myInt int
myString string
myBool bool
myFloat32 float32
)

func TestCoerceString(t *testing.T) {
tests := []struct {
in interface{}
want interface{}
}{
{
in: "hello",
want: "hello",
},
{
in: func() interface{} { s := "hello"; return &s }(),
want: "hello",
},
// Typedef
{
in: myString("hello"),
want: "hello",
},
// Typedef with pointer
{
in: func() interface{} { v := myString("hello"); return &v }(),
want: "hello",
},
// Typedef with nil pointer
{
in: (*myString)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceString(tt.in), tt.want; got != want {
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}

func TestCoerceInt(t *testing.T) {
tests := []struct {
in interface{}
Expand Down Expand Up @@ -240,11 +284,26 @@ func TestCoerceInt(t *testing.T) {
in: make(map[string]interface{}),
want: nil,
},
// Typedef
{
in: myInt(42),
want: int(42),
},
// Typedef with pointer
{
in: func() interface{} { v := myInt(42); return &v }(),
want: int(42),
},
// Typedef with nil pointer
{
in: (*myInt)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceInt(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
Expand Down Expand Up @@ -438,11 +497,26 @@ func TestCoerceFloat(t *testing.T) {
in: make(map[string]interface{}),
want: nil,
},
// Typedef
{
in: myFloat32(3.14),
want: float32(3.14),
},
// Typedef with pointer
{
in: func() interface{} { v := myFloat32(3.14); return &v }(),
want: float32(3.14),
},
// Typedef with nil pointer
{
in: (*myFloat32)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceFloat(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
Expand Down Expand Up @@ -740,11 +814,26 @@ func TestCoerceBool(t *testing.T) {
in: make(map[string]interface{}),
want: false,
},
// Typedef
{
in: myBool(true),
want: true,
},
// Typedef with pointer
{
in: func() interface{} { v := myBool(true); return &v }(),
want: true,
},
// Typedef with nil pointer
{
in: (*myBool)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceBool(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
Expand Down

0 comments on commit 0d8ef75

Please sign in to comment.