Skip to content

Commit

Permalink
feat: hot-reload CORS origins (#3423)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Aug 16, 2023
1 parent 09bcb71 commit 157d934
Show file tree
Hide file tree
Showing 10 changed files with 3,429 additions and 1,759 deletions.
26 changes: 16 additions & 10 deletions cmd/daemon/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"net/http"
"time"

"github.com/rs/cors"

"github.com/ory/x/otelx/semconv"

"github.com/pkg/errors"
"github.com/rs/cors"
"github.com/spf13/cobra"
"github.com/urfave/negroni"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
Expand Down Expand Up @@ -68,7 +69,7 @@ func WithContext(ctx stdctx.Context) Option {
}
}

func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *servicelocatorx.Options, opts []Option) error {
func ServePublic(r driver.Registry, cmd *cobra.Command, _ []string, slOpts *servicelocatorx.Options, opts []Option) error {
modifiers := NewOptions(cmd.Context(), opts)
ctx := modifiers.ctx

Expand Down Expand Up @@ -99,6 +100,16 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
router := x.NewRouterPublic()
csrf := x.NewCSRFHandler(router, r)

// we need to always load the CORS middleware even if it is disabled, to allow hot-enabling CORS
n.UseFunc(func(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
cfg, enabled := r.Config().CORS(req.Context(), "public")
if !enabled {
next(w, req)
return
}
cors.New(cfg).ServeHTTP(w, req, next)
})

n.UseFunc(x.CleanPath) // Prevent double slashes from breaking CSRF.
r.WithCSRFHandler(csrf)
n.UseHandler(r.CSRFHandler())
Expand All @@ -112,14 +123,9 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
r.RegisterPublicRoutes(ctx, router)
r.PrometheusManager().RegisterRouter(router.Router)

var handler http.Handler = n
options, enabled := r.Config().CORS(ctx, "public")
if enabled {
handler = cors.New(options).Handler(handler)
}

certs := c.GetTLSCertificatesForPublic(ctx)

var handler http.Handler = n
if tracer := r.Tracer(ctx); tracer.IsLoaded() {
handler = otelx.TraceHandler(handler, otelhttp.WithTracerProvider(tracer.Provider()))
}
Expand Down Expand Up @@ -152,7 +158,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
return nil
}

func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *servicelocatorx.Options, opts []Option) error {
func ServeAdmin(r driver.Registry, cmd *cobra.Command, _ []string, slOpts *servicelocatorx.Options, opts []Option) error {
modifiers := NewOptions(cmd.Context(), opts)
ctx := modifiers.ctx

Expand Down Expand Up @@ -299,7 +305,7 @@ func sqa(ctx stdctx.Context, cmd *cobra.Command, d driver.Registry) *metricsx.Se
)
}

