From 569efb29978d4b864cdc6ef61a464aaf8ff77ed2 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Sat, 16 Nov 2024 21:49:11 +0100 Subject: [PATCH] Create separate test function --- internal/integration/generated.go | 60 +++++++++++++ internal/integration/integration_test.go | 100 +++++++++++++++++---- internal/integration/schema.graphql | 1 + internal/integration/server/gqlgen_exec.go | 71 ++++++++++++++- internal/integration/server/server.go | 32 +++++-- 5 files changed, 236 insertions(+), 28 deletions(-) diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 8a8f70cb..660efdf6 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -1311,6 +1311,14 @@ type __queryWithVariablesInput struct { // GetId returns __queryWithVariablesInput.Id, and is useful for accessing the field via an interface. func (v *__queryWithVariablesInput) GetId() string { return v.Id } +// countAuthorizedResponse is returned by countAuthorized on success. +type countAuthorizedResponse struct { + CountAuthorized int `json:"countAuthorized"` +} + +// GetCountAuthorized returns countAuthorizedResponse.CountAuthorized, and is useful for accessing the field via an interface. +func (v *countAuthorizedResponse) GetCountAuthorized() int { return v.CountAuthorized } + // countResponse is returned by count on success. type countResponse struct { Count int `json:"count"` @@ -3148,6 +3156,58 @@ func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) err return nil } +// The subscription executed by countAuthorized. +const countAuthorized_Operation = ` +subscription countAuthorized { + countAuthorized +} +` + +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] +func countAuthorized( + ctx_ context.Context, + client_ graphql.WebSocketClient, +) (dataChan_ chan countAuthorizedWsResponse, subscriptionID_ string, err_ error) { + req_ := &graphql.Request{ + OpName: "countAuthorized", + Query: countAuthorized_Operation, + } + + dataChan_ = make(chan countAuthorizedWsResponse) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, countAuthorizedForwardData) + + return dataChan_, subscriptionID_, err_ +} + +type countAuthorizedWsResponse struct { + Data *countAuthorizedResponse `json:"data"` + Extensions map[string]interface{} `json:"extensions,omitempty"` + Errors error `json:"errors"` +} + +func countAuthorizedForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { + var gqlResp graphql.Response + var wsResp countAuthorizedWsResponse + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan countAuthorizedWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan countAuthorizedWsResponse'") + } + dataChan_ <- wsResp + return nil +} + // The mutation executed by createUser. const createUser_Operation = ` mutation createUser ($user: NewUser!) { diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 6611acf3..508f7a59 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -66,17 +66,13 @@ func TestSubscription(t *testing.T) { _ = `# @genqlient subscription count { count }` - authKey := server.AuthKey - ctx := context.Background() server := server.RunServer() defer server.Close() cases := []struct { - connParams map[string]interface{} name string unsubThreshold time.Duration - counterStart int expected subscriptionResult }{ { @@ -87,18 +83,6 @@ func TestSubscription(t *testing.T) { serverChannelClosed: true, }, }, - { - name: "server_closed_authorized_user_gets_incremented_counter", - unsubThreshold: 5 * time.Second, - counterStart: 1000, - connParams: map[string]interface{}{ - authKey: "authorized-user-token", - }, - expected: subscriptionResult{ - clientUnsubscribed: false, - serverChannelClosed: true, - }, - }, { name: "client_unsubscribed", unsubThreshold: 300 * time.Millisecond, @@ -111,7 +95,7 @@ func TestSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - wsClient := newRoundtripWebSocketClient(t, server.URL, tc.connParams) + wsClient := newRoundtripWebSocketClient(t, server.URL, nil) errChan, err := wsClient.Start(ctx) require.NoError(t, err) @@ -121,7 +105,7 @@ func TestSubscription(t *testing.T) { defer wsClient.Close() var ( - counter = tc.counterStart + counter = 0 start = time.Now() result = subscriptionResult{} ) @@ -161,6 +145,86 @@ func TestSubscription(t *testing.T) { } } +func TestSubscriptionConnectionParams(t *testing.T) { + _ = `# @genqlient + subscription countAuthorized { countAuthorized }` + + authKey := server.AuthKey + + ctx := context.Background() + server := server.RunServer() + defer server.Close() + + cases := []struct { + connParams map[string]interface{} + name string + expectedError string + }{ + { + name: "authorized_user_gets_counter", + connParams: map[string]interface{}{ + authKey: "authorized-user-token", + }, + }, + { + name: "unauthorized_user_gets_error", + expectedError: "input: countAuthorized unauthorized\n", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + wsClient := newRoundtripWebSocketClient(t, server.URL, tc.connParams) + + errChan, err := wsClient.Start(ctx) + require.NoError(t, err) + + dataChan, subscriptionID, err := countAuthorized(ctx, wsClient) + require.NoError(t, err) + defer wsClient.Close() + + var ( + counter = 0 + start = time.Now() + ) + + for loop := true; loop; { + select { + case resp, more := <-dataChan: + if !more { + loop = false + break + } + + if tc.expectedError != "" { + require.Error(t, resp.Errors) + assert.Equal(t, tc.expectedError, resp.Errors.Error()) + continue + } + + require.NotNil(t, resp.Data) + assert.Equal(t, counter, resp.Data.CountAuthorized) + require.Nil(t, resp.Errors) + + if time.Since(start) > 5*time.Second { + err := wsClient.Unsubscribe(subscriptionID) + require.NoError(t, err) + loop = false + } + + counter++ + + case err := <-errChan: + require.NoError(t, err) + + case <-time.After(10 * time.Second): + require.NoError(t, fmt.Errorf("subscription timed out")) + } + } + }) + } +} + func TestServerError(t *testing.T) { _ = `# @genqlient query failingQuery { fail me { id } }` diff --git a/internal/integration/schema.graphql b/internal/integration/schema.graphql index 9f148ccf..16e13cab 100644 --- a/internal/integration/schema.graphql +++ b/internal/integration/schema.graphql @@ -19,6 +19,7 @@ type Mutation { type Subscription { count: Int! + countAuthorized: Int! } type User implements Being & Lucky { diff --git a/internal/integration/server/gqlgen_exec.go b/internal/integration/server/gqlgen_exec.go index ccf5d2aa..aab1b6d8 100644 --- a/internal/integration/server/gqlgen_exec.go +++ b/internal/integration/server/gqlgen_exec.go @@ -78,7 +78,8 @@ type ComplexityRoot struct { } Subscription struct { - Count func(childComplexity int) int + Count func(childComplexity int) int + CountAuthorized func(childComplexity int) int } User struct { @@ -108,6 +109,7 @@ type QueryResolver interface { } type SubscriptionResolver interface { Count(ctx context.Context) (<-chan int, error) + CountAuthorized(ctx context.Context) (<-chan int, error) } type executableSchema struct { @@ -291,6 +293,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Subscription.Count(childComplexity), true + case "Subscription.countAuthorized": + if e.complexity.Subscription.CountAuthorized == nil { + break + } + return e.complexity.Subscription.CountAuthorized(childComplexity), true + case "User.birthdate": if e.complexity.User.Birthdate == nil { break @@ -484,6 +492,7 @@ type Mutation { type Subscription { count: Int! + countAuthorized: Int! } type User implements Being & Lucky { @@ -1811,6 +1820,64 @@ func (ec *executionContext) fieldContext_Subscription_count(ctx context.Context, return fc, nil } +func (ec *executionContext) _Subscription_countAuthorized(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { + fc, err := ec.fieldContext_Subscription_countAuthorized(ctx, field) + if err != nil { + return nil + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Subscription().CountAuthorized(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return nil + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return nil + } + return func(ctx context.Context) graphql.Marshaler { + select { + case res, ok := <-resTmp.(<-chan int): + if !ok { + return nil + } + return graphql.WriterFunc(func(w io.Writer) { + w.Write([]byte{'{'}) + graphql.MarshalString(field.Alias).MarshalGQL(w) + w.Write([]byte{':'}) + ec.marshalNInt2int(ctx, field.Selections, res).MarshalGQL(w) + w.Write([]byte{'}'}) + }) + case <-ctx.Done(): + return nil + } + } +} + +func (ec *executionContext) fieldContext_Subscription_countAuthorized(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Subscription", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _User_id(ctx context.Context, field graphql.CollectedField, obj *User) (ret graphql.Marshaler) { fc, err := ec.fieldContext_User_id(ctx, field) if err != nil { @@ -4398,6 +4465,8 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection switch fields[0].Name { case "count": return ec._Subscription_count(ctx, fields[0]) + case "countAuthorized": + return ec._Subscription_countAuthorized(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index f364490c..67cdcf1a 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -168,23 +168,37 @@ func getAuthToken(ctx context.Context) string { } func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { - respCounter := 0 - if getAuthToken(ctx) == "authorized-user-token" { - respCounter = 1000 + respChan := make(chan int, 1) + go func(respChan chan int) { + defer close(respChan) + counter := 0 + for { + if counter == 10 { + return + } + respChan <- counter + counter++ + time.Sleep(100 * time.Millisecond) + } + }(respChan) + return respChan, nil +} + +func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, error) { + if getAuthToken(ctx) != "authorized-user-token" { + return nil, fmt.Errorf("unauthorized") } respChan := make(chan int, 1) go func(respChan chan int) { defer close(respChan) - closeCounter := 0 + counter := 0 for { - if closeCounter == 10 { + if counter == 10 { return } - closeCounter++ - respChan <- respCounter - respCounter++ - time.Sleep(100 * time.Millisecond) + respChan <- counter + counter++ } }(respChan) return respChan, nil