Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Logout hook plugin (#611)
Browse files Browse the repository at this point in the history
* logout hook plugin
  • Loading branch information
iaroslav-ciupin authored Sep 12, 2023
1 parent b32a3d3 commit df88d23
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 14 deletions.
3 changes: 2 additions & 1 deletion auth/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,11 @@ func NewRedirectCookie(ctx context.Context, redirectURL string) *http.Cookie {
}
}

// GetAuthFlowEndRedirect returns the redirect URI according to data in request.
// At the end of the OAuth flow, the server needs to send the user somewhere. This should have been stored as a cookie
// during the initial /login call. If that cookie is missing from the request, it will default to the one configured
// in this package's Config object.
func getAuthFlowEndRedirect(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request) string {
func GetAuthFlowEndRedirect(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request) string {
queryParams := request.URL.Query()
// Use the redirect URL specified in the request if one is available.
if redirectURL := queryParams.Get(RedirectURLParameter); len(redirectURL) > 0 {
Expand Down
9 changes: 5 additions & 4 deletions auth/cookie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
"net/url"
"testing"

"github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
stdConfig "github.com/flyteorg/flytestdlib/config"
"github.com/gorilla/securecookie"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
)

func mustParseURL(t testing.TB, u string) url.URL {
Expand Down Expand Up @@ -131,7 +132,7 @@ func TestGetAuthFlowEndRedirect(t *testing.T) {
assert.NotNil(t, cookie)
request.AddCookie(cookie)
mockAuthCtx := &mocks.AuthenticationContext{}
redirect := getAuthFlowEndRedirect(ctx, mockAuthCtx, request)
redirect := GetAuthFlowEndRedirect(ctx, mockAuthCtx, request)
assert.Equal(t, "/console", redirect)
})

Expand All @@ -145,7 +146,7 @@ func TestGetAuthFlowEndRedirect(t *testing.T) {
RedirectURL: stdConfig.URL{URL: mustParseURL(t, "/api/v1/projects")},
},
})
redirect := getAuthFlowEndRedirect(ctx, mockAuthCtx, request)
redirect := GetAuthFlowEndRedirect(ctx, mockAuthCtx, request)
assert.Equal(t, "/api/v1/projects", redirect)
})
}
21 changes: 16 additions & 5 deletions auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func (e *PreRedirectHookError) Error() string {
// PreRedirectHookError is the error interface which allows the user to set correct http status code and Message to be set in case the function returns an error
// without which the current usage in GetCallbackHandler will set this to InternalServerError
type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) *PreRedirectHookError
type LogoutHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) error
type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD
type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error

Expand All @@ -68,7 +69,7 @@ func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer,
handler.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, authCtx))

// These endpoints require authentication
handler.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, authCtx))
handler.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, authCtx, pluginRegistry))
}

// Look for access token and refresh token, if both are present and the access token is expired, then attempt to
Expand Down Expand Up @@ -123,7 +124,7 @@ func RefreshTokensIfExists(ctx context.Context, authCtx interfaces.Authenticatio
return
}

redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request)
redirectURL := GetAuthFlowEndRedirect(ctx, authCtx, request)
http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect)
}
}
Expand Down Expand Up @@ -210,7 +211,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo
}
logger.Info(ctx, "Successfully called the preRedirect hook")
}
redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request)
redirectURL := GetAuthFlowEndRedirect(ctx, authCtx, request)
http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect)
}
}
Expand Down Expand Up @@ -466,9 +467,19 @@ func GetOIdCMetadataEndpointRedirectHandler(ctx context.Context, authCtx interfa
}
}

func GetLogoutEndpointHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc {
func GetLogoutEndpointHandler(ctx context.Context, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
logger.Debugf(ctx, "Deleting auth cookies")
hook := plugins.Get[LogoutHookFunc](pluginRegistry, plugins.PluginIDLogoutHook)
if hook != nil {
if err := hook(ctx, authCtx, request, writer); err != nil {
logger.Errorf(ctx, "logout hook failed: %v", err)
writer.WriteHeader(http.StatusInternalServerError)
return
}
logger.Debugf(ctx, "logout hook called")
}

logger.Debugf(ctx, "deleting auth cookies")
authCtx.CookieManager().DeleteCookies(ctx, writer)

// Redirect if one was given
Expand Down
101 changes: 97 additions & 4 deletions auth/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -11,8 +12,11 @@ import (
"testing"

"github.com/coreos/go-oidc"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
stdConfig "github.com/flyteorg/flytestdlib/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/structpb"

Expand All @@ -21,8 +25,6 @@ import (
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/plugins"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
stdConfig "github.com/flyteorg/flytestdlib/config"
)

const (
Expand Down Expand Up @@ -50,8 +52,8 @@ func setupMockedAuthContextAtEndpoint(endpoint string) *mocks.AuthenticationCont
Timeout: IdpConnectionTimeout,
}
mockAuthCtx.OnCookieManagerMatch().Return(mockCookieHandler)
mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config)
mockAuthCtx.OnGetHTTPClient().Return(dummyHTTPClient)
return mockAuthCtx
Expand Down Expand Up @@ -255,6 +257,97 @@ func TestGetLoginHandler(t *testing.T) {
assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), "flyte_csrf_state="))
}

