Skip to content

Commit

Permalink
Allow setting websocket connection parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Nov 16, 2024
1 parent e030ff1 commit 5775169
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 8 deletions.
11 changes: 11 additions & 0 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
return NewClientUsingWebSocketWithConnectionParams(endpoint, wsDialer, headers, nil)
}

// NewClientUsingWebSocketWithConnectionParams returns a [WebSocketClient] which makes subscription requests
// to the given endpoint using webSocket. It allows to pass additional connection parameters
// to the server during the initial connection handshake.
//
// connectionParams is a map of connection parameters to be sent to the server
// during the initial connection handshake.
func NewClientUsingWebSocketWithConnectionParams(endpoint string, wsDialer Dialer, headers http.Header, connParams map[string]interface{}) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
Expand All @@ -141,6 +151,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head
return &webSocketClient{
Dialer: wsDialer,
Header: headers,
connParams: connParams,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
Expand Down
11 changes: 9 additions & 2 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,18 @@ type webSocketClient struct {
Header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
errChan chan error
subscriptions subscriptionMap
isClosing bool
sync.Mutex
}

type webSocketInitMessage struct {
Payload map[string]interface{} `json:"payload"`
Type string `json:"type"`
}

type webSocketSendMessage struct {
Payload *Request `json:"payload"`
Type string `json:"type"`
Expand All @@ -67,8 +73,9 @@ type webSocketReceiveMessage struct {
}

func (w *webSocketClient) sendInit() error {
connInitMsg := webSocketSendMessage{
Type: webSocketTypeConnInit,
connInitMsg := webSocketInitMessage{
Type: webSocketTypeConnInit,
Payload: w.connParams,
}
return w.sendStructAsJSON(connInitMsg)
}
Expand Down
60 changes: 60 additions & 0 deletions internal/integration/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 82 additions & 1 deletion internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ func TestSubscription(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL)
wsClient := newRoundtripWebSocketClient(t, server.URL, nil)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -144,6 +145,86 @@ func TestSubscription(t *testing.T) {
}
}

func TestSubscriptionConnectionParams(t *testing.T) {
_ = `# @genqlient
subscription countAuthorized { countAuthorized }`

authKey := server.AuthKey

ctx := context.Background()
server := server.RunServer()
defer server.Close()

cases := []struct {
connParams map[string]interface{}
name string
expectedError string
}{
{
name: "authorized_user_gets_counter",
connParams: map[string]interface{}{
authKey: "authorized-user-token",
},
},
{
name: "unauthorized_user_gets_error",
expectedError: "input: countAuthorized unauthorized\n",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL, tc.connParams)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

dataChan, subscriptionID, err := countAuthorized(ctx, wsClient)
require.NoError(t, err)
defer wsClient.Close()

var (
counter = 0
start = time.Now()
)

for loop := true; loop; {
select {
case resp, more := <-dataChan:
if !more {
loop = false
break
}

if tc.expectedError != "" {
require.Error(t, resp.Errors)
assert.Equal(t, tc.expectedError, resp.Errors.Error())
continue
}

require.NotNil(t, resp.Data)
assert.Equal(t, counter, resp.Data.CountAuthorized)
require.Nil(t, resp.Errors)

if time.Since(start) > 5*time.Second {
err := wsClient.Unsubscribe(subscriptionID)
require.NoError(t, err)
loop = false
}

counter++

case err := <-errChan:
require.NoError(t, err)

case <-time.After(10 * time.Second):
require.NoError(t, fmt.Errorf("subscription timed out"))
}
}
})
}
}

func TestServerError(t *testing.T) {
_ = `# @genqlient
query failingQuery { fail me { id } }`
Expand Down
11 changes: 8 additions & 3 deletions internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,19 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade
return graphql.WSConn(conn), err
}

func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
func newRoundtripWebSocketClient(t *testing.T, endpoint string, connectionParams map[string]interface{}) graphql.WebSocketClient {
dialer := websocket.DefaultDialer
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}
return &roundtripClient{
wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil),
t: t,
wsWrapped: graphql.NewClientUsingWebSocketWithConnectionParams(
endpoint,
&MyDialer{Dialer: dialer},
nil,
connectionParams,
),
t: t,
}
}
1 change: 1 addition & 0 deletions internal/integration/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Mutation {

type Subscription {
count: Int!
countAuthorized: Int!
}

type User implements Being & Lucky {
Expand Down
71 changes: 70 additions & 1 deletion internal/integration/server/gqlgen_exec.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5775169

Please sign in to comment.