Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hot-reload CORS origins #3423

Merged
merged 11 commits into from
Aug 16, 2023
26 changes: 16 additions & 10 deletions cmd/daemon/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"net/http"
"time"

"github.com/rs/cors"

"github.com/ory/x/otelx/semconv"

"github.com/pkg/errors"
"github.com/rs/cors"
"github.com/spf13/cobra"
"github.com/urfave/negroni"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
Expand Down Expand Up @@ -68,7 +69,7 @@ func WithContext(ctx stdctx.Context) Option {
}
}

func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *servicelocatorx.Options, opts []Option) error {
func ServePublic(r driver.Registry, cmd *cobra.Command, _ []string, slOpts *servicelocatorx.Options, opts []Option) error {
modifiers := NewOptions(cmd.Context(), opts)
ctx := modifiers.ctx

Expand Down Expand Up @@ -99,6 +100,16 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
router := x.NewRouterPublic()
csrf := x.NewCSRFHandler(router, r)

// we need to always load the CORS middleware even if it is disabled, to allow hot-enabling CORS
n.UseFunc(func(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
cfg, enabled := r.Config().CORS(req.Context(), "public")
if !enabled {
next(w, req)
return
}
cors.New(cfg).ServeHTTP(w, req, next)
})

n.UseFunc(x.CleanPath) // Prevent double slashes from breaking CSRF.
r.WithCSRFHandler(csrf)
n.UseHandler(r.CSRFHandler())
Expand All @@ -112,14 +123,9 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
r.RegisterPublicRoutes(ctx, router)
r.PrometheusManager().RegisterRouter(router.Router)

var handler http.Handler = n
options, enabled := r.Config().CORS(ctx, "public")
if enabled {
handler = cors.New(options).Handler(handler)
}

certs := c.GetTLSCertificatesForPublic(ctx)

var handler http.Handler = n
if tracer := r.Tracer(ctx); tracer.IsLoaded() {
handler = otelx.TraceHandler(handler, otelhttp.WithTracerProvider(tracer.Provider()))
}
Expand Down Expand Up @@ -152,7 +158,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
return nil
}

func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *servicelocatorx.Options, opts []Option) error {
func ServeAdmin(r driver.Registry, cmd *cobra.Command, _ []string, slOpts *servicelocatorx.Options, opts []Option) error {
modifiers := NewOptions(cmd.Context(), opts)
ctx := modifiers.ctx

Expand Down Expand Up @@ -299,7 +305,7 @@ func sqa(ctx stdctx.Context, cmd *cobra.Command, d driver.Registry) *metricsx.Se
)
}