func bgTasks(d driver.Registry, cmd *cobra.Command, args []string, slOpts *servicelocatorx.Options, opts []Option) error {
func bgTasks(d driver.Registry, cmd *cobra.Command, _ []string, _ *servicelocatorx.Options, opts []Option) error {
modifiers := NewOptions(cmd.Context(), opts)
ctx := modifiers.ctx

Expand Down
1 change: 1 addition & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ func New(ctx context.Context, l *logrusx.Logger, stdOutOrErr io.Writer, opts ...
configx.WithStderrValidationReporter(),
configx.OmitKeysFromTracing("dsn", "courier.smtp.connection_uri", "secrets.default", "secrets.cookie", "secrets.cipher", "client_secret"),
configx.WithImmutables("serve", "profiling", "log"),
configx.WithExceptImmutables("serve.public.cors.allowed_origins"),
configx.WithLogrusWatcher(l),
configx.WithLogger(l),
configx.WithContext(ctx),
Expand Down
28 changes: 26 additions & 2 deletions driver/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ import (
)

func TestViperProvider(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
t.Parallel()

t.Run("suite=loaders", func(t *testing.T) {
p := config.MustNew(t, logrusx.New("", ""), os.Stderr,
Expand Down Expand Up @@ -397,6 +397,7 @@ func TestViperProvider(t *testing.T) {
}

func TestBcrypt(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation())

Expand All @@ -409,6 +410,7 @@ func TestBcrypt(t *testing.T) {
}

func TestProviderBaseURLs(t *testing.T) {
t.Parallel()
ctx := context.Background()
machineHostname, err := os.Hostname()
if err != nil {
Expand Down Expand Up @@ -436,6 +438,7 @@ func TestProviderBaseURLs(t *testing.T) {
}

func TestProviderSelfServiceLinkMethodBaseURL(t *testing.T) {
t.Parallel()
ctx := context.Background()
machineHostname, err := os.Hostname()
if err != nil {
Expand All @@ -450,6 +453,7 @@ func TestProviderSelfServiceLinkMethodBaseURL(t *testing.T) {
}

func TestViperProvider_Secrets(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation())

Expand All @@ -464,6 +468,7 @@ func TestViperProvider_Secrets(t *testing.T) {
}

func TestViperProvider_Defaults(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")

Expand Down Expand Up @@ -573,6 +578,7 @@ func TestViperProvider_Defaults(t *testing.T) {
}

func TestViperProvider_ReturnTo(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")
p := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
Expand All @@ -589,6 +595,7 @@ func TestViperProvider_ReturnTo(t *testing.T) {
}

func TestSession(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")
p := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
Expand All @@ -615,6 +622,7 @@ func TestSession(t *testing.T) {
}

func TestCookies(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")
p := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
Expand Down Expand Up @@ -660,6 +668,7 @@ func TestCookies(t *testing.T) {
}

func TestViperProvider_DSN(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=dsn: memory", func(t *testing.T) {
Expand Down Expand Up @@ -693,6 +702,8 @@ func TestViperProvider_DSN(t *testing.T) {
}

func TestViperProvider_ParseURIOrFail(t *testing.T) {
t.Parallel()

ctx := context.Background()
var exitCode int

Expand Down Expand Up @@ -750,6 +761,8 @@ func TestViperProvider_ParseURIOrFail(t *testing.T) {
}

func TestViperProvider_HaveIBeenPwned(t *testing.T) {
t.Parallel()

ctx := context.Background()
p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation())
t.Run("case=hipb: host", func(t *testing.T) {
Expand Down Expand Up @@ -794,8 +807,8 @@ func newTestConfig(t *testing.T) (_ *config.Config, _ *test.Hook, exited *bool)
}

func TestLoadingTLSConfig(t *testing.T) {
ctx := context.Background()
t.Parallel()
ctx := context.Background()

certPath, keyPath, certBase64, keyBase64 := testhelpers.GenerateTLSCertificateFilesForTests(t)

Expand Down Expand Up @@ -888,6 +901,7 @@ func TestLoadingTLSConfig(t *testing.T) {
}

func TestIdentitySchemaValidation(t *testing.T) {
t.Parallel()
files := []string{"stub/.identity.test.json", "stub/.identity.other.json"}

ctx := context.Background()
Expand Down Expand Up @@ -1061,6 +1075,7 @@ func TestIdentitySchemaValidation(t *testing.T) {
}

func TestPasswordless(t *testing.T) {
t.Parallel()
ctx := context.Background()

conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr,
Expand All @@ -1074,6 +1089,7 @@ func TestPasswordless(t *testing.T) {
}

func TestChangeMinPasswordLength(t *testing.T) {
t.Parallel()
t.Run("case=must fail on minimum password length below enforced minimum", func(t *testing.T) {
ctx := context.Background()

Expand All @@ -1096,6 +1112,7 @@ func TestChangeMinPasswordLength(t *testing.T) {
}

func TestCourierEmailHTTP(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1113,6 +1130,7 @@ func TestCourierEmailHTTP(t *testing.T) {
}

func TestCourierSMS(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1133,6 +1151,7 @@ func TestCourierSMS(t *testing.T) {
}

func TestCourierSMTPUrl(t *testing.T) {
t.Parallel()
ctx := context.Background()

for _, tc := range []string{
Expand Down Expand Up @@ -1161,6 +1180,7 @@ func TestCourierSMTPUrl(t *testing.T) {
}

func TestCourierMessageTTL(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1176,6 +1196,7 @@ func TestCourierMessageTTL(t *testing.T) {
}

func TestOAuth2Provider(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1195,6 +1216,7 @@ func TestOAuth2Provider(t *testing.T) {
}

func TestWebauthn(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=multiple origins", func(t *testing.T) {
Expand Down Expand Up @@ -1240,6 +1262,7 @@ func TestWebauthn(t *testing.T) {
}

func TestCourierTemplatesConfig(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=partial template update allowed", func(t *testing.T) {
Expand Down Expand Up @@ -1294,6 +1317,7 @@ func TestCourierTemplatesConfig(t *testing.T) {
}

func TestCleanup(t *testing.T) {
t.Parallel()
ctx := context.Background()

p := config.MustNew(t, logrusx.New("", ""), os.Stderr,
Expand Down
4 changes: 0 additions & 4 deletions driver/config/schema.go

This file was deleted.

Loading

0 comments on commit 157d934

Please sign in to comment.