From 57751694be29f98352f32d8a49a941c4e845dec0 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Thu, 14 Nov 2024 17:22:37 +0100 Subject: [PATCH] Allow setting websocket connection parameters --- graphql/client.go | 11 +++ graphql/websocket.go | 11 ++- internal/integration/generated.go | 60 ++++++++++++++++ internal/integration/integration_test.go | 83 +++++++++++++++++++++- internal/integration/roundtrip.go | 11 ++- internal/integration/schema.graphql | 1 + internal/integration/server/gqlgen_exec.go | 71 +++++++++++++++++- internal/integration/server/server.go | 36 +++++++++- 8 files changed, 276 insertions(+), 8 deletions(-) diff --git a/graphql/client.go b/graphql/client.go index 526395b8..e60b9733 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -132,6 +132,16 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client { // The client does not support queries nor mutations, and will return an error // if passed a request that attempts one. func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient { + return NewClientUsingWebSocketWithConnectionParams(endpoint, wsDialer, headers, nil) +} + +// NewClientUsingWebSocketWithConnectionParams returns a [WebSocketClient] which makes subscription requests +// to the given endpoint using webSocket. It allows to pass additional connection parameters +// to the server during the initial connection handshake. +// +// connectionParams is a map of connection parameters to be sent to the server +// during the initial connection handshake. +func NewClientUsingWebSocketWithConnectionParams(endpoint string, wsDialer Dialer, headers http.Header, connParams map[string]interface{}) WebSocketClient { if headers == nil { headers = http.Header{} } @@ -141,6 +151,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head return &webSocketClient{ Dialer: wsDialer, Header: headers, + connParams: connParams, errChan: make(chan error), endpoint: endpoint, subscriptions: subscriptionMap{map_: make(map[string]subscription)}, diff --git a/graphql/websocket.go b/graphql/websocket.go index 10b77daa..5fb97cec 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -48,12 +48,18 @@ type webSocketClient struct { Header http.Header endpoint string conn WSConn + connParams map[string]interface{} errChan chan error subscriptions subscriptionMap isClosing bool sync.Mutex } +type webSocketInitMessage struct { + Payload map[string]interface{} `json:"payload"` + Type string `json:"type"` +} + type webSocketSendMessage struct { Payload *Request `json:"payload"` Type string `json:"type"` @@ -67,8 +73,9 @@ type webSocketReceiveMessage struct { } func (w *webSocketClient) sendInit() error { - connInitMsg := webSocketSendMessage{ - Type: webSocketTypeConnInit, + connInitMsg := webSocketInitMessage{ + Type: webSocketTypeConnInit, + Payload: w.connParams, } return w.sendStructAsJSON(connInitMsg) } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 8a8f70cb..660efdf6 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -1311,6 +1311,14 @@ type __queryWithVariablesInput struct { // GetId returns __queryWithVariablesInput.Id, and is useful for accessing the field via an interface. func (v *__queryWithVariablesInput) GetId() string { return v.Id } +// countAuthorizedResponse is returned by countAuthorized on success. +type countAuthorizedResponse struct { + CountAuthorized int `json:"countAuthorized"` +} + +// GetCountAuthorized returns countAuthorizedResponse.CountAuthorized, and is useful for accessing the field via an interface. +func (v *countAuthorizedResponse) GetCountAuthorized() int { return v.CountAuthorized } + // countResponse is returned by count on success. type countResponse struct { Count int `json:"count"` @@ -3148,6 +3156,58 @@ func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) err return nil } +// The subscription executed by countAuthorized. +const countAuthorized_Operation = ` +subscription countAuthorized { + countAuthorized +} +` + +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] +func countAuthorized( + ctx_ context.Context, + client_ graphql.WebSocketClient, +) (dataChan_ chan countAuthorizedWsResponse, subscriptionID_ string, err_ error) { + req_ := &graphql.Request{ + OpName: "countAuthorized", + Query: countAuthorized_Operation, + } + + dataChan_ = make(chan countAuthorizedWsResponse) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, countAuthorizedForwardData) + + return dataChan_, subscriptionID_, err_ +} + +type countAuthorizedWsResponse struct { + Data *countAuthorizedResponse `json:"data"` + Extensions map[string]interface{} `json:"extensions,omitempty"` + Errors error `json:"errors"` +} + +func countAuthorizedForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { + var gqlResp graphql.Response + var wsResp countAuthorizedWsResponse + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan countAuthorizedWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan countAuthorizedWsResponse'") + } + dataChan_ <- wsResp + return nil +} + // The mutation executed by createUser. const createUser_Operation = ` mutation createUser ($user: NewUser!) { diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 96bc0b55..508f7a59 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -95,7 +95,8 @@ func TestSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - wsClient := newRoundtripWebSocketClient(t, server.URL) + wsClient := newRoundtripWebSocketClient(t, server.URL, nil) + errChan, err := wsClient.Start(ctx) require.NoError(t, err) @@ -144,6 +145,86 @@ func TestSubscription(t *testing.T) { } } +func TestSubscriptionConnectionParams(t *testing.T) { + _ = `# @genqlient + subscription countAuthorized { countAuthorized }` + + authKey := server.AuthKey + + ctx := context.Background() + server := server.RunServer() + defer server.Close() + + cases := []struct { + connParams map[string]interface{} + name string + expectedError string + }{ + { + name: "authorized_user_gets_counter", + connParams: map[string]interface{}{ + authKey: "authorized-user-token", + }, + }, + { + name: "unauthorized_user_gets_error", + expectedError: "input: countAuthorized unauthorized\n", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + wsClient := newRoundtripWebSocketClient(t, server.URL, tc.connParams) + + errChan, err := wsClient.Start(ctx) + require.NoError(t, err) + + dataChan, subscriptionID, err := countAuthorized(ctx, wsClient) + require.NoError(t, err) + defer wsClient.Close() + + var ( + counter = 0 + start = time.Now() + ) + + for loop := true; loop; { + select { + case resp, more := <-dataChan: + if !more { + loop = false + break + } + + if tc.expectedError != "" { + require.Error(t, resp.Errors) + assert.Equal(t, tc.expectedError, resp.Errors.Error()) + continue + } + + require.NotNil(t, resp.Data) + assert.Equal(t, counter, resp.Data.CountAuthorized) + require.Nil(t, resp.Errors) + + if time.Since(start) > 5*time.Second { + err := wsClient.Unsubscribe(subscriptionID) + require.NoError(t, err) + loop = false + } + + counter++ + + case err := <-errChan: + require.NoError(t, err) + + case <-time.After(10 * time.Second): + require.NoError(t, fmt.Errorf("subscription timed out")) + } + } + }) + } +} + func TestServerError(t *testing.T) { _ = `# @genqlient query failingQuery { fail me { id } }` diff --git a/internal/integration/roundtrip.go b/internal/integration/roundtrip.go index a020d7b4..810b5659 100644 --- a/internal/integration/roundtrip.go +++ b/internal/integration/roundtrip.go @@ -156,14 +156,19 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade return graphql.WSConn(conn), err } -func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient { +func newRoundtripWebSocketClient(t *testing.T, endpoint string, connectionParams map[string]interface{}) graphql.WebSocketClient { dialer := websocket.DefaultDialer if !strings.HasPrefix(endpoint, "ws") { _, address, _ := strings.Cut(endpoint, "://") endpoint = "ws://" + address } return &roundtripClient{ - wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil), - t: t, + wsWrapped: graphql.NewClientUsingWebSocketWithConnectionParams( + endpoint, + &MyDialer{Dialer: dialer}, + nil, + connectionParams, + ), + t: t, } } diff --git a/internal/integration/schema.graphql b/internal/integration/schema.graphql index 9f148ccf..16e13cab 100644 --- a/internal/integration/schema.graphql +++ b/internal/integration/schema.graphql @@ -19,6 +19,7 @@ type Mutation { type Subscription { count: Int! + countAuthorized: Int! } type User implements Being & Lucky { diff --git a/internal/integration/server/gqlgen_exec.go b/internal/integration/server/gqlgen_exec.go index ccf5d2aa..aab1b6d8 100644 --- a/internal/integration/server/gqlgen_exec.go +++ b/internal/integration/server/gqlgen_exec.go @@ -78,7 +78,8 @@ type ComplexityRoot struct { } Subscription struct { - Count func(childComplexity int) int + Count func(childComplexity int) int + CountAuthorized func(childComplexity int) int } User struct { @@ -108,6 +109,7 @@ type QueryResolver interface { } type SubscriptionResolver interface { Count(ctx context.Context) (<-chan int, error) + CountAuthorized(ctx context.Context) (<-chan int, error) } type executableSchema struct { @@ -291,6 +293,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Subscription.Count(childComplexity), true + case "Subscription.countAuthorized": + if e.complexity.Subscription.CountAuthorized == nil { + break + } + return e.complexity.Subscription.CountAuthorized(childComplexity), true + case "User.birthdate": if e.complexity.User.Birthdate == nil { break @@ -484,6 +492,7 @@ type Mutation { type Subscription { count: Int! + countAuthorized: Int! } type User implements Being & Lucky { @@ -1811,6 +1820,64 @@ func (ec *executionContext) fieldContext_Subscription_count(ctx context.Context, return fc, nil } +func (ec *executionContext) _Subscription_countAuthorized(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { + fc, err := ec.fieldContext_Subscription_countAuthorized(ctx, field) + if err != nil { + return nil + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Subscription().CountAuthorized(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return nil + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return nil + } + return func(ctx context.Context) graphql.Marshaler { + select { + case res, ok := <-resTmp.(<-chan int): + if !ok { + return nil + } + return graphql.WriterFunc(func(w io.Writer) { + w.Write([]byte{'{'}) + graphql.MarshalString(field.Alias).MarshalGQL(w) + w.Write([]byte{':'}) + ec.marshalNInt2int(ctx, field.Selections, res).MarshalGQL(w) + w.Write([]byte{'}'}) + }) + case <-ctx.Done(): + return nil + } + } +} + +func (ec *executionContext) fieldContext_Subscription_countAuthorized(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Subscription", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _User_id(ctx context.Context, field graphql.CollectedField, obj *User) (ret graphql.Marshaler) { fc, err := ec.fieldContext_User_id(ctx, field) if err != nil { @@ -4398,6 +4465,8 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection switch fields[0].Name { case "count": return ec._Subscription_count(ctx, fields[0]) + case "countAuthorized": + return ec._Subscription_countAuthorized(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index 2c79909f..e626e34a 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -173,11 +173,45 @@ func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { return respChan, nil } +func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, error) { + if getAuthToken(ctx) != "authorized-user-token" { + return nil, fmt.Errorf("unauthorized") + } + + return s.Count(ctx) +} + +const AuthKey = "authToken" + +type ( + authTokenCtxKey struct{} +) + +func withAuthToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, authTokenCtxKey{}, token) +} + +func getAuthToken(ctx context.Context) string { + if tkn, ok := ctx.Value(authTokenCtxKey{}).(string); ok { + return tkn + } + return "" +} + func RunServer() *httptest.Server { gqlgenServer := handler.New(NewExecutableSchema(Config{Resolvers: &resolver{}})) gqlgenServer.AddTransport(transport.POST{}) gqlgenServer.AddTransport(transport.GET{}) - gqlgenServer.AddTransport(transport.Websocket{}) + + gqlgenServer.AddTransport(transport.Websocket{ + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + if authToken, ok := initPayload[AuthKey].(string); ok && authToken != "" { + ctx = withAuthToken(ctx, authToken) + } + return ctx, &initPayload, nil + }, + }) + gqlgenServer.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { graphql.RegisterExtension(ctx, "foobar", "test") return next(ctx)