Skip to content

Commit

Permalink
More server tests
Browse files Browse the repository at this point in the history
  • Loading branch information
smira committed Mar 18, 2019
1 parent d72b78c commit 2ac731f
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 7 deletions.
9 changes: 7 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Client) Close() error {
// DiscoverVersions with the server
func (c *Client) DiscoverVersions(versions []ProtocolVersion) (serverVersions []ProtocolVersion, err error) {
var resp interface{}
resp, err = c.send(OPERATION_DISCOVER_VERSIONS,
resp, err = c.Send(OPERATION_DISCOVER_VERSIONS,
DiscoverVersionsRequest{
ProtocolVersions: versions,
})
Expand All @@ -93,7 +93,12 @@ func (c *Client) DiscoverVersions(versions []ProtocolVersion) (serverVersions []
return
}

func (c *Client) send(operation Enum, req interface{}) (resp interface{}, err error) {
// Send request to server and deliver response/error back
//
// Request payload should be passed as req, and response payload will be
// returned back as resp. Operation will be sent as a batch with single
// item.
func (c *Client) Send(operation Enum, req interface{}) (resp interface{}, err error) {
if c.conn == nil {
err = errors.New("not connected")
return
Expand Down
10 changes: 7 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,14 @@ func (s *Server) serve(conn net.Conn, session string) {
}
}

if s.SessionAuthHandler != nil {
s.mu.Lock()
sessionAuthHandler := s.SessionAuthHandler
s.mu.Unlock()

if sessionAuthHandler != nil {
var err error

sessionCtx.SessionAuth, err = s.SessionAuthHandler(conn)
sessionCtx.SessionAuth, err = sessionAuthHandler(conn)
if err != nil {
s.Log.Printf("[ERROR] [%s] Error in session auth handler: %s", session, err)
return
Expand Down Expand Up @@ -376,7 +380,7 @@ func (s *Server) handleWrapped(request *RequestContext, item *RequestBatchItem)
handler := s.handlers[item.Operation]

if handler == nil {
err = wrapError(errors.New("operation not supported"), RESULT_REASON_FEATURE_NOT_SUPPORTED)
err = wrapError(errors.New("operation not supported"), RESULT_REASON_OPERATION_NOT_SUPPORTED)
return
}

Expand Down
123 changes: 121 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ package kmip
import (
"context"
"crypto/tls"
"crypto/x509"
"log"
"net"
"os"
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -76,12 +78,16 @@ func (s *ServerSuite) SetupTest() {

s.client.ReadTimeout = time.Second
s.client.WriteTimeout = time.Second

s.Require().NoError(s.client.Connect())
}

func (s *ServerSuite) TearDownTest() {
s.Require().NoError(s.client.Close())

// reset server state
s.server.mu.Lock()
s.server.SessionAuthHandler = nil
s.server.initHandlers()
s.server.mu.Unlock()
}

func (s *ServerSuite) TearDownSuite() {
Expand All @@ -93,6 +99,8 @@ func (s *ServerSuite) TearDownSuite() {
}

func (s *ServerSuite) TestDiscoverVersions() {
s.Require().NoError(s.client.Connect())

versions, err := s.client.DiscoverVersions(DefaultSupportedVersions)
s.Require().NoError(err)
s.Require().Equal(DefaultSupportedVersions, versions)
Expand All @@ -110,6 +118,117 @@ func (s *ServerSuite) TestDiscoverVersions() {
s.Require().Equal([]ProtocolVersion(nil), versions)
}

func (s *ServerSuite) TestSessionAuthHandlerOkay() {
s.server.SessionAuthHandler = func(conn net.Conn) (interface{}, error) {
commonName := conn.(*tls.Conn).ConnectionState().PeerCertificates[0].Subject.CommonName

if commonName != "client_auth_test_cert" {
return nil, errors.New("wrong common name")
}

return commonName, nil
}

s.server.Handle(OPERATION_DISCOVER_VERSIONS, func(req *RequestContext, item *RequestBatchItem) (interface{}, error) {
if req.SessionAuth.(string) != "client_auth_test_cert" {
return nil, errors.New("wrong session auth")
}

return DiscoverVersionsResponse{
ProtocolVersions: nil,
}, nil
})

s.Require().NoError(s.client.Connect())

versions, err := s.client.DiscoverVersions(nil)
s.Require().NoError(err)
s.Require().Equal([]ProtocolVersion(nil), versions)
}

func (s *ServerSuite) TestSessionAuthHandlerFail() {
s.server.SessionAuthHandler = func(conn net.Conn) (interface{}, error) {
commonName := conn.(*tls.Conn).ConnectionState().PeerCertificates[0].Subject.CommonName

if commonName != "xxx" {
return nil, errors.New("wrong common name")
}

return commonName, nil
}

s.Require().NoError(s.client.Connect())

_, err := s.client.DiscoverVersions(nil)
s.Require().Regexp("broken pipe$", errors.Cause(err).Error())

s.client.Close()
}

func (s *ServerSuite) TestConnectTLSNoCert() {
var savedCerts []tls.Certificate
savedCerts, s.client.TLSConfig.Certificates = s.client.TLSConfig.Certificates, nil
defer func() {
s.client.TLSConfig.Certificates = savedCerts
}()

s.Require().EqualError(errors.Cause(s.client.Connect()), "remote error: tls: bad certificate")
}

func (s *ServerSuite) TestConnectTLSNoCA() {
var savedPool *x509.CertPool
savedPool, s.client.TLSConfig.RootCAs = s.client.TLSConfig.RootCAs, nil
defer func() {
s.client.TLSConfig.RootCAs = savedPool
}()

s.Require().EqualError(errors.Cause(s.client.Connect()), "x509: certificate signed by unknown authority")
}

func (s *ServerSuite) TestOperationGenericFail() {
s.server.Handle(OPERATION_DISCOVER_VERSIONS, func(req *RequestContext, item *RequestBatchItem) (interface{}, error) {
return nil, errors.New("oops!")
})

s.Require().NoError(s.client.Connect())

_, err := s.client.DiscoverVersions(nil)
s.Require().EqualError(errors.Cause(err), "oops!")
s.Require().Equal(errors.Cause(err).(Error).ResultReason(), RESULT_REASON_GENERAL_FAILURE)
}

func (s *ServerSuite) TestOperationPanic() {
s.server.Handle(OPERATION_DISCOVER_VERSIONS, func(req *RequestContext, item *RequestBatchItem) (interface{}, error) {
panic("oops!")
})

s.Require().NoError(s.client.Connect())

_, err := s.client.DiscoverVersions(nil)
s.Require().EqualError(errors.Cause(err), "panic: oops!")
s.Require().Equal(errors.Cause(err).(Error).ResultReason(), RESULT_REASON_GENERAL_FAILURE)
}

func (s *ServerSuite) TestOperationFailWithReason() {
s.server.Handle(OPERATION_DISCOVER_VERSIONS, func(req *RequestContext, item *RequestBatchItem) (interface{}, error) {
return nil, wrapError(errors.New("oops!"), RESULT_REASON_CRYPTOGRAPHIC_FAILURE)
})

s.Require().NoError(s.client.Connect())

_, err := s.client.DiscoverVersions(nil)
s.Require().EqualError(errors.Cause(err), "oops!")
s.Require().Equal(errors.Cause(err).(Error).ResultReason(), RESULT_REASON_CRYPTOGRAPHIC_FAILURE)
}

func (s *ServerSuite) TestOperationNotSupported() {
s.Require().NoError(s.client.Connect())

_, err := s.client.Send(OPERATION_GET, GetRequest{})
s.Require().EqualError(errors.Cause(err), "operation not supported")
s.Require().Equal(errors.Cause(err).(Error).ResultReason(), RESULT_REASON_OPERATION_NOT_SUPPORTED)
}

func TestServerSuite(t *testing.T) {
suite.Run(t, new(ServerSuite))
}

0 comments on commit 2ac731f

Please sign in to comment.