func bgTasks(d driver.Registry, cmd *cobra.Command, args []string, slOpts *servicelocatorx.Options, opts []Option) error {
func bgTasks(d driver.Registry, cmd *cobra.Command, _ []string, _ *servicelocatorx.Options, opts []Option) error {
modifiers := NewOptions(cmd.Context(), opts)
ctx := modifiers.ctx

Expand Down
1 change: 1 addition & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ func New(ctx context.Context, l *logrusx.Logger, stdOutOrErr io.Writer, opts ...
configx.WithStderrValidationReporter(),
configx.OmitKeysFromTracing("dsn", "courier.smtp.connection_uri", "secrets.default", "secrets.cookie", "secrets.cipher", "client_secret"),
configx.WithImmutables("serve", "profiling", "log"),
configx.WithExceptImmutables("serve.public.cors.allowed_origins"),
configx.WithLogrusWatcher(l),
configx.WithLogger(l),
configx.WithContext(ctx),
Expand Down
28 changes: 26 additions & 2 deletions driver/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ import (
)

func TestViperProvider(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
t.Parallel()

t.Run("suite=loaders", func(t *testing.T) {
p := config.MustNew(t, logrusx.New("", ""), os.Stderr,
Expand Down Expand Up @@ -397,6 +397,7 @@ func TestViperProvider(t *testing.T) {
}

func TestBcrypt(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation())

Expand All @@ -409,6 +410,7 @@ func TestBcrypt(t *testing.T) {
}

func TestProviderBaseURLs(t *testing.T) {
t.Parallel()
ctx := context.Background()
machineHostname, err := os.Hostname()
if err != nil {
Expand Down Expand Up @@ -436,6 +438,7 @@ func TestProviderBaseURLs(t *testing.T) {
}

func TestProviderSelfServiceLinkMethodBaseURL(t *testing.T) {
t.Parallel()
ctx := context.Background()
machineHostname, err := os.Hostname()
if err != nil {
Expand All @@ -450,6 +453,7 @@ func TestProviderSelfServiceLinkMethodBaseURL(t *testing.T) {
}

func TestViperProvider_Secrets(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation())

Expand All @@ -464,6 +468,7 @@ func TestViperProvider_Secrets(t *testing.T) {
}

func TestViperProvider_Defaults(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")

Expand Down Expand Up @@ -573,6 +578,7 @@ func TestViperProvider_Defaults(t *testing.T) {
}

func TestViperProvider_ReturnTo(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")
p := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
Expand All @@ -589,6 +595,7 @@ func TestViperProvider_ReturnTo(t *testing.T) {
}

func TestSession(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")
p := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
Expand All @@ -615,6 +622,7 @@ func TestSession(t *testing.T) {
}

func TestCookies(t *testing.T) {
t.Parallel()
ctx := context.Background()
l := logrusx.New("", "")
p := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
Expand Down Expand Up @@ -660,6 +668,7 @@ func TestCookies(t *testing.T) {
}

func TestViperProvider_DSN(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=dsn: memory", func(t *testing.T) {
Expand Down Expand Up @@ -693,6 +702,8 @@ func TestViperProvider_DSN(t *testing.T) {
}

func TestViperProvider_ParseURIOrFail(t *testing.T) {
t.Parallel()

ctx := context.Background()
var exitCode int

Expand Down Expand Up @@ -750,6 +761,8 @@ func TestViperProvider_ParseURIOrFail(t *testing.T) {
}

func TestViperProvider_HaveIBeenPwned(t *testing.T) {
t.Parallel()

ctx := context.Background()
p := config.MustNew(t, logrusx.New("", ""), os.Stderr, configx.SkipValidation())
t.Run("case=hipb: host", func(t *testing.T) {
Expand Down Expand Up @@ -794,8 +807,8 @@ func newTestConfig(t *testing.T) (_ *config.Config, _ *test.Hook, exited *bool)
}

func TestLoadingTLSConfig(t *testing.T) {
ctx := context.Background()
t.Parallel()
ctx := context.Background()

certPath, keyPath, certBase64, keyBase64 := testhelpers.GenerateTLSCertificateFilesForTests(t)

Expand Down Expand Up @@ -888,6 +901,7 @@ func TestLoadingTLSConfig(t *testing.T) {
}

func TestIdentitySchemaValidation(t *testing.T) {
t.Parallel()
files := []string{"stub/.identity.test.json", "stub/.identity.other.json"}

ctx := context.Background()
Expand Down Expand Up @@ -1061,6 +1075,7 @@ func TestIdentitySchemaValidation(t *testing.T) {
}

func TestPasswordless(t *testing.T) {
t.Parallel()
ctx := context.Background()

conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr,
Expand All @@ -1074,6 +1089,7 @@ func TestPasswordless(t *testing.T) {
}

func TestChangeMinPasswordLength(t *testing.T) {
t.Parallel()
t.Run("case=must fail on minimum password length below enforced minimum", func(t *testing.T) {
ctx := context.Background()

Expand All @@ -1096,6 +1112,7 @@ func TestChangeMinPasswordLength(t *testing.T) {
}

func TestCourierEmailHTTP(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1113,6 +1130,7 @@ func TestCourierEmailHTTP(t *testing.T) {
}

func TestCourierSMS(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1133,6 +1151,7 @@ func TestCourierSMS(t *testing.T) {
}

func TestCourierSMTPUrl(t *testing.T) {
t.Parallel()
ctx := context.Background()

for _, tc := range []string{
Expand Down Expand Up @@ -1161,6 +1180,7 @@ func TestCourierSMTPUrl(t *testing.T) {
}

func TestCourierMessageTTL(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1176,6 +1196,7 @@ func TestCourierMessageTTL(t *testing.T) {
}

func TestOAuth2Provider(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=configs set", func(t *testing.T) {
Expand All @@ -1195,6 +1216,7 @@ func TestOAuth2Provider(t *testing.T) {
}

func TestWebauthn(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=multiple origins", func(t *testing.T) {
Expand Down Expand Up @@ -1240,6 +1262,7 @@ func TestWebauthn(t *testing.T) {
}

func TestCourierTemplatesConfig(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("case=partial template update allowed", func(t *testing.T) {
Expand Down Expand Up @@ -1294,6 +1317,7 @@ func TestCourierTemplatesConfig(t *testing.T) {
}

func TestCleanup(t *testing.T) {
t.Parallel()
ctx := context.Background()

p := config.MustNew(t, logrusx.New("", ""), os.Stderr,
Expand Down
4 changes: 0 additions & 4 deletions driver/config/schema.go

This file was deleted.

Loading
Loading