Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssh: add top-level DialContext #280

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package ssh

import (
"bytes"
"context"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -168,6 +169,50 @@ func (c *Client) handleChannelOpens(in <-chan NewChannel) {
c.mu.Unlock()
}

// DialContext starts a client connection to the given SSH server. It is a
// convenience function that connects to the given network address,
// initiates the SSH handshake, and then sets up a Client.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// See [Dial] for additional information.
func DialContext(ctx context.Context, network, addr string, config *ClientConfig) (*Client, error) {
d := net.Dialer{
Timeout: config.Timeout,
}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
type result struct {
client *Client
err error
}
ch := make(chan result)
go func() {
var client *Client
c, chans, reqs, err := NewClientConn(conn, addr, config)
if err == nil {
client = NewClient(c, chans, reqs)
}
select {
case ch <- result{client, err}:
case <-ctx.Done():
if client != nil {
client.Close()
}
}
}()
select {
case res := <-ch:
return res.client, res.err
case <-ctx.Done():
return nil, context.Cause(ctx)
}
}

// Dial starts a client connection to the given SSH server. It is a
// convenience function that connects to the given network address,
// initiates the SSH handshake, and then sets up a Client. For access
Expand Down
26 changes: 26 additions & 0 deletions ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package ssh

import (
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"strings"
"testing"
"time"
)

func TestClientVersion(t *testing.T) {
Expand Down Expand Up @@ -365,3 +367,27 @@ func TestUnsupportedAlgorithm(t *testing.T) {
})
}
}

func TestDialContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := DialContext(ctx, "tcp", ":22", &ClientConfig{})
wantErr := context.Canceled
if !errors.Is(err, wantErr) {
t.Errorf("DialContext: err == %v, expected %v", err, wantErr)
}

ctx, cancel = context.WithDeadline(context.Background(), time.Now())
defer cancel()
_, err = DialContext(ctx, "tcp", ":22", &ClientConfig{})
wantErr = context.DeadlineExceeded
if !errors.Is(err, wantErr) {
t.Errorf("DialContext: err == %v, expected %v", err, wantErr)
}

ctx = context.Background()
_, err = DialContext(ctx, "tcp", ":22", &ClientConfig{})
if _, ok := err.(*net.OpError); !ok {
t.Errorf("DialContext: err == %#v, expected *net.OpError", err)
}
}
26 changes: 26 additions & 0 deletions ssh/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package ssh_test
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
Expand All @@ -17,6 +18,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
Expand Down Expand Up @@ -262,6 +264,30 @@ func ExampleDial() {
fmt.Println(b.String())
}

func ExampleDialContext() {
var hostKey ssh.PublicKey
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("yourpassword"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}

// The Context supplied to DialContext allows the caller to control
// the timeout or cancel opening an SSH connection.
//
// Cancelling the context after DialContext returns will not effect
// the resulting Client.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
client, err := ssh.DialContext(ctx, "tcp", "yourserver.com:22", config)
if err != nil {
log.Fatal("Failed to dial: ", err)
}
defer client.Close()
}

func ExamplePublicKeys() {
var hostKey ssh.PublicKey
// A public key may be used to authenticate against the remote
Expand Down