func TestGetLogoutHandler(t *testing.T) {
ctx := context.Background()

t.Run("no_hook_no_redirect", func(t *testing.T) {
cookieHandler := &CookieManager{}
authCtx := mocks.AuthenticationContext{}
authCtx.OnCookieManager().Return(cookieHandler).Once()
w := httptest.NewRecorder()
r := plugins.NewRegistry()
req, err := http.NewRequest(http.MethodGet, "/logout", nil)
require.NoError(t, err)

GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)

assert.Equal(t, http.StatusOK, w.Code)
require.Len(t, w.Result().Cookies(), 3)
authCtx.AssertExpectations(t)
})

t.Run("no_hook_with_redirect", func(t *testing.T) {
ctx := context.Background()
cookieHandler := &CookieManager{}
authCtx := mocks.AuthenticationContext{}
authCtx.OnCookieManager().Return(cookieHandler).Once()
w := httptest.NewRecorder()
r := plugins.NewRegistry()
req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil)
require.NoError(t, err)

GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)

assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
authCtx.AssertExpectations(t)
require.Len(t, w.Result().Cookies(), 3)
})

t.Run("with_hook_with_redirect", func(t *testing.T) {
ctx := context.Background()
cookieHandler := &CookieManager{}
authCtx := mocks.AuthenticationContext{}
authCtx.OnCookieManager().Return(cookieHandler).Once()
w := httptest.NewRecorder()
r := plugins.NewRegistry()
hook := new(mock.Mock)
err := r.Register(plugins.PluginIDLogoutHook, LogoutHookFunc(func(
ctx context.Context,
authCtx interfaces.AuthenticationContext,
request *http.Request,
w http.ResponseWriter) error {
return hook.MethodCalled("hook").Error(0)
}))
hook.On("hook").Return(nil).Once()
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil)
require.NoError(t, err)

GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)

assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
require.Len(t, w.Result().Cookies(), 3)
authCtx.AssertExpectations(t)
hook.AssertExpectations(t)
})

t.Run("hook_error", func(t *testing.T) {
ctx := context.Background()
authCtx := mocks.AuthenticationContext{}
w := httptest.NewRecorder()
r := plugins.NewRegistry()
hook := new(mock.Mock)
err := r.Register(plugins.PluginIDLogoutHook, LogoutHookFunc(func(
ctx context.Context,
authCtx interfaces.AuthenticationContext,
request *http.Request,
w http.ResponseWriter) error {
return hook.MethodCalled("hook").Error(0)
}))
hook.On("hook").Return(errors.New("fail")).Once()
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil)
require.NoError(t, err)

GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)

assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Empty(t, w.Result().Cookies())
authCtx.AssertExpectations(t)
hook.AssertExpectations(t)
})
}

func TestGetHTTPRequestCookieToMetadataHandler(t *testing.T) {
ctx := context.Background()
// These were generated for unit testing only.
Expand Down
1 change: 1 addition & 0 deletions plugins/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const (
PluginIDDataProxy PluginID = "DataProxy"
PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware"
PluginIDPreRedirectHook PluginID = "PreRedirectHook"
PluginIDLogoutHook PluginID = "LogoutHook"
)

type AtomicRegistry struct {
Expand Down
20 changes: 20 additions & 0 deletions plugins/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ func TestRedirectHook(t *testing.T) {
assert.Equal(t, fmt.Errorf("redirect hook error"), err)
}

type LogoutHook func(context.Context) error

func TestLogoutHook(t *testing.T) {
ar := NewAtomicRegistry(nil)
r := NewRegistry()

hook := LogoutHook(func(ctx context.Context) error {
return fmt.Errorf("redirect hook error")
})
err := r.Register(PluginIDLogoutHook, hook)
assert.NoError(t, err)

ar.Store(r)
r = ar.Load()
fn := Get[LogoutHook](r, PluginIDLogoutHook)
err = fn(context.Background())

assert.Equal(t, fmt.Errorf("redirect hook error"), err)
}

func TestRegistry_RegisterDefault(t *testing.T) {
r := NewRegistry()
r.RegisterDefault("hello", 5)
Expand Down

0 comments on commit df88d23

Please sign in to comment.