Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix coercion of typedef primitives and their pointers #489

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 99 additions & 8 deletions scalars.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,79 @@ package graphql
import (
"fmt"
"math"
"reflect"
"strconv"
"time"

"github.com/graphql-go/graphql/language/ast"
)

func unwrapInt(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Int:
return int(r.Int()), true
case reflect.Int8:
return int8(r.Int()), true
case reflect.Int16:
return int16(r.Int()), true
case reflect.Int32:
return int32(r.Int()), true
case reflect.Int64:
return r.Int(), true
default:
return nil, false
}
}

func unwrapFloat(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Float32:
return float32(r.Float()), true
case reflect.Float64:
return r.Float(), true
default:
return nil, false
}
}

func unwrapBool(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Bool:
return r.Bool(), true
default:
return nil, false
}
}

func unwrapString(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.String:
return r.String(), true
default:
return nil, false
}
}

// As per the GraphQL Spec, Integers are only treated as valid when a valid
// 32-bit signed integer, providing the broadest support across platforms.
//
Expand Down Expand Up @@ -142,11 +209,14 @@ func coerceInt(value interface{}) interface{} {
return nil
}
return coerceInt(*value)
default:
if v, ok := unwrapInt(value); ok {
return coerceInt(v)
}
// If the value cannot be transformed into an int, return nil instead of '0'
// to denote 'no integer found'
return nil
}

// If the value cannot be transformed into an int, return nil instead of '0'
// to denote 'no integer found'
return nil
}

// Int is the GraphQL Integer type definition.
Expand Down Expand Up @@ -276,6 +346,10 @@ func coerceFloat(value interface{}) interface{} {
return coerceFloat(*value)
}

if v, ok := unwrapFloat(value); ok {
return coerceFloat(v)
}

// If the value cannot be transformed into an float, return nil instead of '0.0'
// to denote 'no float found'
return nil
Expand Down Expand Up @@ -305,13 +379,23 @@ var Float = NewScalar(ScalarConfig{
})

func coerceString(value interface{}) interface{} {
if v, ok := value.(*string); ok {
if v == nil {
switch t := value.(type) {
case *string:
if t == nil {
return nil
}
return *v
return *t
case string:
return t
default:
if v, ok := unwrapString(value); ok {
return coerceString(v)
}
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
return nil
}
return fmt.Sprintf("%v", value)
}
return fmt.Sprintf("%v", value)
}

// String is the GraphQL string type definition
Expand Down Expand Up @@ -472,6 +556,13 @@ func coerceBool(value interface{}) interface{} {
}
return coerceBool(*value)
}

if v, ok := unwrapBool(value); ok {
return coerceBool(v)
}
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
return nil
}
return false
}

Expand Down
95 changes: 92 additions & 3 deletions scalars_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,50 @@ import (
"testing"
)

type (
myInt int
myString string
myBool bool
myFloat32 float32
)

func TestCoerceString(t *testing.T) {
tests := []struct {
in interface{}
want interface{}
}{
{
in: "hello",
want: "hello",
},
{
in: func() interface{} { s := "hello"; return &s }(),
want: "hello",
},
// Typedef
{
in: myString("hello"),
want: "hello",
},
// Typedef with pointer
{
in: func() interface{} { v := myString("hello"); return &v }(),
want: "hello",
},
// Typedef with nil pointer
{
in: (*myString)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceString(tt.in), tt.want; got != want {
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}

func TestCoerceInt(t *testing.T) {
tests := []struct {
in interface{}
Expand Down Expand Up @@ -240,11 +284,26 @@ func TestCoerceInt(t *testing.T) {
in: make(map[string]interface{}),
want: nil,
},
// Typedef
{
in: myInt(42),
want: int(42),
},
// Typedef with pointer
{
in: func() interface{} { v := myInt(42); return &v }(),
want: int(42),
},
// Typedef with nil pointer
{
in: (*myInt)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceInt(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
Expand Down Expand Up @@ -438,11 +497,26 @@ func TestCoerceFloat(t *testing.T) {
in: make(map[string]interface{}),
want: nil,
},
// Typedef
{
in: myFloat32(3.14),
want: float32(3.14),
},
// Typedef with pointer
{
in: func() interface{} { v := myFloat32(3.14); return &v }(),
want: float32(3.14),
},
// Typedef with nil pointer
{
in: (*myFloat32)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceFloat(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
Expand Down Expand Up @@ -740,11 +814,26 @@ func TestCoerceBool(t *testing.T) {
in: make(map[string]interface{}),
want: false,
},
// Typedef
{
in: myBool(true),
want: true,
},
// Typedef with pointer
{
in: func() interface{} { v := myBool(true); return &v }(),
want: true,
},
// Typedef with nil pointer
{
in: (*myBool)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceBool(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
Expand Down