Skip to content

Commit

Permalink
middleware: basic auth middleware can extract and check multiple auth…
Browse files Browse the repository at this point in the history
… headers
  • Loading branch information
aldas committed Nov 7, 2023
1 parent c7d6d43 commit 5f9f144
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 80 deletions.
44 changes: 26 additions & 18 deletions middleware/basic_auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"bytes"
"encoding/base64"
"net/http"
"strconv"
Expand All @@ -15,7 +16,8 @@ type (
// Skipper defines a function to skip middleware.
Skipper Skipper

// Validator is a function to validate BasicAuth credentials.
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic
// auth headers this function would be called once for each header until first valid result is returned
// Required.
Validator BasicAuthValidator

Expand Down Expand Up @@ -71,30 +73,36 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c)
}

auth := c.Request().Header.Get(echo.HeaderAuthorization)
var lastError error
l := len(basic)
for i, auth := range c.Request().Header[echo.HeaderAuthorization] {
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
continue
}

if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
// Invalid base64 shouldn't be treated as error
// instead should be treated as invalid client input
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err)
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode)
continue
}

cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
valid, err := config.Validator(cred[:i], cred[i+1:], c)
if err != nil {
return err
} else if valid {
return next(c)
}
break
idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c)
if errValidate != nil {
lastError = errValidate
} else if valid {
return next(c)
}
}
if i >= headerCountLimit { // guard against attacker maliciously sending huge amount of invalid headers
break
}
}

if lastError != nil {
return lastError
}

realm := defaultRealm
Expand Down
182 changes: 136 additions & 46 deletions middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -11,11 +12,139 @@ import (
"github.com/stretchr/testify/assert"
)

func TestBasicAuthWithConfig(t *testing.T) {
validatorFunc := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
defaultConfig := BasicAuthConfig{Validator: validatorFunc}

// we can not add OK value here because ranging over map returns random order. We just try to trigger break
tooManyAuths := make([]string, 0)
for i := 0; i < headerCountLimit+2; i++ {
tooManyAuths = append(tooManyAuths, basic+" "+base64.StdEncoding.EncodeToString([]byte("nope:nope")))
}

var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, from multiple auth headers one is ok",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), // different type
basic + " NOT_BASE64", // invalid basic auth
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), // OK
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, too many invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: tooManyAuths,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()

mw := BasicAuthWithConfig(tc.givenConfig)

h := mw(func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()

if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err := h(e.NewContext(req, res))

if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}

func TestBasicAuth(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
Expand All @@ -26,50 +155,11 @@ func TestBasicAuth(t *testing.T) {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))

// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
}
20 changes: 10 additions & 10 deletions middleware/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
)

const (
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
// headerCountLimit is arbitrary number to limit number of headers processed. this limits possible resource exhaustion
// attack vector
extractorLimit = 20
headerCountLimit = 20
)

var errHeaderExtractorValueMissing = errors.New("missing value in request header")
Expand Down Expand Up @@ -105,14 +105,14 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
for i, value := range values {
if prefixLen == 0 {
result = append(result, value)
if i >= extractorLimit-1 {
if i >= headerCountLimit-1 {
break
}
continue
}
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
if i >= extractorLimit-1 {
if i >= headerCountLimit-1 {
break
}
}
Expand All @@ -134,8 +134,8 @@ func valuesFromQuery(param string) ValuesExtractor {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, errQueryExtractorValueMissing
} else if len(result) > extractorLimit-1 {
result = result[:extractorLimit]
} else if len(result) > headerCountLimit-1 {
result = result[:headerCountLimit]
}
return result, nil
}
Expand All @@ -149,7 +149,7 @@ func valuesFromParam(param string) ValuesExtractor {
for i, p := range c.ParamNames() {
if param == p {
result = append(result, paramVales[i])
if i >= extractorLimit-1 {
if i >= headerCountLimit-1 {
break
}
}
Expand All @@ -173,7 +173,7 @@ func valuesFromCookie(name string) ValuesExtractor {
for i, cookie := range cookies {
if name == cookie.Name {
result = append(result, cookie.Value)
if i >= extractorLimit-1 {
if i >= headerCountLimit-1 {
break
}
}
Expand All @@ -195,8 +195,8 @@ func valuesFromForm(name string) ValuesExtractor {
if len(values) == 0 {
return nil, errFormExtractorValueMissing
}
if len(values) > extractorLimit-1 {
values = values[:extractorLimit]
if len(values) > headerCountLimit-1 {
values = values[:headerCountLimit]
}
result := append([]string{}, values...)
return result, nil
Expand Down
Loading

0 comments on commit 5f9f144

Please sign in to comment.