Skip to content

Commit

Permalink
expose source for KeyAuth/JWT key/token validation/parsing function t…
Browse files Browse the repository at this point in the history
…o allow custom logic depending from where key/token value was extracted
  • Loading branch information
aldas committed May 21, 2022
1 parent 0d85116 commit 2b4c5a4
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 50 deletions.
2 changes: 1 addition & 1 deletion middleware/basic_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {

b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
lastError = echo.ErrUnauthorized.WithInternal(fmt.Errorf("invalid basic auth value: %w", errDecode))
continue
}
idx := bytes.IndexByte(b, ':')
Expand Down
2 changes: 1 addition & 1 deletion middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestBasicAuth(t *testing.T) {
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
expectErr: "code=401, message=Unauthorized, internal=invalid basic auth value: illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
Expand Down
2 changes: 1 addition & 1 deletion middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastTokenErr error
outer:
for _, extractor := range extractors {
clientTokens, err := extractor(c)
clientTokens, _, err := extractor(c)
if err != nil {
lastExtractorErr = err
continue
Expand Down
56 changes: 37 additions & 19 deletions middleware/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ const (
extractorLimit = 20
)

// ExtractorSource is type to indicate source for extracted value
type ExtractorSource string

const (
// ExtractorSourceHeader means value was extracted from request header
ExtractorSourceHeader ExtractorSource = "header"
// ExtractorSourceQuery means value was extracted from request query parameters
ExtractorSourceQuery ExtractorSource = "query"
// ExtractorSourcePathParam means value was extracted from route path parameters
ExtractorSourcePathParam ExtractorSource = "param"
// ExtractorSourceCookie means value was extracted from request cookies
ExtractorSourceCookie ExtractorSource = "cookie"
// ExtractorSourceForm means value was extracted from request form values
ExtractorSourceForm ExtractorSource = "form"
// ExtractorSourceCustom means value was extracted by custom extractor
ExtractorSourceCustom ExtractorSource = "custom"
)

// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
type ValueExtractorError struct {
message string
Expand All @@ -31,7 +49,7 @@ var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing valu
var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}

// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c echo.Context) ([]string, error)
type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error)

func createExtractors(lookups string) ([]ValuesExtractor, error) {
if lookups == "" {
Expand Down Expand Up @@ -75,10 +93,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
values := c.Request().Header.Values(header)
if len(values) == 0 {
return nil, errHeaderExtractorValueMissing
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
}

result := make([]string, 0)
Expand All @@ -100,30 +118,30 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {

if len(result) == 0 {
if prefixLen > 0 {
return nil, errHeaderExtractorValueInvalid
return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
}
return nil, errHeaderExtractorValueMissing
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
}
return result, nil
return result, ExtractorSourceHeader, nil
}
}

// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, errQueryExtractorValueMissing
return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
} else if len(result) > extractorLimit-1 {
result = result[:extractorLimit]
}
return result, nil
return result, ExtractorSourceQuery, nil
}
}

// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
result := make([]string, 0)
for i, p := range c.PathParams() {
if param == p.Name {
Expand All @@ -134,18 +152,18 @@ func valuesFromParam(param string) ValuesExtractor {
}
}
if len(result) == 0 {
return nil, errParamExtractorValueMissing
return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
}
return result, nil
return result, ExtractorSourcePathParam, nil
}
}

// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, errCookieExtractorValueMissing
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
}

result := make([]string, 0)
Expand All @@ -158,26 +176,26 @@ func valuesFromCookie(name string) ValuesExtractor {
}
}
if len(result) == 0 {
return nil, errCookieExtractorValueMissing
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
}
return result, nil
return result, ExtractorSourceCookie, nil
}
}

// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
if c.Request().Form == nil {
_ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, errFormExtractorValueMissing
return nil, ExtractorSourceForm, errFormExtractorValueMissing
}
if len(values) > extractorLimit-1 {
values = values[:extractorLimit]
}
result := append([]string{}, values...)
return result, nil
return result, ExtractorSourceForm, nil
}
}
24 changes: 18 additions & 6 deletions middleware/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func TestCreateExtractors(t *testing.T) {
givenPathParams echo.PathParams
whenLoopups string
expectValues []string
expectSource ExtractorSource
expectCreateError string
expectError string
}{
Expand All @@ -32,6 +33,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
expectSource: ExtractorSourceHeader,
},
{
name: "ok, form",
Expand All @@ -45,6 +47,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
expectSource: ExtractorSourceForm,
},
{
name: "ok, cookie",
Expand All @@ -55,6 +58,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
expectSource: ExtractorSourceCookie,
},
{
name: "ok, param",
Expand All @@ -63,6 +67,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "param:id",
expectValues: []string{"123"},
expectSource: ExtractorSourcePathParam,
},
{
name: "ok, query",
Expand All @@ -72,6 +77,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "query:id",
expectValues: []string{"999"},
expectSource: ExtractorSourceQuery,
},
{
name: "nok, invalid lookup",
Expand Down Expand Up @@ -102,8 +108,9 @@ func TestCreateExtractors(t *testing.T) {
assert.NoError(t, err)

for _, e := range extractors {
values, eErr := e(c)
values, source, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, tc.expectSource, source)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
Expand Down Expand Up @@ -228,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) {

extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)

values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceHeader, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
Expand Down Expand Up @@ -289,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) {

extractor := valuesFromQuery(tc.whenName)

values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceQuery, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
Expand Down Expand Up @@ -368,8 +377,9 @@ func TestValuesFromParam(t *testing.T) {

extractor := valuesFromParam(tc.whenName)

values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourcePathParam, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
Expand Down Expand Up @@ -448,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) {

extractor := valuesFromCookie(tc.whenName)

values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceCookie, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
Expand Down Expand Up @@ -578,8 +589,9 @@ func TestValuesFromForm(t *testing.T) {

extractor := valuesFromForm(tc.whenName)

values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceForm, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
Expand Down
8 changes: 4 additions & 4 deletions middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type JWTConfig struct {
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
ParseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)
}

// JWTSuccessHandler defines a function which is executed for a valid token.
Expand Down Expand Up @@ -101,7 +101,7 @@ var DefaultJWTConfig = JWTConfig{
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
func JWT(parseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)) echo.MiddlewareFunc {
c := DefaultJWTConfig
c.ParseTokenFunc = parseTokenFunc
return JWTWithConfig(c)
Expand Down Expand Up @@ -152,13 +152,13 @@ func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastExtractorErr error
var lastTokenErr error
for _, extractor := range extractors {
auths, extrErr := extractor(c)
auths, source, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, auth := range auths {
token, err := config.ParseTokenFunc(c, auth)
token, err := config.ParseTokenFunc(c, auth, source)
if err != nil {
lastTokenErr = err
continue
Expand Down
4 changes: 2 additions & 2 deletions middleware/jwt_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKey is not provided
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
// keyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
Expand All @@ -41,7 +41,7 @@ func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]in
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}

return func(c echo.Context, auth string) (interface{}, error) {
return func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 2b4c5a4

Please sign in to comment.