Skip to content

Commit

Permalink
feat: support multiple extract token key
Browse files Browse the repository at this point in the history
Implement support for multiple custom token keys and simplify the JWT authentication configuration.
`WithTokenKeys` function enables setting token keys, improving the authentication process by accommodating various token header extraction strategies. by accommodating various token header extraction strategies.
  • Loading branch information
ch3nnn authored and kevwan committed Aug 27, 2024
1 parent 075817a commit 7d33874
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 33 deletions.
16 changes: 16 additions & 0 deletions rest/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ type (
PrivateKeys []PrivateKeyConf
}

// JWTConf Key and expiration time configuration required for JWT authentication
JWTConf struct {
AccessSecret string
AccessExpire int64
// extract a jwt from custom request header or url arguments
TokenKeys []string `json:",optional"`
}

// A JWTTransConf is a jwtTrans config.
JWTTransConf struct {
Secret string
PrevSecret string
// extract a jwt from custom request header or url arguments
TokenKeys []string `json:",optional"`
}

// A RestConf is a http service config.
// Why not name it as Conf, because we need to consider usage like:
// type Config struct {
Expand Down
17 changes: 10 additions & 7 deletions rest/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ func (ng *engine) addRoutes(r featuredRoutes) {
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(chain.Chain) chain.Chain) chain.Chain {
if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 {
chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} else {
chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithPrevSecret(fr.jwt.prevSecret),
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
authOpts := []handler.AuthorizeOption{
handler.WithUnauthorizedCallback(ng.unauthorizedCallback),
}
if len(fr.jwt.prevSecret) > 0 {
authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret))
}
if len(fr.jwt.tokenKeys) > 0 {
authOpts = append(authOpts, handler.WithTokenKeys(fr.jwt.tokenKeys))
}

chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...))
}

return verifier(chn)
Expand Down
15 changes: 14 additions & 1 deletion rest/handler/authhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type (
AuthorizeOptions struct {
PrevSecret string
Callback UnauthorizedCallback
TokenKeys []string
}

// UnauthorizedCallback defines the method of unauthorized callback.
Expand All @@ -48,7 +49,12 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
opt(&authOpts)
}

parser := token.NewTokenParser()
var parseOpts []token.ParseOption
if len(authOpts.TokenKeys) > 0 {
parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenKeys))
}

parser := token.NewTokenParser(parseOpts...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
Expand Down Expand Up @@ -97,6 +103,13 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
}
}

// WithTokenKeys custom token key
func WithTokenKeys(tokenKeys []string) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.TokenKeys = tokenKeys
}
}

func detailAuthLog(r *http.Request, reason string) {
// discard dump error, only for debug purpose
details, _ := httputil.DumpRequest(r, true)
Expand Down
17 changes: 10 additions & 7 deletions rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,27 @@ func WithFileServer(path string, fs http.FileSystem) RunOption {
}

// WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption {
func WithJwt(jwt JWTConf) RouteOption {
return func(r *featuredRoutes) {
validateSecret(secret)
validateSecret(jwt.AccessSecret)
r.jwt.enabled = true
r.jwt.secret = secret
r.jwt.secret = jwt.AccessSecret
r.jwt.tokenKeys = jwt.TokenKeys
}
}

// WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition.
// Which means old and new jwt secrets work together for a period.
func WithJwtTransition(secret, prevSecret string) RouteOption {
func WithJwtTransition(jwt JWTTransConf) RouteOption {
return func(r *featuredRoutes) {
// why not validate prevSecret, because prevSecret is an already used one,
// even it not meet our requirement, we still need to allow the transition.
validateSecret(secret)
validateSecret(jwt.Secret)
r.jwt.enabled = true
r.jwt.secret = secret
r.jwt.prevSecret = prevSecret
r.jwt.secret = jwt.Secret
r.jwt.prevSecret = jwt.PrevSecret
r.jwt.tokenKeys = jwt.TokenKeys

}
}

Expand Down
4 changes: 2 additions & 2 deletions rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ Port: 0
Method: http.MethodGet,
Path: "/",
Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone"))
}, WithJwt(JWTConf{AccessSecret: "thesecret"}), WithSignature(SignatureConf{}),
WithJwtTransition(JWTTransConf{Secret: "preivous", PrevSecret: "thenewone"}))

func() {
defer func() {
Expand Down
21 changes: 17 additions & 4 deletions rest/token/tokenparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type (
resetTime time.Duration
resetDuration time.Duration
history sync.Map
extractor request.MultiExtractor
}
)

Expand All @@ -30,6 +31,7 @@ func NewTokenParser(opts ...ParseOption) *TokenParser {
parser := &TokenParser{
resetTime: timex.Now(),
resetDuration: claimHistoryResetDuration,
extractor: request.MultiExtractor{request.AuthorizationHeaderExtractor},
}

for _, opt := range opts {
Expand Down Expand Up @@ -79,10 +81,11 @@ func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*
}

func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
func(token *jwt.Token) (any, error) {
return []byte(secret), nil
}, request.WithParser(newParser()))
keyFunc := func(token *jwt.Token) (any, error) {
return []byte(secret), nil
}

return request.ParseFromRequest(r, tp.extractor, keyFunc, request.WithParser(newParser()))
}

func (tp *TokenParser) incrementCount(secret string) {
Expand Down Expand Up @@ -119,6 +122,16 @@ func WithResetDuration(duration time.Duration) ParseOption {
}
}

func WithExtractor(tokenKeys []string) ParseOption {
return func(parser *TokenParser) {
parser.extractor = request.MultiExtractor{
request.HeaderExtractor(tokenKeys),
request.ArgumentExtractor(tokenKeys),
parser.extractor,
}
}
}

func newParser() *jwt.Parser {
return jwt.NewParser(jwt.WithJSONNumber())
}
46 changes: 46 additions & 0 deletions rest/token/tokenparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,52 @@ func TestTokenParser(t *testing.T) {
}
}

func TestTokenParser_CustomHeader(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)
req.Header.Set("Token", token)

parser := NewTokenParser(WithExtractor([]string{"Token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}

func TestTokenParser_URLArgument(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)

req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody)

parser := NewTokenParser(WithExtractor([]string{"token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}

func TestTokenParser_Expired(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
Expand Down
1 change: 1 addition & 0 deletions rest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type (
enabled bool
secret string
prevSecret string
tokenKeys []string
}

signatureSetting struct {
Expand Down
12 changes: 2 additions & 10 deletions tools/goctl/api/gogen/genconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,8 @@ import (
const (
configFile = "config"

jwtTemplate = ` struct {
AccessSecret string
AccessExpire int64
}
`
jwtTransTemplate = ` struct {
Secret string
PrevSecret string
}
`
jwtTemplate = ` rest.JWTConf`
jwtTransTemplate = ` rest.JWTTransConf`
)

//go:embed config.tpl
Expand Down
4 changes: 2 additions & 2 deletions tools/goctl/api/gogen/genroutes.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error

var jwt string
if g.jwtEnabled {
jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s.AccessSecret),", g.authName)
jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s),", g.authName)
}
if len(g.jwtTrans) > 0 {
jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s.PrevSecret,serverCtx.Config.%s.Secret),", g.jwtTrans, g.jwtTrans)
jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s),", g.jwtTrans)
}
var signature, prefix string
if g.signatureEnabled {
Expand Down

0 comments on commit 7d33874

Please sign in to comment.