Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix(🩹): error handling in CSRF token storage retrieval #3021

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
)

var (
ErrTokenNotFound = errors.New("csrf token not found")
ErrTokenInvalid = errors.New("csrf token invalid")
ErrRefererNotFound = errors.New("referer not supplied")
ErrRefererInvalid = errors.New("referer invalid")
ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin")
ErrOriginInvalid = errors.New("origin invalid")
ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin")
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
dummyValue = []byte{'+'}
ErrTokenNotFound = errors.New("csrf token not found")
ErrTokenInvalid = errors.New("csrf token invalid")
ErrRefererNotFound = errors.New("referer not supplied")
ErrRefererInvalid = errors.New("referer invalid")
ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin")
ErrOriginInvalid = errors.New("origin invalid")
ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin")
ErrStorageRetrievalFailed = errors.New("unable to retrieve data from CSRF storage")

errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
dummyValue = []byte{'+'}
)

// Handler for CSRF middleware
Expand Down Expand Up @@ -103,10 +105,12 @@
switch c.Method() {
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
cookieToken := c.Cookies(cfg.CookieName)

if cookieToken != "" {
raw := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager)

raw, err := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager)
if err != nil {
println("hereee+" + err.Error())

Check failure on line 111 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

use of `println` forbidden by pattern `^print(ln)?$` (forbidigo)
return cfg.ErrorHandler(c, err)
}
if raw != nil {
token = cookieToken // Token is valid, safe to set it
}
Expand Down Expand Up @@ -149,14 +153,18 @@
return cfg.ErrorHandler(c, ErrTokenInvalid)
}

raw := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
raw, err := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
if err != nil {

Check failure on line 157 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

empty-lines: extra empty line at the start of a block (revive)

Check failure on line 157 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

return cfg.ErrorHandler(c, err)
} else if raw == nil {

Check failure on line 160 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

empty-lines: extra empty line at the start of a block (revive)

Check failure on line 160 in middleware/csrf/csrf.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

if raw == nil {
// If token is not in storage, expire the cookie
expireCSRFCookie(c, cfg)
// and return an error
return cfg.ErrorHandler(c, ErrTokenNotFound)
return cfg.ErrorHandler(c, ErrTokenInvalid)
}

if cfg.SingleUseToken {
// If token is single use, delete it from storage
deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
Expand Down Expand Up @@ -210,7 +218,7 @@

// getRawFromStorage returns the raw value from the storage for the given token
// returns nil if the token does not exist, is expired or is invalid
func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte {
func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) ([]byte, error) {
if cfg.Session != nil {
return sessionManager.getRaw(c, token, dummyValue)
}
Expand Down
66 changes: 65 additions & 1 deletion middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package csrf

import (
"fmt"
"net/http/httptest"
"strings"
"testing"
Expand Down Expand Up @@ -1263,7 +1264,6 @@
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]

// Exploit CSRF token we just injected
ctx.Request.Reset()
Expand Down Expand Up @@ -1509,3 +1509,67 @@
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

type mockStorage struct{}

func (m *mockStorage) Get(key string) ([]byte, error) {

Check failure on line 1515 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-receiver: method receiver 'm' is not referenced in method's body, consider removing or renaming it as _ (revive)

Check failure on line 1515 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'key' seems to be unused, consider removing or renaming it as _ (revive)
return nil, fmt.Errorf("not found")

Check failure on line 1516 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

fmt.Errorf can be replaced with errors.New (perfsprint)
}

func (m *mockStorage) Set(key string, val []byte, exp time.Duration) error {

Check failure on line 1519 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-receiver: method receiver 'm' is not referenced in method's body, consider removing or renaming it as _ (revive)

Check failure on line 1519 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'key' seems to be unused, consider removing or renaming it as _ (revive)
return nil
}

func (m *mockStorage) Delete(key string) error {
return nil
}

func (m *mockStorage) Reset() error {
return nil
}

func (m *mockStorage) Close() error {
return nil
}

func Test_NotGetTokenInSessionStorage(t *testing.T) {
t.Parallel()

errHandler := func(c fiber.Ctx, err error) error {
require.Equal(t, ErrStorageRetrievalFailed.Error(), err.Error())
return c.Status(419).Send([]byte(err.Error()))
}

// &session.Store{}.Storage.Set(ConfigDefault.CookieName, "fiber", 300)

app := fiber.New()
app.Use(New(Config{
ErrorHandler: errHandler,
Session: &session.Store{
Config: session.Config{
Storage: &mockStorage{},
KeyGenerator: ConfigDefault.KeyGenerator,
KeyLookup: ConfigDefault.KeyLookup,
Expiration: ConfigDefault.Expiration,
CookieSameSite: "Lax",
},
},
}))

app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}

ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, "fiber")
h(ctx)

require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, "invalid CSRF token", string(ctx.Response.Body()))

}
22 changes: 14 additions & 8 deletions middleware/csrf/session_manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package csrf

import (
"fmt"
"time"

"github.com/gofiber/fiber/v3"
Expand All @@ -26,20 +27,25 @@ func newSessionManager(s *session.Store, k string) *sessionManager {
}

// get token from session
func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte {
func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) {
sess, err := m.session.Get(c)
if err != nil {
return nil
return nil, ErrStorageRetrievalFailed
}

fmt.Println("key: ", sess)

token, ok := sess.Get(m.key).(Token)
if ok {
if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) {
return nil
}
return token.Raw
fmt.Println("key: ", token, ok)
if !ok {
return nil, ErrTokenInvalid
}

if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) {
return nil, ErrTokenInvalid
}

return nil
return token.Raw, nil
}

// set token in session
Expand Down
35 changes: 28 additions & 7 deletions middleware/csrf/storage_manager.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package csrf

import (
"fmt"
"sync"
"time"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/memory"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)

Expand Down Expand Up @@ -41,20 +43,35 @@ func newStorageManager(storage fiber.Storage) *storageManager {
}

// get raw data from storage or memory
func (m *storageManager) getRaw(key string) []byte {
var raw []byte
func (m *storageManager) getRaw(key string) ([]byte, error) {
var (
raw []byte
err error
)
if m.storage != nil {
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
raw, err = m.storage.Get(key)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrStorageRetrievalFailed, err.Error())
}
} else {
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
var ok bool
raw, ok = m.memory.Get(key).([]byte)
if !ok {
return nil, ErrStorageRetrievalFailed
}
}
return raw

return raw, nil
}

// set data to storage or memory
func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
err := m.storage.Set(key, raw, exp)
if err != nil {
log.Warnf("csrf: failed to save session in storage: %s", err.Error())
return
}
} else {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), raw, exp)
Expand All @@ -64,7 +81,11 @@ func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) {
// delete data from storage or memory
func (m *storageManager) delRaw(key string) {
if m.storage != nil {
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Do not ignore error
err := m.storage.Delete(key)
if err != nil {
log.Warnf("csrf: failed to delete session in storage: %s", err.Error())
return
}
} else {
m.memory.Delete(key)
}
Expand Down
Loading