Skip to content

Commit

Permalink
fix DiscReason encoding/decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
ucwong committed Dec 12, 2024
1 parent 8e567f5 commit 48e3126
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
25 changes: 22 additions & 3 deletions p2p/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,7 @@ func (p *Peer) handle(msg Msg) error {
case msg.Code == discMsg:
// This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it.
var m struct{ R DiscReason }
rlp.Decode(msg.Payload, &m)
return m.R
return decodeDisconnectMessage(msg.Payload)
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
return msg.Discard()
Expand All @@ -372,6 +370,27 @@ func (p *Peer) handle(msg Msg) error {
return nil
}

// decodeDisconnectMessage decodes the payload of discMsg.
func decodeDisconnectMessage(r io.Reader) (reason DiscReason) {
s := rlp.NewStream(r, 100)
k, _, err := s.Kind()
if err != nil {
return DiscInvalid
}
if k == rlp.List {
s.List()
err = s.Decode(&reason)
} else {
// Legacy path: some implementations, including geth, used to send the disconnect
// reason as a byte array by accident.
err = s.Decode(&reason)
}
if err != nil {
reason = DiscInvalid
}
return reason
}

func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
n := 0
for _, cap := range caps {
Expand Down
5 changes: 4 additions & 1 deletion p2p/peer_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ const (
DiscSelf
DiscReadTimeout
DiscSubprotocolError = DiscReason(0x10)

DiscInvalid = 0xff
)

var discReasonToString = [...]string{
Expand All @@ -86,10 +88,11 @@ var discReasonToString = [...]string{
DiscSelf: "connected to self",
DiscReadTimeout: "read timeout",
DiscSubprotocolError: "subprotocol error",
DiscInvalid: "invalid disconnect reason",
}

func (d DiscReason) String() string {
if len(discReasonToString) <= int(d) {
if len(discReasonToString) <= int(d) || discReasonToString[d] == "" {
return fmt.Sprintf("unknown disconnect reason %d", d)
}
return discReasonToString[d]
Expand Down
24 changes: 10 additions & 14 deletions p2p/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ func (t *rlpxTransport) close(err error) {
// Tell the remote end why we're disconnecting if possible.
// We only bother doing this if the underlying connection supports
// setting a timeout tough.
if t.conn != nil {
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
deadline := time.Now().Add(discWriteTimeout)
if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline.
t.wbuf.Reset()
rlp.Encode(&t.wbuf, []DiscReason{r})
t.conn.Write(discMsg, t.wbuf.Bytes())
}
if reason, ok := err.(DiscReason); ok && reason != DiscNetworkError {
// We do not use the WriteMsg func since we want a custom deadline
deadline := time.Now().Add(discWriteTimeout)
if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline.
t.wbuf.Reset()
rlp.Encode(&t.wbuf, []any{reason})
t.conn.Write(discMsg, t.wbuf.Bytes())
}
}
t.conn.Close()
Expand Down Expand Up @@ -163,11 +162,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
if msg.Code == discMsg {
// Disconnect before protocol handshake is valid according to the
// spec and we send it ourself if the post-handshake checks fail.
// We can't return the reason directly, though, because it is echoed
// back otherwise. Wrap it in a string instead.
var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason)
return nil, reason[0]
r := decodeDisconnectMessage(msg.Payload)
return nil, r
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
Expand Down
10 changes: 8 additions & 2 deletions p2p/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestProtocolHandshake(t *testing.T) {
return
}

if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil {
t.Errorf("error receiving disconnect: %v", err)
}
}()
Expand All @@ -112,7 +112,13 @@ func TestProtocolHandshakeErrors(t *testing.T) {
}{
{
code: discMsg,
msg: []DiscReason{DiscQuitting},
msg: []any{DiscQuitting},
err: DiscQuitting,
},
{
// legacy disconnect encoding as byte array
code: discMsg,
msg: []byte{byte(DiscQuitting)},
err: DiscQuitting,
},
{
Expand Down

0 comments on commit 48e3126

Please sign in to comment.