Skip to content

Commit

Permalink
Merge pull request multiformats#19 from libp2p/fix/close-on-err
Browse files Browse the repository at this point in the history
improve correctness of closing connections on failure
  • Loading branch information
Stebalien authored Apr 26, 2019
2 parents 773b63c + 8467c1e commit 5ddf5de
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
74 changes: 56 additions & 18 deletions listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ var _ = Describe("Listener", func() {

It("accepts a single connection", func() {
ln := createListener(defaultUpgrader)
defer ln.Close()
cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expect(err).ToNot(HaveOccurred())
sconn, err := ln.Accept()
Expand All @@ -113,6 +114,7 @@ var _ = Describe("Listener", func() {

It("accepts multiple connections", func() {
ln := createListener(defaultUpgrader)
defer ln.Close()
const num = 10
for i := 0; i < 10; i++ {
cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expand All @@ -127,11 +129,15 @@ var _ = Describe("Listener", func() {
const timeout = 200 * time.Millisecond
tpt.AcceptTimeout = timeout
ln := createListener(defaultUpgrader)
defer ln.Close()
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
if !Expect(err).ToNot(HaveOccurred()) {
return
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer conn.Close()
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
// start a Read. It will block until the connection is closed
Expand All @@ -151,10 +157,16 @@ var _ = Describe("Listener", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, _ = ln.Accept()
conn, err := ln.Accept()
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
close(done)
}()
_, _ = dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
Consistently(done).ShouldNot(BeClosed())
// make the goroutine return
ln.Close()
Expand All @@ -178,6 +190,7 @@ var _ = Describe("Listener", func() {
if err != nil {
return
}
conn.Close()
accepted <- conn
}
}()
Expand All @@ -187,8 +200,14 @@ var _ = Describe("Listener", func() {
wg.Add(1)
go func() {
defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if Expect(err).ToNot(HaveOccurred()) {
stream, err := conn.AcceptStream() // wait for conn to be accepted.
if !Expect(err).To(HaveOccurred()) {
stream.Close()
}
conn.Close()
}
wg.Done()
}()
}
Expand All @@ -201,29 +220,40 @@ var _ = Describe("Listener", func() {

It("stops setting up when the more than AcceptQueueLength connections are waiting to get accepted", func() {
ln := createListener(defaultUpgrader)
defer ln.Close()

// setup AcceptQueueLength connections, but don't accept any of them
dialed := make(chan struct{}, 10*st.AcceptQueueLength) // used as a thread-safe counter
dialed := make(chan tpt.Conn, 10*st.AcceptQueueLength) // used as a thread-safe counter
for i := 0; i < st.AcceptQueueLength; i++ {
go func() {
defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
dialed <- struct{}{}
dialed <- conn
}()
}
Eventually(dialed).Should(HaveLen(st.AcceptQueueLength))
// dial a new connection. This connection should not complete setup, since the queue is full
go func() {
defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred())
dialed <- struct{}{}
dialed <- conn
}()
Consistently(dialed).Should(HaveLen(st.AcceptQueueLength))
// accept a single connection. Now the new connection should be set up, and fill the queue again
_, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
conn, err := ln.Accept()
if Expect(err).ToNot(HaveOccurred()) {
conn.Close()
}
Eventually(dialed).Should(HaveLen(st.AcceptQueueLength + 1))

// Cleanup
for i := 0; i < st.AcceptQueueLength+1; i++ {
if c := <-dialed; c != nil {
c.Close()
}
}
})
})

Expand All @@ -233,9 +263,12 @@ var _ = Describe("Listener", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
conn, err := ln.Accept()
if Expect(err).To(HaveOccurred()) {
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
} else {
conn.Close()
}
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
Expand All @@ -246,15 +279,20 @@ var _ = Describe("Listener", func() {
It("doesn't accept new connections when it is closed", func() {
ln := createListener(defaultUpgrader)
Expect(ln.Close()).To(Succeed())
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expect(err).To(HaveOccurred())
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
})

It("closes incoming connections that have not yet been accepted", func() {
ln := createListener(defaultUpgrader)
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if !Expect(err).ToNot(HaveOccurred()) {
ln.Close()
return
}
Expect(conn.IsClosed()).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(ln.Close()).To(Succeed())
Eventually(conn.IsClosed).Should(BeTrue())
})
Expand Down
6 changes: 5 additions & 1 deletion upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma
}
smconn, err := u.setupMuxer(ctx, sconn, p)
if err != nil {
conn.Close()
sconn.Close()
return nil, fmt.Errorf("failed to negotiate security stream multiplexer: %s", err)
}
return &transportConn{
Expand Down Expand Up @@ -122,6 +122,10 @@ func (u *Upgrader) setupMuxer(ctx context.Context, conn net.Conn, p peer.ID) (sm
case <-done:
return smconn, err
case <-ctx.Done():
// interrupt this process
conn.Close()
// wait to finish
<-done
return nil, ctx.Err()
}
}

0 comments on commit 5ddf5de

Please sign in to comment.