From 5f9f144251e2f4aab8cd1ee31eb0ba15a76da43f Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 7 Nov 2023 10:07:46 +0200 Subject: [PATCH] middleware: basic auth middleware can extract and check multiple auth headers --- middleware/basic_auth.go | 44 ++++---- middleware/basic_auth_test.go | 182 +++++++++++++++++++++++++--------- middleware/extractor.go | 20 ++-- middleware/extractor_test.go | 12 +-- 4 files changed, 178 insertions(+), 80 deletions(-) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index f9e8caafe..88591cea7 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,6 +1,7 @@ package middleware import ( + "bytes" "encoding/base64" "net/http" "strconv" @@ -15,7 +16,8 @@ type ( // Skipper defines a function to skip middleware. Skipper Skipper - // Validator is a function to validate BasicAuth credentials. + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic + // auth headers this function would be called once for each header until first valid result is returned // Required. Validator BasicAuthValidator @@ -71,30 +73,36 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + for i, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) + continue } - - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } + if i >= headerCountLimit { // guard against attacker maliciously sending huge amount of invalid headers + break + } + } + + if lastError != nil { + return lastError } realm := defaultRealm diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 20e769214..c081c30c0 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -2,6 +2,7 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" @@ -11,11 +12,139 @@ import ( "github.com/stretchr/testify/assert" ) +func TestBasicAuthWithConfig(t *testing.T) { + validatorFunc := func(u, p string, c echo.Context) (bool, error) { + if u == "joe" && p == "secret" { + return true, nil + } + if u == "error" { + return false, errors.New(p) + } + return false, nil + } + defaultConfig := BasicAuthConfig{Validator: validatorFunc} + + // we can not add OK value here because ranging over map returns random order. We just try to trigger break + tooManyAuths := make([]string, 0) + for i := 0; i < headerCountLimit+2; i++ { + tooManyAuths = append(tooManyAuths, basic+" "+base64.StdEncoding.EncodeToString([]byte("nope:nope"))) + } + + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, from multiple auth headers one is ok", + givenConfig: defaultConfig, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), // different type + basic + " NOT_BASE64", // invalid basic auth + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), // OK + }, + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, too many invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: tooManyAuths, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + mw := BasicAuthWithConfig(tc.givenConfig) + + h := mw(func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err := h(e.NewContext(req, res)) + + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} + func TestBasicAuth(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) f := func(u, p string, c echo.Context) (bool, error) { if u == "joe" && p == "secret" { return true, nil @@ -26,50 +155,11 @@ func TestBasicAuth(t *testing.T) { return c.String(http.StatusOK, "test") }) - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) + auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, auth) assert.NoError(t, h(c)) - - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) - - // Invalid base64 string - auth = basic + " invalidString" - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code) - - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) } diff --git a/middleware/extractor.go b/middleware/extractor.go index 5d9cee6d0..4825c8d20 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -9,9 +9,9 @@ import ( ) const ( - // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion + // headerCountLimit is arbitrary number to limit number of headers processed. this limits possible resource exhaustion // attack vector - extractorLimit = 20 + headerCountLimit = 20 ) var errHeaderExtractorValueMissing = errors.New("missing value in request header") @@ -105,14 +105,14 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { for i, value := range values { if prefixLen == 0 { result = append(result, value) - if i >= extractorLimit-1 { + if i >= headerCountLimit-1 { break } continue } if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { result = append(result, value[prefixLen:]) - if i >= extractorLimit-1 { + if i >= headerCountLimit-1 { break } } @@ -134,8 +134,8 @@ func valuesFromQuery(param string) ValuesExtractor { result := c.QueryParams()[param] if len(result) == 0 { return nil, errQueryExtractorValueMissing - } else if len(result) > extractorLimit-1 { - result = result[:extractorLimit] + } else if len(result) > headerCountLimit-1 { + result = result[:headerCountLimit] } return result, nil } @@ -149,7 +149,7 @@ func valuesFromParam(param string) ValuesExtractor { for i, p := range c.ParamNames() { if param == p { result = append(result, paramVales[i]) - if i >= extractorLimit-1 { + if i >= headerCountLimit-1 { break } } @@ -173,7 +173,7 @@ func valuesFromCookie(name string) ValuesExtractor { for i, cookie := range cookies { if name == cookie.Name { result = append(result, cookie.Value) - if i >= extractorLimit-1 { + if i >= headerCountLimit-1 { break } } @@ -195,8 +195,8 @@ func valuesFromForm(name string) ValuesExtractor { if len(values) == 0 { return nil, errFormExtractorValueMissing } - if len(values) > extractorLimit-1 { - values = values[:extractorLimit] + if len(values) > headerCountLimit-1 { + values = values[:headerCountLimit] } result := append([]string{}, values...) return result, nil diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 428c5563e..772afaf48 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -202,7 +202,7 @@ func TestValuesFromHeader(t *testing.T) { expectError: errHeaderExtractorValueMissing.Error(), }, { - name: "ok, prefix, cut values over extractorLimit", + name: "ok, prefix, cut values over headerCountLimit", givenRequest: func(req *http.Request) { for i := 1; i <= 25; i++ { req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) @@ -216,7 +216,7 @@ func TestValuesFromHeader(t *testing.T) { }, }, { - name: "ok, cut values over extractorLimit", + name: "ok, cut values over headerCountLimit", givenRequest: func(req *http.Request) { for i := 1; i <= 25; i++ { req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) @@ -282,7 +282,7 @@ func TestValuesFromQuery(t *testing.T) { expectError: errQueryExtractorValueMissing.Error(), }, { - name: "ok, cut values over extractorLimit", + name: "ok, cut values over headerCountLimit", givenQueryPart: "?name=test" + "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + @@ -361,7 +361,7 @@ func TestValuesFromParam(t *testing.T) { expectError: errParamExtractorValueMissing.Error(), }, { - name: "ok, cut values over extractorLimit", + name: "ok, cut values over headerCountLimit", givenPathParams: examplePathParams20, whenName: "id", expectValues: []string{ @@ -437,7 +437,7 @@ func TestValuesFromCookie(t *testing.T) { expectError: errCookieExtractorValueMissing.Error(), }, { - name: "ok, cut values over extractorLimit", + name: "ok, cut values over headerCountLimit", givenRequest: func(req *http.Request) { for i := 1; i < 25; i++ { req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) @@ -570,7 +570,7 @@ func TestValuesFromForm(t *testing.T) { expectError: errFormExtractorValueMissing.Error(), }, { - name: "ok, cut values over extractorLimit", + name: "ok, cut values over headerCountLimit", givenRequest: examplePostFormRequest(func(v *url.Values) { for i := 1; i < 25; i++ { v.Add("id[]", fmt.Sprintf("%v", i))