diff --git a/executor_test.go b/executor_test.go index 856aadf3..236fa3bd 100644 --- a/executor_test.go +++ b/executor_test.go @@ -403,6 +403,82 @@ func TestThreadsSourceCorrectly(t *testing.T) { } } +func TestCorrectlyListArgumentsWithNull(t *testing.T) { + query := ` + query Example { + b(listStringArg: null, listBoolArg: [true,false,null],listIntArg:[123,null,12],listStringNonNullArg:[null]) + } + ` + var resolvedArgs map[string]interface{} + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Type", + Fields: graphql.Fields{ + "b": &graphql.Field{ + Args: graphql.FieldConfigArgument{ + "listStringArg": &graphql.ArgumentConfig{ + Type: graphql.NewList(graphql.String), + }, + "listStringNonNullArg": &graphql.ArgumentConfig{ + Type: graphql.NewNonNull(graphql.NewList(graphql.String)), + }, + "listBoolArg": &graphql.ArgumentConfig{ + Type: graphql.NewList(graphql.Boolean), + }, + "listIntArg": &graphql.ArgumentConfig{ + Type: graphql.NewList(graphql.Int), + }, + }, + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + resolvedArgs = p.Args + return resolvedArgs, nil + }, + }, + }, + }), + }) + if err != nil { + t.Fatalf("Error in schema %v", err.Error()) + } + ast := testutil.TestParse(t, query) + + ep := graphql.ExecuteParams{ + Schema: schema, + AST: ast, + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + tests := []struct { + key string + expected interface{} + }{ + { + "listStringArg", nil, + }, + + { + "listStringNonNullArg", []interface{}{nil}, + }, + + { + "listBoolArg", []interface{}{true, false, nil}, + }, + + { + "listIntArg", []interface{}{123, nil, 12}, + }, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("TestCorrectlyListArgumentsWithNull_%s", tt.key), func(t *testing.T) { + if !reflect.DeepEqual(resolvedArgs[tt.key], tt.expected) { + t.Fatalf("Expected args.%s to equal `%v`, got `%v`", tt.key, tt.expected, resolvedArgs[tt.key]) + } + }) + } +} func TestCorrectlyThreadsArguments(t *testing.T) { query := ` diff --git a/language/ast/values.go b/language/ast/values.go index 6c3c8864..f02e1c16 100644 --- a/language/ast/values.go +++ b/language/ast/values.go @@ -19,6 +19,7 @@ var _ Value = (*BooleanValue)(nil) var _ Value = (*EnumValue)(nil) var _ Value = (*ListValue)(nil) var _ Value = (*ObjectValue)(nil) +var _ Value = (*NullValue)(nil) // Variable implements Node, Value type Variable struct { @@ -202,6 +203,39 @@ func (v *EnumValue) GetValue() interface{} { return v.Value } +// NullValue represents the GraphQL null value. +// +// It is used to support passing null as an input value. +// +// Reference: https://spec.graphql.org/October2021/#sec-Null-Value +type NullValue struct { + Kind string + Loc *Location + Value interface{} +} + +func NewNullValue(v *NullValue) *NullValue { + if v == nil { + v = &NullValue{} + } + return &NullValue{ + Kind: kinds.NullValue, + Loc: v.Loc, + Value: nil, + } +} +func (n *NullValue) GetKind() string { + return n.Kind +} + +func (n *NullValue) GetLoc() *Location { + return n.Loc +} + +func (n *NullValue) GetValue() interface{} { + return n.Value +} + // ListValue implements Node, Value type ListValue struct { Kind string diff --git a/language/kinds/kinds.go b/language/kinds/kinds.go index 40bc994e..d5be9ed8 100644 --- a/language/kinds/kinds.go +++ b/language/kinds/kinds.go @@ -27,6 +27,7 @@ const ( ListValue = "ListValue" ObjectValue = "ObjectValue" ObjectField = "ObjectField" + NullValue = "NullValue" // Directives Directive = "Directive" diff --git a/language/lexer/lexer.go b/language/lexer/lexer.go index 1988c5fd..a50c335f 100644 --- a/language/lexer/lexer.go +++ b/language/lexer/lexer.go @@ -34,6 +34,7 @@ const ( STRING BLOCK_STRING AMP + NULL ) var tokenDescription = map[TokenKind]string{ @@ -57,6 +58,7 @@ var tokenDescription = map[TokenKind]string{ STRING: "String", BLOCK_STRING: "BlockString", AMP: "&", + NULL: "null", } func (kind TokenKind) String() string { diff --git a/language/parser/parser.go b/language/parser/parser.go index 4ae3dc33..0e8bc74a 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -614,6 +614,14 @@ func parseValueLiteral(parser *Parser, isConst bool) (ast.Value, error) { Value: token.Value, Loc: loc(parser, token.Start), }), nil + } else { + // If the value literal in the GraphQL input is `null`, converts it into a NullValue AST node. + if err := advance(parser); err != nil { + return nil, err + } + return ast.NewNullValue(&ast.NullValue{ + Loc: loc(parser, token.Start), + }), nil } case lexer.DOLLAR: if !isConst { @@ -1562,7 +1570,8 @@ func unexpectedEmpty(parser *Parser, beginLoc int, openKind, closeKind lexer.Tok return gqlerrors.NewSyntaxError(parser.Source, beginLoc, description) } -// Returns list of parse nodes, determined by +// Returns list of parse nodes, determined by +// // the parseFn. This list begins with a lex token of openKind // and ends with a lex token of closeKind. Advances the parser // to the next lex token after the closing token. diff --git a/language/parser/parser_test.go b/language/parser/parser_test.go index 8f0e0715..062dd577 100644 --- a/language/parser/parser_test.go +++ b/language/parser/parser_test.go @@ -183,15 +183,6 @@ func TestDoesNotAcceptFragmentsSpreadOfOn(t *testing.T) { testErrorMessage(t, test) } -func TestDoesNotAllowNullAsValue(t *testing.T) { - test := errorMessageTest{ - `{ fieldWithNullableStringInput(input: null) }'`, - `Syntax Error GraphQL (1:39) Unexpected Name "null"`, - false, - } - testErrorMessage(t, test) -} - func TestParsesMultiByteCharacters_Unicode(t *testing.T) { doc := ` diff --git a/language/printer/printer.go b/language/printer/printer.go index ac771ba6..43ba45c4 100644 --- a/language/printer/printer.go +++ b/language/printer/printer.go @@ -8,6 +8,7 @@ import ( "reflect" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/lexer" "github.com/graphql-go/graphql/language/visitor" ) @@ -472,7 +473,13 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ } return visitor.ActionNoChange, nil }, - + "NullValue": func(p visitor.VisitFuncParams) (string, interface{}) { + switch p.Node.(type) { + case *ast.NullValue: + return visitor.ActionUpdate, lexer.NULL.String() + } + return visitor.ActionNoChange, nil + }, // Type System Definitions "SchemaDefinition": func(p visitor.VisitFuncParams) (string, interface{}) { switch node := p.Node.(type) { diff --git a/language/printer/printer_test.go b/language/printer/printer_test.go index b6d7de7d..5b8e064f 100644 --- a/language/printer/printer_test.go +++ b/language/printer/printer_test.go @@ -200,3 +200,17 @@ func TestPrinter_CorrectlyPrintsStringArgumentsWithProperQuoting(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, results)) } } + +func TestPrinter_CorrectlyPrintsNullArguments(t *testing.T) { + queryAst := `query { foo(nullArg: null) }` + expected := `{ + foo(nullArg: null) +} +` + astDoc := parse(t, queryAst) + results := printer.Print(astDoc) + + if !reflect.DeepEqual(expected, results) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, results)) + } +} diff --git a/rules.go b/rules.go index ae0c75b9..95b918a4 100644 --- a/rules.go +++ b/rules.go @@ -1735,6 +1735,11 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { if valueAST.GetKind() == kinds.Variable { return true, nil } + // Supplying a nullable variable type to a non-null input type is considered invalid. + // nullValue is only valid for nullable input types. + if valueAST.GetKind() == kinds.NullValue { + return true, nil + } } switch ttype := ttype.(type) { case *NonNull: @@ -1742,7 +1747,7 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { if e := ttype.Error(); e != nil { return false, []string{e.Error()} } - if valueAST == nil { + if valueAST == nil || valueAST.GetKind() == kinds.NullValue { if ttype.OfType.Name() != "" { return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())} } diff --git a/rules_arguments_of_correct_type_test.go b/rules_arguments_of_correct_type_test.go index ecd4bea4..b27b78c1 100644 --- a/rules_arguments_of_correct_type_test.go +++ b/rules_arguments_of_correct_type_test.go @@ -8,6 +8,41 @@ import ( "github.com/graphql-go/graphql/testutil" ) +func TestValidate_ArgValuesOfCorrectType_ValidValue_GoodNullValue(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.ArgumentsOfCorrectTypeRule, ` + { + complicatedArgs { + intArgField(intArg: null) + } + } + `) +} + +func TestValidator_NonNullArgsUsingNullValue(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.ArgumentsOfCorrectTypeRule, ` + { + complicatedArgs { + nonNullIntArgField(nonNullIntArg: null) + } + } + `, []gqlerrors.FormattedError{ + testutil.RuleError( + "Argument \"nonNullIntArg\" has invalid value null.\nExpected \"Int!\", found null.", + 4, 47, + ), + }) +} + +func TestValidator_NullArgsUsingNullValue(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.ArgumentsOfCorrectTypeRule, ` + { + complicatedArgs { + stringArgField(stringArg: null) + } + } + `) +} + func TestValidate_ArgValuesOfCorrectType_ValidValue_GoodIntValue(t *testing.T) { testutil.ExpectPassesRule(t, graphql.ArgumentsOfCorrectTypeRule, ` { diff --git a/scalars.go b/scalars.go index 45479b54..68fdbf39 100644 --- a/scalars.go +++ b/scalars.go @@ -163,6 +163,7 @@ var Int = NewScalar(ScalarConfig{ return intValue } } + return nil }, }) @@ -332,6 +333,9 @@ var String = NewScalar(ScalarConfig{ }) func coerceBool(value interface{}) interface{} { + if value == nil { + return nil + } switch value := value.(type) { case bool: return value diff --git a/scalars_test.go b/scalars_test.go index 26987e5c..aec04ffd 100644 --- a/scalars_test.go +++ b/scalars_test.go @@ -240,6 +240,10 @@ func TestCoerceInt(t *testing.T) { in: make(map[string]interface{}), want: nil, }, + { + in: nil, + want: nil, + }, } for i, tt := range tests { @@ -438,6 +442,10 @@ func TestCoerceFloat(t *testing.T) { in: make(map[string]interface{}), want: nil, }, + { + in: nil, + want: nil, + }, } for i, tt := range tests { @@ -740,6 +748,10 @@ func TestCoerceBool(t *testing.T) { in: make(map[string]interface{}), want: false, }, + { + in: nil, + want: nil, + }, } for i, tt := range tests { diff --git a/validator_test.go b/validator_test.go index 6eaf0005..b390cbef 100644 --- a/validator_test.go +++ b/validator_test.go @@ -45,6 +45,7 @@ func TestValidator_SupportsFullValidation_ValidatesQueries(t *testing.T) { `) } + // NOTE: experimental func TestValidator_SupportsFullValidation_ValidatesUsingACustomTypeInfo(t *testing.T) { diff --git a/values.go b/values.go index 06c08af6..c1b20c7e 100644 --- a/values.go +++ b/values.go @@ -57,16 +57,34 @@ func getArgumentValues( if tmpValue, ok := argASTMap[argDef.PrivateName]; ok { value = tmpValue.Value } - if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) { - tmp = argDef.DefaultValue - } - if !isNullish(tmp) { - results[argDef.PrivateName] = tmp + // if ast value is NullValue, and keep args's key + if value != nil && value.GetKind() == kinds.NullValue { + results[argDef.PrivateName] = nil + } else { + if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) { + tmp = argDef.DefaultValue + } + if !isNullish(tmp) { + results[argDef.PrivateName] = tmp + } else { + if nullValueWithVairableProvided(value, argDef.PrivateName, variableValues) { + results[argDef.PrivateName] = nil + } + } } } return results } +func nullValueWithVairableProvided(valueAST ast.Value, key string, variables map[string]interface{}) bool { + if valueAST != nil && valueAST.GetKind() == kinds.Variable { + if _, ok := variables[key]; ok { + return true + } + } + return false +} + // Given a variable definition, and any value of input, return a value which // adheres to the variable definition, or throw an error. func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, input interface{}) (interface{}, error) { @@ -349,7 +367,7 @@ func isIterable(src interface{}) bool { * */ func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interface{}) interface{} { - if valueAST == nil { + if valueAST == nil || valueAST.GetKind() == kinds.NullValue { return nil } // precedence: value > type diff --git a/variables_test.go b/variables_test.go index 9dc430df..54ed6c36 100644 --- a/variables_test.go +++ b/variables_test.go @@ -70,6 +70,9 @@ func inputResolved(p graphql.ResolveParams) (interface{}, error) { if !ok { return nil, nil } + if input == nil { + return nil, nil + } b, err := json.Marshal(input) if err != nil { return nil, nil @@ -960,6 +963,36 @@ func TestVariables_ListsAndNullability_AllowsListsToBeNull(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) } } + +func TestVariables_ListsAndNullability_AllowsListsToBeNullWithMoreListValues(t *testing.T) { + doc := ` + query q($input: [String]) { + list(input: $input) + } + ` + params := map[string]interface{}{ + "input": []interface{}{nil, "ok", nil}, + } + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "list": `[null,"ok",null]`, + }, + } + ast := testutil.TestParse(t, doc) + ep := graphql.ExecuteParams{ + Schema: variablesTestSchema, + AST: ast, + Args: params, + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} func TestVariables_ListsAndNullability_AllowsListsToContainValues(t *testing.T) { doc := ` query q($input: [String]) {