-
Notifications
You must be signed in to change notification settings - Fork 105
/
connection.go
181 lines (157 loc) · 4.24 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
package nostr
import (
"bytes"
"compress/flate"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
)
type Connection struct {
conn net.Conn
enableCompression bool
controlHandler wsutil.FrameHandlerFunc
flateReader *wsflate.Reader
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(requestHeader),
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
},
TLSConfig: tlsConfig,
}
conn, _, hs, err := dialer.Dial(ctx, url)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}
// reader
var flateReader *wsflate.Reader
var msgStateR wsflate.MessageState
if enableCompression {
msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
}
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
reader := &wsutil.Reader{
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgStateR,
},
}
// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
InfoLogger.Printf("Failed to create flate writer: %v", err)
}
return fw
})
}
writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgStateW)
return &Connection{
conn: conn,
enableCompression: enableCompression,
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
writer: writer,
msgStateW: &msgStateW,
}, nil
}
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}
if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
if err := c.flateWriter.Close(); err != nil {
return fmt.Errorf("failed to close flate writer: %w", err)
}
} else {
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
}
if err := c.writer.Flush(); err != nil {
return fmt.Errorf("failed to flush writer: %w", err)
}
return nil
}
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
for {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}
h, err := c.reader.NextFrame()
if err != nil {
c.conn.Close()
return fmt.Errorf("failed to advance frame: %w", err)
}
if h.OpCode.IsControl() {
if err := c.controlHandler(h, c.reader); err != nil {
return fmt.Errorf("failed to handle control frame: %w", err)
}
} else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText {
break
}
if err := c.reader.Discard(); err != nil {
return fmt.Errorf("failed to discard: %w", err)
}
}
if c.msgStateR.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
} else {
if _, err := io.Copy(buf, c.reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
}
return nil
}
func (c *Connection) Close() error {
return c.conn.Close()
}