Skip to content

Commit

Permalink
Merge branch 'main' into feat/authorization-and-roles-management-service
Browse files Browse the repository at this point in the history
  • Loading branch information
pyshx authored Oct 18, 2024
2 parents 8518ee5 + 351e923 commit 0ed1054
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 20 deletions.
6 changes: 4 additions & 2 deletions account/accountusecase/accountinteractor/user_signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ func (i *User) Signup(ctx context.Context, param accountinterfaces.SignupParam)
return nil, err
}

if err := i.sendVerificationMail(ctx, u, vr); err != nil {
return nil, err
if !param.MockAuth {
if err := i.sendVerificationMail(ctx, u, vr); err != nil {
return nil, err
}
}

return u, nil
Expand Down
1 change: 1 addition & 0 deletions account/accountusecase/accountinterfaces/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type SignupParam struct {
Theme *user.Theme
UserID *user.ID
WorkspaceID *workspace.ID
MockAuth bool
}

type UserFindOrCreateParam struct {
Expand Down
4 changes: 4 additions & 0 deletions account/accountusecase/accountproxy/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions account/user.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ input SignUpInput {
secret: String
lang: Lang
theme: Theme
mockAuth: Boolean
}

input SignupOIDCInput {
Expand Down
62 changes: 48 additions & 14 deletions appx/jwt_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"

"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/reearth/reearthx/log"
Expand All @@ -29,7 +30,7 @@ func NewJWTValidatorWithError(
audience []string,
opts ...validator.Option,
) (*JWTValidatorWithError, error) {
validator, err := validator.New(
v, err := validator.New(
keyFunc,
signatureAlgorithm,
issuerURL,
Expand All @@ -40,7 +41,7 @@ func NewJWTValidatorWithError(
return nil, err
}
return &JWTValidatorWithError{
validator: validator,
validator: v,
iss: issuerURL,
aud: slices.Clone(audience),
}, nil
Expand All @@ -49,9 +50,9 @@ func NewJWTValidatorWithError(
func (v *JWTValidatorWithError) ValidateToken(ctx context.Context, token string) (interface{}, error) {
res, err := v.validator.ValidateToken(ctx, token)
if err != nil {
err = fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err)
return nil, fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err)
}
return res, err
return res, nil
}

type JWTMultipleValidator []JWTValidator
Expand All @@ -62,20 +63,53 @@ func NewJWTMultipleValidator(providers []JWTProvider) (JWTMultipleValidator, err
})
}

// ValidateToken Trys to validate the token with each validator
// ValidateToken tries to validate the token with each validator concurrently
// NOTE: the last validation error only is returned
func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) {
func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (interface{}, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

type result struct {
res interface{}
err error
}

resultChan := make(chan result, len(mv))
var wg sync.WaitGroup

for _, v := range mv {
var err2 error
res, err2 = v.ValidateToken(ctx, tokenString)
if err2 == nil {
err = nil
return
wg.Add(1)
go func(validator JWTValidator) {
defer wg.Done()
res, err := validator.ValidateToken(ctx, tokenString)
select {
case resultChan <- result{res, err}:
case <-ctx.Done():
return
}
}(v)
}

go func() {
wg.Wait()
close(resultChan)
}()

var lastErr error
for i := 0; i < len(mv); i++ {
select {
case r := <-resultChan:
if r.err == nil {
cancel()
return r.res, nil
}
lastErr = errors.Join(lastErr, r.err)
case <-ctx.Done():
return nil, ctx.Err()
}
err = errors.Join(err, err2)
}

log.Debugfc(ctx, "auth: invalid JWT token: %s", tokenString)
log.Errorfc(ctx, "auth: invalid JWT token: %v", err)
return
log.Errorfc(ctx, "auth: invalid JWT token: %v", lastErr)
return nil, lastErr
}
130 changes: 129 additions & 1 deletion appx/jwt_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rsa"
"encoding/json"
"net/http"
"sync"
"testing"
"time"

Expand All @@ -23,7 +24,10 @@ func TestMultiValidator(t *testing.T) {
key := lo.Must(rsa.GenerateKey(rand.Reader, 2048))

httpmock.Activate()
defer httpmock.DeactivateAndReset()
t.Cleanup(func() {
httpmock.DeactivateAndReset()
})

httpmock.RegisterResponder(
http.MethodGet,
"https://example.com/.well-known/openid-configuration",
Expand Down Expand Up @@ -121,4 +125,128 @@ func TestMultiValidator(t *testing.T) {
res3, err := v.ValidateToken(context.Background(), tokenString3)
assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer)
assert.Nil(t, res3)

t.Run("all validators fail", func(t *testing.T) {
invalidTokenString := "invalid.token.string"

res, err := v.ValidateToken(context.Background(), invalidTokenString)
assert.Error(t, err)
assert.Nil(t, res)

// Check if the error is a combination of multiple errors
var multiErr interface{ Unwrap() []error }
assert.ErrorAs(t, err, &multiErr)
errs := multiErr.Unwrap()
assert.Len(t, errs, 2)

// Check if both errors are related to invalid token
for _, e := range errs {
assert.Contains(t, e.Error(), "invalid JWT")
}
})

t.Run("first validator succeeds", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.NotNil(t, res)
claims, ok := res.(*validator.ValidatedClaims)
assert.True(t, ok)
assert.Equal(t, "https://example.com/", claims.RegisteredClaims.Issuer)
})

t.Run("second validator succeeds", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.NotNil(t, res)
claims, ok := res.(*validator.ValidatedClaims)
assert.True(t, ok)
assert.Equal(t, "https://example.com/", claims.RegisteredClaims.Issuer)
})

t.Run("all validators fail", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example3.com/", AUD: []string{"d"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

res, err := v.ValidateToken(context.Background(), tokenString)
assert.Error(t, err)
assert.Nil(t, res)

var multiErr interface{ Unwrap() []error }
assert.ErrorAs(t, err, &multiErr)
errs := multiErr.Unwrap()
assert.Len(t, errs, 2)

for _, e := range errs {
assert.Contains(t, e.Error(), "invalid JWT")
}
})

t.Run("context cancellation", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
cancel()

res, err := v.ValidateToken(ctx, tokenString)
assert.Error(t, err)
assert.Nil(t, res)
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("mixed valid and invalid tokens", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

// Test with valid token
res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.NotNil(t, res)

// Test with invalid token
res, err = v.ValidateToken(context.Background(), "invalid.token")
assert.Error(t, err)
assert.Nil(t, res)
})

t.Run("concurrent validations", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.NotNil(t, res)
}()
}
wg.Wait()
})
}
6 changes: 3 additions & 3 deletions appx/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type TracerConfig struct {
TracerSample float64
}

func InitTracer(ctx context.Context, conf TracerConfig) io.Closer {
func InitTracer(ctx context.Context, conf *TracerConfig) io.Closer {
if conf.Tracer == TRACER_GCP {
initGCPTracer(ctx, conf)
} else if conf.Tracer == TRACER_JAEGER {
Expand All @@ -34,7 +34,7 @@ func InitTracer(ctx context.Context, conf TracerConfig) io.Closer {
return nil
}

func initGCPTracer(ctx context.Context, conf TracerConfig) {
func initGCPTracer(ctx context.Context, conf *TracerConfig) {
exporter, err := texporter.New()
if err != nil {
log.Fatalc(ctx, err)
Expand All @@ -50,7 +50,7 @@ func initGCPTracer(ctx context.Context, conf TracerConfig) {
log.Infofc(ctx, "tracer: initialized cloud trace with sample fraction: %g", conf.TracerSample)
}

func initJaegerTracer(conf TracerConfig) io.Closer {
func initJaegerTracer(conf *TracerConfig) io.Closer {
cfg := jaegercfg.Configuration{
Sampler: &jaegercfg.SamplerConfig{
Type: jaeger.SamplerTypeConst,
Expand Down
Loading

0 comments on commit 0ed1054

Please sign in to comment.