diff --git a/docs/middleware/csrf.md b/docs/middleware/csrf.md index a034f9dfd7..8127432438 100644 --- a/docs/middleware/csrf.md +++ b/docs/middleware/csrf.md @@ -34,7 +34,7 @@ app.Use(csrf.New(csrf.Config{ KeyLookup: "header:X-Csrf-Token", CookieName: "csrf_", CookieSameSite: "Lax", - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, Extractor: func(c fiber.Ctx) (string, error) { ... }, })) @@ -106,15 +106,14 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error | CookieSecure | `bool` | Indicates if the CSRF cookie is secure. | false | | CookieHTTPOnly | `bool` | Indicates if the CSRF cookie is HTTP-only. | false | | CookieSameSite | `string` | Value of SameSite cookie. | "Lax" | -| CookieSessionOnly | `bool` | Decides whether the cookie should last for only the browser session. Ignores Expiration if set to true. | false | -| Expiration | `time.Duration` | Expiration is the duration before the CSRF token will expire. | 1 * time.Hour | +| CookieSessionOnly | `bool` | Decides whether the cookie should last for only the browser session. (cookie expires on close). | false | +| IdleTimeout | `time.Duration` | IdleTimeout is the duration of inactivity before the CSRF token will expire. | 30 * time.Minute | | KeyGenerator | `func() string` | KeyGenerator creates a new CSRF token. | utils.UUID | | ErrorHandler | `fiber.ErrorHandler` | ErrorHandler is executed when an error is returned from fiber.Handler. | DefaultErrorHandler | | Extractor | `func(fiber.Ctx) (string, error)` | Extractor returns the CSRF token. If set, this will be used in place of an Extractor based on KeyLookup. | Extractor based on KeyLookup | | SingleUseToken | `bool` | SingleUseToken indicates if the CSRF token be destroyed and a new one generated on each use. (See TokenLifecycle) | false | | Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` | | Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` | -| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" | | TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` | ### Default Config @@ -124,11 +123,10 @@ var ConfigDefault = Config{ KeyLookup: "header:" + HeaderName, CookieName: "csrf_", CookieSameSite: "Lax", - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), - SessionKey: "csrfToken", } ``` @@ -144,12 +142,11 @@ var ConfigDefault = Config{ CookieSecure: true, CookieSessionOnly: true, CookieHTTPOnly: true, - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), Session: session.Store, - SessionKey: "csrfToken", } ``` @@ -304,7 +301,7 @@ The Referer header is automatically included in requests by all modern browsers, ## Token Lifecycle -Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time. +Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 30 minutes, and each subsequent request extends the expiration by the idle timeout. The token only expires if the user doesn't make a request for the duration of the idle timeout. ### Token Reuse diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 39b9ccc801..ff73ff6094 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -2,142 +2,481 @@ id: session --- -# Session +# Session Middleware for [Fiber](https://github.com/gofiber/fiber) -Session middleware for [Fiber](https://github.com/gofiber/fiber). +The `session` middleware provides session management for Fiber applications, utilizing the [Storage](https://github.com/gofiber/storage) package for multi-database support via a unified interface. By default, session data is stored in memory, but custom storage options are easily configurable (see examples below). + +As of v3, we recommend using the middleware handler for session management. However, for backward compatibility, v2's session methods are still available, allowing you to continue using the session management techniques from earlier versions of Fiber. Both methods are demonstrated in the examples. + +## Table of Contents + +- [Migration Guide](#migration-guide) + - [v2 to v3](#v2-to-v3) +- [Types](#types) + - [Config](#config) + - [Middleware](#middleware) + - [Session](#session) + - [Store](#store) +- [Signatures](#signatures) + - [Session Package Functions](#session-package-functions) + - [Config Methods](#config-methods) + - [Middleware Methods](#middleware-methods) + - [Session Methods](#session-methods) + - [Store Methods](#store-methods) +- [Examples](#examples) + - [Middleware Handler (Recommended)](#middleware-handler-recommended) + - [Custom Storage Example](#custom-storage-example) + - [Session Without Middleware Handler](#session-without-middleware-handler) + - [Custom Types in Session Data](#custom-types-in-session-data) +- [Config](#config) +- [Default Config](#default-config) + +## Migration Guide + +### v2 to v3 + +- **Function Signature Change**: In v3, the `New` function now returns a middleware handler instead of a `*Store`. To access the store, use the `Store` method on `*Middleware` (obtained from `session.FromContext(c)` in a handler) or use `NewStore` or `NewWithStore`. + +- **Session Lifecycle Management**: The `*Store.Save` method no longer releases the instance automatically. You must manually call `sess.Release()` after using the session to manage its lifecycle properly. + +- **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with `IdleTimeout` and `AbsoluteTimeout` fields, which explicitly defines the session's idle and absolute timeout periods. + + - **Idle Timeout**: The new `IdleTimeout`, handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. + + - **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. + +For more details about Fiber v3, see [What’s New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). + +### Migrating v2 to v3 Example (Legacy Approach) + +To convert a v2 example to use the v3 legacy approach, follow these steps: + +1. **Initialize with Store**: Use `session.NewStore()` to obtain a store. +2. **Retrieve Session**: Access the session store using the `store.Get(c)` method. +3. **Release Session**: Ensure that you call `sess.Release()` after you are done with the session to manage its lifecycle. :::note -This middleware uses our [Storage](https://github.com/gofiber/storage) package to support various databases through a single interface. The default configuration for this middleware saves data to memory, see the examples below for other databases. +When using the legacy approach, the IdleTimeout will be updated when the session is saved. ::: +#### Example Conversion + +**v2 Example:** + +```go +store := session.New() + +app.Get("/", func(c *fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return err + } + + key, ok := sess.Get("key").(string) + if !ok { + return c.SendStatus(fiber.StatusInternalServerError) + } + + sess.Set("key", "value") + + err = sess.Save() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return nil +}) +``` + +**v3 Legacy Approach:** + +```go +store := session.NewStore() + +app.Get("/", func(c fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return err + } + defer sess.Release() // Important: Release the session + + key, ok := sess.Get("key").(string) + if !ok { + return c.SendStatus(fiber.StatusInternalServerError) + } + + sess.Set("key", "value") + + err = sess.Save() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return nil +}) +``` + +### v3 Example (Recommended Middleware Handler) + +Do not call `sess.Release()` when using the middleware handler. `sess.Save()` is also not required, as the middleware automatically saves the session data. + +For the recommended approach, use the middleware handler. See the [Middleware Handler (Recommended)](#middleware-handler-recommended) section for details. + +## Types + +### Config + +Defines the configuration options for the session middleware. + +```go +type Config struct { + Storage fiber.Storage + Next func(fiber.Ctx) bool + Store *Store + ErrorHandler func(fiber.Ctx, error) + KeyGenerator func() string + KeyLookup string + CookieDomain string + CookiePath string + CookieSameSite string + IdleTimeout time.Duration + AbsoluteTimeout time.Duration + CookieSecure bool + CookieHTTPOnly bool + CookieSessionOnly bool +} +``` + +### Middleware + +The `Middleware` struct encapsulates the session middleware configuration and storage, created via `New` or `NewWithStore`. + +```go +type Middleware struct { + Session *Session +} +``` + +### Session + +Represents a user session, accessible through `FromContext` or `Store.Get`. + +```go +type Session struct {} +``` + +### Store + +Handles session data management and is created using `NewStore`, `NewWithStore` or by accessing the `Store` method of a middleware instance. + +```go +type Store struct { + Config +} +``` + ## Signatures +### Session Package Functions + ```go -func New(config ...Config) *Store -func (s *Store) RegisterType(i any) -func (s *Store) Get(c fiber.Ctx) (*Session, error) -func (s *Store) Delete(id string) error -func (s *Store) Reset() error +func New(config ...Config) *Middleware +func NewWithStore(config ...Config) (fiber.Handler, *Store) +func FromContext(c fiber.Ctx) *Middleware +``` + +### Config Methods + +```go +func DefaultErrorHandler(fiber.Ctx, err error) +``` + +### Middleware Methods + +```go +func (m *Middleware) Set(key string, value any) +func (m *Middleware) Get(key string) any +func (m *Middleware) Delete(key string) +func (m *Middleware) Destroy() error +func (m *Middleware) Reset() error +func (m *Middleware) Store() *Store +``` +### Session Methods + +```go +func (s *Session) Fresh() bool +func (s *Session) ID() string func (s *Session) Get(key string) any func (s *Session) Set(key string, val any) -func (s *Session) Delete(key string) func (s *Session) Destroy() error -func (s *Session) Reset() error func (s *Session) Regenerate() error +func (s *Session) Release() +func (s *Session) Reset() error func (s *Session) Save() error -func (s *Session) Fresh() bool -func (s *Session) ID() string func (s *Session) Keys() []string -func (s *Session) SetExpiry(exp time.Duration) +func (s *Session) SetIdleTimeout(idleTimeout time.Duration) +``` + +### Store Methods + +```go +func (*Store) RegisterType(i any) +func (s *Store) Get(c fiber.Ctx) (*Session, error) +func (s *Store) GetByID(id string) (*Session, error) +func (s *Store) Reset() error +func (s *Store) Delete(id string) error ``` -:::caution -Storing `any` values are limited to built-ins Go types. +:::note + +#### `GetByID` Method + +The `GetByID` method retrieves a session from storage using its session ID. Unlike `Get`, which ties the session to a `fiber.Ctx` (request-response cycle), `GetByID` operates independently of any HTTP context. This makes it ideal for scenarios such as background processing, scheduled tasks, or non-HTTP-related session management. + +##### Key Features + +- **Context Independence**: Sessions retrieved via `GetByID` are not bound to `fiber.Ctx`. This means the session can be manipulated in contexts that aren't tied to an active HTTP request-response cycle. +- **Background Task Suitability**: Use this method when you need to manage sessions outside of the standard HTTP workflow, such as in scheduled jobs, background tasks, or any non-HTTP context where session data needs to be accessed or modified. + +##### Usage Considerations + +- **Manual Persistence**: Since there is no associated `fiber.Ctx`, changes made to the session (e.g., modifying data) will **not** automatically be saved to storage. You **must** call `session.Save()` explicitly to persist any updates to storage. +- **No Automatic Cookie Handling**: Any updates made to the session will **not** affect the client-side cookies. If the session changes need to be reflected in the client (e.g., in a future HTTP response), you will need to handle this manually by setting the cookies via other methods. +- **Resource Management**: After using a session retrieved by `GetByID`, you should call `session.Release()` to properly release the session back to the pool and free up resources. + +##### Example Use Cases + +- **Scheduled Jobs**: Retrieve and update session data periodically without triggering an HTTP request. +- **Background Processing**: Manage sessions for tasks running in the background, such as user inactivity checks or batch processing. + ::: ## Examples -Import the middleware package that is part of the Fiber web framework +:::note +**Security Notice**: For robust security, especially during sensitive operations like account changes or transactions, consider using CSRF protection. Fiber provides a [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) that can be used with sessions to prevent CSRF attacks. +::: + +:::note +**Middleware Order**: The order of middleware matters. The session middleware should come before any handler or middleware that uses the session (for example, the CSRF middleware). +::: + +### Middleware Handler (Recommended) ```go +package main + import ( "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/csrf" "github.com/gofiber/fiber/v3/middleware/session" ) + +func main() { + app := fiber.New() + + sessionMiddleware, sessionStore := session.NewWithStore() + + app.Use(sessionMiddleware) + app.Use(csrf.New(csrf.Config{ + Store: sessionStore, + })) + + app.Get("/", func(c fiber.Ctx) error { + sess := session.FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + name, ok := sess.Get("name").(string) + if !ok { + return c.SendString("Welcome anonymous user!") + } + + return c.SendString("Welcome " + name) + }) + + app.Listen(":3000") +} ``` -After you initiate your Fiber app, you can use the following possibilities: +### Custom Storage Example ```go -// Initialize default config -// This stores all of your app's sessions -store := session.New() +package main -app.Get("/", func(c fiber.Ctx) error { - // Get session from storage - sess, err := store.Get(c) - if err != nil { - panic(err) - } +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/storage/sqlite3" + "github.com/gofiber/fiber/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/session" +) - // Get value - name := sess.Get("name") +func main() { + app := fiber.New() - // Set key/value - sess.Set("name", "john") + storage := sqlite3.New() + sessionMiddleware, sessionStore := session.NewWithStore(session.Config{ + Storage: storage, + }) - // Get all Keys - keys := sess.Keys() + app.Use(sessionMiddleware) + app.Use(csrf.New(csrf.Config{ + Store: sessionStore, + })) - // Delete key - sess.Delete("name") + app.Listen(":3000") +} +``` - // Destroy session - if err := sess.Destroy(); err != nil { - panic(err) - } +### Session Without Middleware Handler - // Sets a specific expiration for this session - sess.SetExpiry(time.Second * 2) +```go +package main - // Save session - if err := sess.Save(); err != nil { - panic(err) - } +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/session" +) - return c.SendString(fmt.Sprintf("Welcome %v", name)) -}) -``` +func main() { + app := fiber.New() -## Config + sessionStore := session.NewStore() -| Property | Type | Description | Default | -|:------------------------|:----------------|:------------------------------------------------------------------------------------------------------------|:----------------------| -| Expiration | `time.Duration` | Allowed session duration. | `24 * time.Hour` | -| Storage | `fiber.Storage` | Storage interface to store the session data. | `memory.New()` | -| KeyLookup | `string` | KeyLookup is a string in the form of "`:`" that is used to extract session id from the request. | `"cookie:session_id"` | -| CookieDomain | `string` | Domain of the cookie. | `""` | -| CookiePath | `string` | Path of the cookie. | `""` | -| CookieSecure | `bool` | Indicates if cookie is secure. | `false` | -| CookieHTTPOnly | `bool` | Indicates if cookie is HTTP only. | `false` | -| CookieSameSite | `string` | Value of SameSite cookie. | `"Lax"` | -| CookieSessionOnly | `bool` | Decides whether cookie should last for only the browser session. Ignores Expiration if set to true. | `false` | -| KeyGenerator | `func() string` | KeyGenerator generates the session key. | `utils.UUIDv4` | -| CookieName (Deprecated) | `string` | Deprecated: Please use KeyLookup. The session name. | `""` | + app.Use(csrf.New(csrf.Config{ + Store: sessionStore, + })) -## Default Config + app.Get("/", func(c fiber.Ctx) error { + sess, err := sessionStore.Get(c) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + defer sess.Release() -```go -var ConfigDefault = Config{ - Expiration: 24 * time.Hour, - KeyLookup: "cookie:session_id", - KeyGenerator: utils.UUIDv4, - source: "cookie", - sessionName: "session_id", + name, ok := sess.Get("name").(string) + if !ok { + return c.SendString("Welcome anonymous user!") + } + + return c.SendString("Welcome " + name) + }) + + app.Post("/login", func(c fiber.Ctx) error { + sess, err := sessionStore.Get(c) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + defer sess.Release() + + if !sess.Fresh() { + if err := sess.Regenerate(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + } + + sess.Set("name", "John Doe") + + err = sess.Save() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return c.SendString("Logged in!") + }) + + app.Listen(":3000") } ``` -## Constants +### Custom Types in Session Data + +Session data can only be of the following types by default: + +- `string` +- `int` +- `int8` +- `int16` +- `int32` +- `int64` +- `uint` +- `uint8` +- `uint16` +- `uint32` +- `uint64` +- `bool` +- `float32` +- `float64` +- `[]byte` +- `complex64` +- `complex128` +- `interface{}` + +To support other types in session data, you can register custom types. Here is an example of how to register a custom type: ```go -const ( - SourceCookie Source = "cookie" - SourceHeader Source = "header" - SourceURLQuery Source = "query" +package main + +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/session" ) -``` -### Custom Storage/Database +type User struct { + Name string + Age int +} -You can use any storage from our [storage](https://github.com/gofiber/storage/) package. +func main() { + app := fiber.New() -```go -storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3 + sessionMiddleware, sessionStore := session.NewWithStore() + sessionStore.RegisterType(User{}) -store := session.New(session.Config{ - Storage: storage, -}) + app.Use(sessionMiddleware) + + app.Listen(":3000") +} ``` -To use the store, see the [Examples](#examples). +## Config + +| Property | Type | Description | Default | +|-----------------------|--------------------------------|--------------------------------------------------------------------------------------------|---------------------------| +| **Storage** | `fiber.Storage` | Defines where session data is stored. | `nil` (in-memory storage) | +| **Next** | `func(c fiber.Ctx) bool` | Function to skip this middleware under certain conditions. | `nil` | +| **ErrorHandler** | `func(c fiber.Ctx, err error)` | Custom error handler for session middleware errors. | `nil` | +| **KeyGenerator** | `func() string` | Function to generate session IDs. | `UUID()` | +| **KeyLookup** | `string` | Key used to store session ID in cookie or header. | `"cookie:session_id"` | +| **CookieDomain** | `string` | The domain scope of the session cookie. | `""` | +| **CookiePath** | `string` | The path scope of the session cookie. | `"/"` | +| **CookieSameSite** | `string` | The SameSite attribute of the session cookie. | `"Lax"` | +| **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `30 * time.Minute` | +| **AbsoluteTimeout** | `time.Duration` | Maximum duration before session expires. | `0` (no expiration) | +| **CookieSecure** | `bool` | Ensures session cookie is only sent over HTTPS. | `false` | +| **CookieHTTPOnly** | `bool` | Ensures session cookie is not accessible to JavaScript (HTTP only). | `true` | +| **CookieSessionOnly** | `bool` | Prevents session cookie from being saved after the session ends (cookie expires on close). | `false` | + +## Default Config + +```go +session.Config{ + Storage: memory.New(), + Next: nil, + Store: nil, + ErrorHandler: nil, + KeyGenerator: utils.UUIDv4, + KeyLookup: "cookie:session_id", + CookieDomain: "", + CookiePath: "", + CookieSameSite: "Lax", + IdleTimeout: 30 * time.Minute, + AbsoluteTimeout: 0, + CookieSecure: false, + CookieHTTPOnly: false, + CookieSessionOnly: false, +} +``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 963d1daece..4779a57364 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -30,6 +30,7 @@ Here's a quick overview of the changes in Fiber `v3`: - [🧰 Generic functions](#-generic-functions) - [🧬 Middlewares](#-middlewares) - [CORS](#cors) + - [CSRF](#csrf) - [Session](#session) - [Filesystem](#filesystem) - [Monitor](#monitor) @@ -314,9 +315,19 @@ Added support for specifying Key length when using `encryptcookie.GenerateKey(le ### Session -:::caution -DRAFT section -::: +The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. + +#### Key Updates + +- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration. + +- **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. + +- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. + +- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. + +For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). ### Filesystem @@ -494,6 +505,24 @@ app.Use(cors.New(cors.Config{ })) ``` +#### CSRF + +- **Field Renaming**: The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes. Update your code as follows: + +```go +// Before +app.Use(csrf.New(csrf.Config{ + Expiration: 10 * time.Minute, +})) + +// After +app.Use(csrf.New(csrf.Config{ + IdleTimeout: 10 * time.Minute, +})) +``` + +- **Session Key Removal**: The `SessionKey` field has been removed from the CSRF middleware configuration. The session key is now an unexported constant within the middleware to avoid potential key collisions in the session store. + #### Filesystem You need to move filesystem middleware to static middleware due to it has been removed from the core. diff --git a/middleware/csrf/config.go b/middleware/csrf/config.go index d37c33a58e..e718b15874 100644 --- a/middleware/csrf/config.go +++ b/middleware/csrf/config.go @@ -78,11 +78,6 @@ type Config struct { // Optional. Default value "Lax". CookieSameSite string - // SessionKey is the key used to store the token in the session - // - // Default: "csrfToken" - SessionKey string - // TrustedOrigins is a list of trusted origins for unsafe requests. // For requests that use the Origin header, the origin must match the // Host header or one of the TrustedOrigins. @@ -96,10 +91,10 @@ type Config struct { // Optional. Default: [] TrustedOrigins []string - // Expiration is the duration before csrf token will expire + // IdleTimeout is the duration of time the CSRF token is valid. // - // Optional. Default: 1 * time.Hour - Expiration time.Duration + // Optional. Default: 30 * time.Minute + IdleTimeout time.Duration // Indicates if CSRF cookie is secure. // Optional. Default value false. @@ -127,11 +122,10 @@ var ConfigDefault = Config{ KeyLookup: "header:" + HeaderName, CookieName: "csrf_", CookieSameSite: "Lax", - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), - SessionKey: "csrfToken", } // default ErrorHandler that process return error from fiber.Handler @@ -153,8 +147,8 @@ func configDefault(config ...Config) Config { if cfg.KeyLookup == "" { cfg.KeyLookup = ConfigDefault.KeyLookup } - if int(cfg.Expiration.Seconds()) <= 0 { - cfg.Expiration = ConfigDefault.Expiration + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = ConfigDefault.IdleTimeout } if cfg.CookieName == "" { cfg.CookieName = ConfigDefault.CookieName @@ -168,9 +162,6 @@ func configDefault(config ...Config) Config { if cfg.ErrorHandler == nil { cfg.ErrorHandler = ConfigDefault.ErrorHandler } - if cfg.SessionKey == "" { - cfg.SessionKey = ConfigDefault.SessionKey - } // Generate the correct extractor to get the token from the correct location selectors := strings.Split(cfg.KeyLookup, ":") diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index d417730416..dedfe6bd55 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -49,10 +49,7 @@ func New(config ...Config) fiber.Handler { var sessionManager *sessionManager var storageManager *storageManager if cfg.Session != nil { - // Register the Token struct in the session store - cfg.Session.RegisterType(Token{}) - - sessionManager = newSessionManager(cfg.Session, cfg.SessionKey) + sessionManager = newSessionManager(cfg.Session) } else { storageManager = newStorageManager(cfg.Storage) } @@ -220,9 +217,9 @@ func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *se // createOrExtendTokenInStorage creates or extends the token in the storage func createOrExtendTokenInStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) { if cfg.Session != nil { - sessionManager.setRaw(c, token, dummyValue, cfg.Expiration) + sessionManager.setRaw(c, token, dummyValue, cfg.IdleTimeout) } else { - storageManager.setRaw(token, dummyValue, cfg.Expiration) + storageManager.setRaw(token, dummyValue, cfg.IdleTimeout) } } @@ -237,7 +234,7 @@ func deleteTokenFromStorage(c fiber.Ctx, token string, cfg Config, sessionManage // Update CSRF cookie // if expireCookie is true, the cookie will expire immediately func updateCSRFCookie(c fiber.Ctx, cfg Config, token string) { - setCSRFCookie(c, cfg, token, cfg.Expiration) + setCSRFCookie(c, cfg, token, cfg.IdleTimeout) } func expireCSRFCookie(c fiber.Ctx, cfg Config) { diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 82252549bd..090082f4d8 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -70,7 +70,7 @@ func Test_CSRF_WithSession(t *testing.T) { t.Parallel() // session store - store := session.New(session.Config{ + store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) @@ -156,13 +156,68 @@ func Test_CSRF_WithSession(t *testing.T) { } } +// go test -run Test_CSRF_WithSession_Middleware +func Test_CSRF_WithSession_Middleware(t *testing.T) { + t.Parallel() + app := fiber.New() + + // session mw + smh, sstore := session.NewWithStore() + + // csrf mw + cmh := New(Config{ + Session: sstore, + }) + + app.Use(smh) + + app.Use(cmh) + + app.Get("/", func(c fiber.Ctx) error { + sess := session.FromContext(c) + sess.Set("hello", "world") + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/", func(c fiber.Ctx) error { + sess := session.FromContext(c) + if sess.Get("hello") != "world" { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + // Generate CSRF token and session_id + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + csrfTokenParts := strings.Split(string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)), ";") + require.Greater(t, len(csrfTokenParts), 2) + csrfToken := strings.Split(csrfTokenParts[0], "=")[1] + require.NotEmpty(t, csrfToken) + sessionID := strings.Split(csrfTokenParts[1], "=")[1] + require.NotEmpty(t, sessionID) + + // Use the CSRF token and session_id + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.Set(HeaderName, csrfToken) + ctx.Request.Header.SetCookie(ConfigDefault.CookieName, csrfToken) + ctx.Request.Header.SetCookie("session_id", sessionID) + h(ctx) + require.Equal(t, 200, ctx.Response.StatusCode()) +} + // go test -run Test_CSRF_ExpiredToken func Test_CSRF_ExpiredToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ - Expiration: 1 * time.Second, + IdleTimeout: 1 * time.Second, })) app.Post("/", func(c fiber.Ctx) error { @@ -205,7 +260,7 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { t.Parallel() // session store - store := session.New(session.Config{ + store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) @@ -229,8 +284,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { // middleware config config := Config{ - Session: store, - Expiration: 1 * time.Second, + Session: store, + IdleTimeout: 1 * time.Second, } // middleware @@ -1076,7 +1131,7 @@ func Test_CSRF_DeleteToken_WithSession(t *testing.T) { t.Parallel() // session store - store := session.New(session.Config{ + store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 3bbf173a26..8961c6a542 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -10,28 +10,46 @@ import ( type sessionManager struct { session *session.Store - key string } -func newSessionManager(s *session.Store, k string) *sessionManager { +type sessionKeyType int + +const ( + sessionKey sessionKeyType = 0 +) + +func newSessionManager(s *session.Store) *sessionManager { // Create new storage handler - sessionManager := &sessionManager{ - key: k, - } + sessionManager := new(sessionManager) if s != nil { // Use provided storage if provided sessionManager.session = s + + // Register the sessionKeyType and Token type + s.RegisterType(sessionKeyType(0)) + s.RegisterType(Token{}) } return sessionManager } // get token from session func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { - sess, err := m.session.Get(c) - if err != nil { - return nil + sess := session.FromContext(c) + var token Token + var ok bool + + if sess != nil { + token, ok = sess.Get(sessionKey).(Token) + } else { + // Try to get the session from the store + storeSess, err := m.session.Get(c) + if err != nil { + // Handle error + return nil + } + token, ok = storeSess.Get(sessionKey).(Token) } - token, ok := sess.Get(m.key).(Token) + if ok { if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { return nil @@ -44,25 +62,39 @@ func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { // set token in session func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Duration) { - sess, err := m.session.Get(c) - if err != nil { - return - } - // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here - sess.Set(m.key, &Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) - if err := sess.Save(); err != nil { - log.Warn("csrf: failed to save session: ", err) + sess := session.FromContext(c) + if sess != nil { + // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here + sess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) + } else { + // Try to get the session from the store + storeSess, err := m.session.Get(c) + if err != nil { + // Handle error + return + } + storeSess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) + if err := storeSess.Save(); err != nil { + log.Warn("csrf: failed to save session: ", err) + } } } // delete token from session func (m *sessionManager) delRaw(c fiber.Ctx) { - sess, err := m.session.Get(c) - if err != nil { - return - } - sess.Delete(m.key) - if err := sess.Save(); err != nil { - log.Warn("csrf: failed to save session: ", err) + sess := session.FromContext(c) + if sess != nil { + sess.Delete(sessionKey) + } else { + // Try to get the session from the store + storeSess, err := m.session.Get(c) + if err != nil { + // Handle error + return + } + storeSess.Delete(sessionKey) + if err := storeSess.Save(); err != nil { + log.Warn("csrf: failed to save session: ", err) + } } } diff --git a/middleware/session/config.go b/middleware/session/config.go index 1eabc05bd4..c2a115d732 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -5,60 +5,98 @@ import ( "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) -// Config defines the config for middleware. +// Config defines the configuration for the session middleware. type Config struct { - // Storage interface to store the session data - // Optional. Default value memory.New() + // Storage interface for storing session data. + // + // Optional. Default: memory.New() Storage fiber.Storage + // Next defines a function to skip this middleware when it returns true. + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // Store defines the session store. + // + // Required. + Store *Store + + // ErrorHandler defines a function to handle errors. + // + // Optional. Default: nil + ErrorHandler func(fiber.Ctx, error) + // KeyGenerator generates the session key. - // Optional. Default value utils.UUIDv4 + // + // Optional. Default: utils.UUIDv4 KeyGenerator func() string - // KeyLookup is a string in the form of ":" that is used - // to extract session id from the request. - // Possible values: "header:", "query:" or "cookie:" - // Optional. Default value "cookie:session_id". + // KeyLookup is a string in the format ":" used to extract the session ID from the request. + // + // Possible values: "header:", "query:", "cookie:" + // + // Optional. Default: "cookie:session_id" KeyLookup string - // Domain of the cookie. - // Optional. Default value "". + // CookieDomain defines the domain of the session cookie. + // + // Optional. Default: "" CookieDomain string - // Path of the cookie. - // Optional. Default value "". + // CookiePath defines the path of the session cookie. + // + // Optional. Default: "" CookiePath string - // Value of SameSite cookie. - // Optional. Default value "Lax". + // CookieSameSite specifies the SameSite attribute of the cookie. + // + // Optional. Default: "Lax" CookieSameSite string - // Source defines where to obtain the session id + // Source defines where to obtain the session ID. source Source - // The session name + // sessionName is the name of the session. sessionName string - // Allowed session duration - // Optional. Default value 24 * time.Hour - Expiration time.Duration - // Indicates if cookie is secure. - // Optional. Default value false. + // IdleTimeout defines the maximum duration of inactivity before the session expires. + // + // Note: The idle timeout is updated on each `Save()` call. If a middleware handler is used, `Save()` is called automatically. + // + // Optional. Default: 30 * time.Minute + IdleTimeout time.Duration + + // AbsoluteTimeout defines the maximum duration of the session before it expires. + // + // If set to 0, the session will not have an absolute timeout, and will expire after the idle timeout. + // + // Optional. Default: 0 + AbsoluteTimeout time.Duration + + // CookieSecure specifies if the session cookie should be secure. + // + // Optional. Default: false CookieSecure bool - // Indicates if cookie is HTTP only. - // Optional. Default value false. + // CookieHTTPOnly specifies if the session cookie should be HTTP-only. + // + // Optional. Default: false CookieHTTPOnly bool - // Decides whether cookie should last for only the browser sesison. - // Ignores Expiration if set to true - // Optional. Default value false. + // CookieSessionOnly determines if the cookie should expire when the browser session ends. + // + // If true, the cookie will be deleted when the browser is closed. + // Note: This will not delete the session data from the store. + // + // Optional. Default: false CookieSessionOnly bool } +// Source represents the type of session ID source. type Source string const ( @@ -67,28 +105,59 @@ const ( SourceURLQuery Source = "query" ) -// ConfigDefault is the default config +// ConfigDefault provides the default configuration. var ConfigDefault = Config{ - Expiration: 24 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyLookup: "cookie:session_id", KeyGenerator: utils.UUIDv4, - source: "cookie", + source: SourceCookie, sessionName: "session_id", } -// Helper function to set default values +// DefaultErrorHandler logs the error and sends a 500 status code. +// +// Parameters: +// - c: The Fiber context. +// - err: The error to handle. +// +// Usage: +// +// DefaultErrorHandler(c, err) +func DefaultErrorHandler(c fiber.Ctx, err error) { + log.Errorf("session: %v", err) + if sendErr := c.SendStatus(fiber.StatusInternalServerError); sendErr != nil { + log.Errorf("session: %v", sendErr) + } +} + +// configDefault sets default values for the Config struct. +// +// Parameters: +// - config: Variadic parameter to override the default config. +// +// Returns: +// - Config: The configuration with default values set. +// +// Usage: +// +// cfg := configDefault() +// cfg := configDefault(customConfig) func configDefault(config ...Config) Config { - // Return default config if nothing provided + // Return default config if none provided. if len(config) < 1 { return ConfigDefault } - // Override default config + // Override default config with provided config. cfg := config[0] - // Set default values - if int(cfg.Expiration.Seconds()) <= 0 { - cfg.Expiration = ConfigDefault.Expiration + // Set default values where necessary. + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = ConfigDefault.IdleTimeout + } + // Ensure AbsoluteTimeout is greater than or equal to IdleTimeout. + if cfg.AbsoluteTimeout > 0 && cfg.AbsoluteTimeout < cfg.IdleTimeout { + panic("[session] AbsoluteTimeout must be greater than or equal to IdleTimeout") } if cfg.KeyLookup == "" { cfg.KeyLookup = ConfigDefault.KeyLookup @@ -97,10 +166,11 @@ func configDefault(config ...Config) Config { cfg.KeyGenerator = ConfigDefault.KeyGenerator } + // Parse KeyLookup into source and session name. selectors := strings.Split(cfg.KeyLookup, ":") const numSelectors = 2 if len(selectors) != numSelectors { - panic("[session] KeyLookup must in the form of :") + panic("[session] KeyLookup must be in the format ':'") } switch Source(selectors[0]) { case SourceCookie: @@ -110,7 +180,7 @@ func configDefault(config ...Config) Config { case SourceURLQuery: cfg.source = SourceURLQuery default: - panic("[session] source is not supported") + panic("[session] unsupported source in KeyLookup") } cfg.sessionName = selectors[1] diff --git a/middleware/session/config_test.go b/middleware/session/config_test.go new file mode 100644 index 0000000000..c87ecef258 --- /dev/null +++ b/middleware/session/config_test.go @@ -0,0 +1,59 @@ +package session + +import ( + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func TestConfigDefault(t *testing.T) { + // Test default config + cfg := configDefault() + require.Equal(t, 30*time.Minute, cfg.IdleTimeout) + require.Equal(t, "cookie:session_id", cfg.KeyLookup) + require.NotNil(t, cfg.KeyGenerator) + require.Equal(t, SourceCookie, cfg.source) + require.Equal(t, "session_id", cfg.sessionName) +} + +func TestConfigDefaultWithCustomConfig(t *testing.T) { + // Test custom config + customConfig := Config{ + IdleTimeout: 48 * time.Hour, + KeyLookup: "header:custom_session_id", + KeyGenerator: func() string { return "custom_key" }, + } + cfg := configDefault(customConfig) + require.Equal(t, 48*time.Hour, cfg.IdleTimeout) + require.Equal(t, "header:custom_session_id", cfg.KeyLookup) + require.NotNil(t, cfg.KeyGenerator) + require.Equal(t, SourceHeader, cfg.source) + require.Equal(t, "custom_session_id", cfg.sessionName) +} + +func TestDefaultErrorHandler(t *testing.T) { + // Create a new Fiber app + app := fiber.New() + + // Create a new context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + // Test DefaultErrorHandler + DefaultErrorHandler(ctx, fiber.ErrInternalServerError) + require.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode()) +} + +func TestInvalidKeyLookupFormat(t *testing.T) { + require.PanicsWithValue(t, "[session] KeyLookup must be in the format ':'", func() { + configDefault(Config{KeyLookup: "invalid_format"}) + }) +} + +func TestUnsupportedSource(t *testing.T) { + require.PanicsWithValue(t, "[session] unsupported source in KeyLookup", func() { + configDefault(Config{KeyLookup: "unsupported:session_id"}) + }) +} diff --git a/middleware/session/data.go b/middleware/session/data.go index 08cb833f4e..052e43bc1b 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -8,57 +8,120 @@ import ( // //go:generate msgp -o=data_msgp.go -tests=true -unexported type data struct { - Data map[string]any + Data map[any]any sync.RWMutex `msg:"-"` } var dataPool = sync.Pool{ New: func() any { d := new(data) - d.Data = make(map[string]any) + d.Data = make(map[any]any) return d }, } +// acquireData returns a new data object from the pool. +// +// Returns: +// - *data: The data object. +// +// Usage: +// +// d := acquireData() func acquireData() *data { - return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool + obj := dataPool.Get() + if d, ok := obj.(*data); ok { + return d + } + // Handle unexpected type in the pool + panic("unexpected type in data pool") } +// Reset clears the data map and resets the data object. +// +// Usage: +// +// d.Reset() func (d *data) Reset() { d.Lock() - d.Data = make(map[string]any) - d.Unlock() + defer d.Unlock() + d.Data = make(map[any]any) } -func (d *data) Get(key string) any { +// Get retrieves a value from the data map by key. +// +// Parameters: +// - key: The key to retrieve. +// +// Returns: +// - any: The value associated with the key. +// +// Usage: +// +// value := d.Get("key") +func (d *data) Get(key any) any { d.RLock() - v := d.Data[key] - d.RUnlock() - return v + defer d.RUnlock() + return d.Data[key] } -func (d *data) Set(key string, value any) { +// Set updates or creates a new key-value pair in the data map. +// +// Parameters: +// - key: The key to set. +// - value: The value to set. +// +// Usage: +// +// d.Set("key", "value") +func (d *data) Set(key, value any) { d.Lock() + defer d.Unlock() d.Data[key] = value - d.Unlock() } -func (d *data) Delete(key string) { +// Delete removes a key-value pair from the data map. +// +// Parameters: +// - key: The key to delete. +// +// Usage: +// +// d.Delete("key") +func (d *data) Delete(key any) { d.Lock() + defer d.Unlock() delete(d.Data, key) - d.Unlock() } -func (d *data) Keys() []string { - d.Lock() - keys := make([]string, 0, len(d.Data)) +// Keys retrieves all keys in the data map. +// +// Returns: +// - []any: A slice of all keys in the data map. +// +// Usage: +// +// keys := d.Keys() +func (d *data) Keys() []any { + d.RLock() + defer d.RUnlock() + keys := make([]any, 0, len(d.Data)) for k := range d.Data { keys = append(keys, k) } - d.Unlock() return keys } +// Len returns the number of key-value pairs in the data map. +// +// Returns: +// - int: The number of key-value pairs. +// +// Usage: +// +// length := d.Len() func (d *data) Len() int { + d.RLock() + defer d.RUnlock() return len(d.Data) } diff --git a/middleware/session/data_msgp.go b/middleware/session/data_msgp.go index a93ffcfb27..a640e141b8 100644 --- a/middleware/session/data_msgp.go +++ b/middleware/session/data_msgp.go @@ -24,36 +24,6 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { return } switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[string]interface{}, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - zb0002-- - var za0001 string - var za0002 interface{} - za0001, err = dc.ReadString() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } default: err = dc.Skip() if err != nil { @@ -66,48 +36,22 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { } // EncodeMsg implements msgp.Encodable -func (z *data) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 1 - // write "Data" - err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) +func (z data) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 0 + _ = z + err = en.Append(0x80) if err != nil { return } - err = en.WriteMapHeader(uint32(len(z.Data))) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - for za0001, za0002 := range z.Data { - err = en.WriteString(za0001) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - err = en.WriteIntf(za0002) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - } return } // MarshalMsg implements msgp.Marshaler -func (z *data) MarshalMsg(b []byte) (o []byte, err error) { +func (z data) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 1 - // string "Data" - o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - o = msgp.AppendMapHeader(o, uint32(len(z.Data))) - for za0001, za0002 := range z.Data { - o = msgp.AppendString(o, za0001) - o, err = msgp.AppendIntf(o, za0002) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - } + // map header, size 0 + _ = z + o = append(o, 0x80) return } @@ -129,36 +73,6 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { return } switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[string]interface{}, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - var za0001 string - var za0002 interface{} - zb0002-- - za0001, bts, err = msgp.ReadStringBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } default: bts, err = msgp.Skip(bts) if err != nil { @@ -172,13 +86,7 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -func (z *data) Msgsize() (s int) { - s = 1 + 5 + msgp.MapHeaderSize - if z.Data != nil { - for za0001, za0002 := range z.Data { - _ = za0002 - s += msgp.StringPrefixSize + len(za0001) + msgp.GuessSize(za0002) - } - } +func (z data) Msgsize() (s int) { + s = 1 return } diff --git a/middleware/session/data_test.go b/middleware/session/data_test.go new file mode 100644 index 0000000000..1913f761d3 --- /dev/null +++ b/middleware/session/data_test.go @@ -0,0 +1,204 @@ +package session + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKeys(t *testing.T) { + t.Parallel() + + // Test case: Empty data + t.Run("Empty data", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + keys := d.Keys() + require.Empty(t, keys, "Expected no keys in empty data") + }) + + // Test case: Single key + t.Run("Single key", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + keys := d.Keys() + require.Len(t, keys, 1, "Expected one key") + require.Contains(t, keys, "key1", "Expected key1 to be present") + }) + + // Test case: Multiple keys + t.Run("Multiple keys", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + keys := d.Keys() + require.Len(t, keys, 3, "Expected three keys") + require.Contains(t, keys, "key1", "Expected key1 to be present") + require.Contains(t, keys, "key2", "Expected key2 to be present") + require.Contains(t, keys, "key3", "Expected key3 to be present") + }) + + // Test case: Concurrent access + t.Run("Concurrent access", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + + done := make(chan bool) + go func() { + keys := d.Keys() + assert.Len(t, keys, 3, "Expected three keys") + done <- true + }() + go func() { + keys := d.Keys() + assert.Len(t, keys, 3, "Expected three keys") + done <- true + }() + <-done + <-done + }) +} + +func TestData_Len(t *testing.T) { + t.Parallel() + + // Test case: Empty data + t.Run("Empty data", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + length := d.Len() + require.Equal(t, 0, length, "Expected length to be 0 for empty data") + }) + + // Test case: Single key + t.Run("Single key", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + length := d.Len() + require.Equal(t, 1, length, "Expected length to be 1 when one key is set") + }) + + // Test case: Multiple keys + t.Run("Multiple keys", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + length := d.Len() + require.Equal(t, 3, length, "Expected length to be 3 when three keys are set") + }) + + // Test case: Concurrent access + t.Run("Concurrent access", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + + done := make(chan bool, 2) // Buffered channel with size 2 + go func() { + length := d.Len() + assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") + done <- true + }() + go func() { + length := d.Len() + assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") + done <- true + }() + <-done + <-done + }) +} + +func TestData_Get(t *testing.T) { + t.Parallel() + + // Test case: Non-existent key + t.Run("Non-existent key", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + value := d.Get("non-existent-key") + require.Nil(t, value, "Expected nil for non-existent key") + }) + + // Test case: Existing key + t.Run("Existing key", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + value := d.Get("key1") + require.Equal(t, "value1", value, "Expected value1 for key1") + }) +} + +func TestData_Reset(t *testing.T) { + t.Parallel() + + // Test case: Reset data + t.Run("Reset data", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Reset() + require.Empty(t, d.Data, "Expected data map to be empty after reset") + }) +} + +func TestData_Delete(t *testing.T) { + t.Parallel() + + // Test case: Delete existing key + t.Run("Delete existing key", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Set("key1", "value1") + d.Delete("key1") + value := d.Get("key1") + require.Nil(t, value, "Expected nil for deleted key") + }) + + // Test case: Delete non-existent key + t.Run("Delete non-existent key", func(t *testing.T) { + t.Parallel() + d := acquireData() + defer dataPool.Put(d) + defer d.Reset() + d.Delete("non-existent-key") + // No assertion needed, just ensure no panic or error + }) +} diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go new file mode 100644 index 0000000000..c14bc19efe --- /dev/null +++ b/middleware/session/middleware.go @@ -0,0 +1,301 @@ +// Package session provides session management middleware for Fiber. +// This middleware handles user sessions, including storing session data in the store. +package session + +import ( + "errors" + "sync" + + "github.com/gofiber/fiber/v3" +) + +// Middleware holds session data and configuration. +type Middleware struct { + Session *Session + ctx fiber.Ctx + config Config + mu sync.RWMutex + destroyed bool +} + +// Context key for session middleware lookup. +type middlewareKey int + +const ( + // middlewareContextKey is the key used to store the *Middleware in the context locals. + middlewareContextKey middlewareKey = iota +) + +var ( + // ErrTypeAssertionFailed occurs when a type assertion fails. + ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") + + // Pool for reusing middleware instances. + middlewarePool = &sync.Pool{ + New: func() any { + return &Middleware{} + }, + } +) + +// New initializes session middleware with optional configuration. +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - fiber.Handler: The Fiber handler for the session middleware. +// +// Usage: +// +// app.Use(session.New()) +// +// Usage: +// +// app.Use(session.New()) +func New(config ...Config) fiber.Handler { + if len(config) > 0 { + handler, _ := NewWithStore(config[0]) + return handler + } + handler, _ := NewWithStore() + return handler +} + +// NewWithStore creates session middleware with an optional custom store. +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - fiber.Handler: The Fiber handler for the session middleware. +// - *Store: The session store. +// +// Usage: +// +// handler, store := session.NewWithStore() +func NewWithStore(config ...Config) (fiber.Handler, *Store) { + cfg := configDefault(config...) + + if cfg.Store == nil { + cfg.Store = NewStore(cfg) + } + + handler := func(c fiber.Ctx) error { + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + // Acquire session middleware + m := acquireMiddleware() + m.initialize(c, cfg) + + stackErr := c.Next() + + m.mu.RLock() + destroyed := m.destroyed + m.mu.RUnlock() + + if !destroyed { + m.saveSession() + } + + releaseMiddleware(m) + return stackErr + } + + return handler, cfg.Store +} + +// initialize sets up middleware for the request. +func (m *Middleware) initialize(c fiber.Ctx, cfg Config) { + m.mu.Lock() + defer m.mu.Unlock() + + session, err := cfg.Store.getSession(c) + if err != nil { + panic(err) // handle or log this error appropriately in production + } + + m.config = cfg + m.Session = session + m.ctx = c + + c.Locals(middlewareContextKey, m) +} + +// saveSession handles session saving and error management after the response. +func (m *Middleware) saveSession() { + if err := m.Session.saveSession(); err != nil { + if m.config.ErrorHandler != nil { + m.config.ErrorHandler(m.ctx, err) + } else { + DefaultErrorHandler(m.ctx, err) + } + } + + releaseSession(m.Session) +} + +// acquireMiddleware retrieves a middleware instance from the pool. +func acquireMiddleware() *Middleware { + m, ok := middlewarePool.Get().(*Middleware) + if !ok { + panic(ErrTypeAssertionFailed.Error()) + } + return m +} + +// releaseMiddleware resets and returns middleware to the pool. +// +// Parameters: +// - m: The middleware object to release. +// +// Usage: +// +// releaseMiddleware(m) +func releaseMiddleware(m *Middleware) { + m.mu.Lock() + m.config = Config{} + m.Session = nil + m.ctx = nil + m.destroyed = false + m.mu.Unlock() + middlewarePool.Put(m) +} + +// FromContext returns the Middleware from the Fiber context. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - *Middleware: The middleware object if found, otherwise nil. +// +// Usage: +// +// m := session.FromContext(c) +func FromContext(c fiber.Ctx) *Middleware { + m, ok := c.Locals(middlewareContextKey).(*Middleware) + if !ok { + return nil + } + return m +} + +// Set sets a key-value pair in the session. +// +// Parameters: +// - key: The key to set. +// - value: The value to set. +// +// Usage: +// +// m.Set("key", "value") +func (m *Middleware) Set(key, value any) { + m.mu.Lock() + defer m.mu.Unlock() + + m.Session.Set(key, value) +} + +// Get retrieves a value from the session by key. +// +// Parameters: +// - key: The key to retrieve. +// +// Returns: +// - any: The value associated with the key. +// +// Usage: +// +// value := m.Get("key") +func (m *Middleware) Get(key any) any { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.Session.Get(key) +} + +// Delete removes a key-value pair from the session. +// +// Parameters: +// - key: The key to delete. +// +// Usage: +// +// m.Delete("key") +func (m *Middleware) Delete(key any) { + m.mu.Lock() + defer m.mu.Unlock() + + m.Session.Delete(key) +} + +// Destroy destroys the session. +// +// Returns: +// - error: An error if the destruction fails. +// +// Usage: +// +// err := m.Destroy() +func (m *Middleware) Destroy() error { + m.mu.Lock() + defer m.mu.Unlock() + + err := m.Session.Destroy() + m.destroyed = true + return err +} + +// Fresh checks if the session is fresh. +// +// Returns: +// - bool: True if the session is fresh, otherwise false. +// +// Usage: +// +// isFresh := m.Fresh() +func (m *Middleware) Fresh() bool { + return m.Session.Fresh() +} + +// ID returns the session ID. +// +// Returns: +// - string: The session ID. +// +// Usage: +// +// id := m.ID() +func (m *Middleware) ID() string { + return m.Session.ID() +} + +// Reset resets the session. +// +// Returns: +// - error: An error if the reset fails. +// +// Usage: +// +// err := m.Reset() +func (m *Middleware) Reset() error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.Session.Reset() +} + +// Store returns the session store. +// +// Returns: +// - *Store: The session store. +// +// Usage: +// +// store := m.Store() +func (m *Middleware) Store() *Store { + return m.config.Store +} diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go new file mode 100644 index 0000000000..579d61c44c --- /dev/null +++ b/middleware/session/middleware_test.go @@ -0,0 +1,469 @@ +package session + +import ( + "strings" + "sync" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_Session_Middleware(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New()) + + app.Get("/get", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + value, ok := sess.Get("key").(string) + if !ok { + return c.Status(fiber.StatusNotFound).SendString("key not found") + } + return c.SendString("value=" + value) + }) + + app.Post("/set", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + // get a value from the body + value := c.FormValue("value") + sess.Set("key", value) + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/delete", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + sess.Delete("key") + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/reset", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if err := sess.Reset(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/destroy", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if err := sess.Destroy(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/fresh", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + // Reset the session to make it fresh + if err := sess.Reset(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if sess.Fresh() { + return c.SendStatus(fiber.StatusOK) + } + return c.SendStatus(fiber.StatusInternalServerError) + }) + + // Test GET, SET, DELETE, RESET, DESTROY by sending requests to the respective routes + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + h := app.Handler() + h(ctx) + require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] + require.Equal(t, "key not found", string(ctx.Response.Body())) + + // Test POST /set + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/set") + ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type + ctx.Request.SetBodyString("value=hello") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test GET /get to check if the value was set + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + require.Equal(t, "value=hello", string(ctx.Response.Body())) + + // Test POST /delete to delete the value + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/delete") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test GET /get to check if the value was deleted + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) + require.Equal(t, "key not found", string(ctx.Response.Body())) + + // Test POST /reset to reset the session + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/reset") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // verify we have a new session token + newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") + newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) + require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") + newToken = newTokenParts[1] + require.NotEqual(t, token, newToken) + token = newToken + + // Test POST /destroy to destroy the session + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/destroy") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Verify the session cookie is set to expire + setCookieHeader := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.Contains(t, setCookieHeader, "expires=") + cookieParts := strings.Split(setCookieHeader, ";") + expired := false + for _, part := range cookieParts { + if strings.Contains(part, "expires=") { + part = strings.TrimSpace(part) + expiryDateStr := strings.TrimPrefix(part, "expires=") + // Correctly parse the date with "GMT" timezone + expiryDate, err := time.Parse(time.RFC1123, strings.TrimSpace(expiryDateStr)) + require.NoError(t, err) + if expiryDate.Before(time.Now()) { + expired = true + break + } + } + } + require.True(t, expired, "Session cookie should be expired") + + // Sleep so that the session expires + time.Sleep(1 * time.Second) + + // Test GET /get to check if the session was destroyed + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) + // check that we have a new session token + newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") + parts := strings.Split(newToken, ";") + require.Greater(t, len(parts), 1) + valueParts := strings.Split(parts[0], "=") + require.Greater(t, len(valueParts), 1) + newToken = valueParts[1] + require.NotEqual(t, token, newToken) + token = newToken + + // Test POST /fresh to check if the session is fresh + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/fresh") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // check that we have a new session token + newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") + newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) + require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") + newToken = newTokenParts[1] + require.NotEqual(t, token, newToken) +} + +func Test_Session_NewWithStore(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New()) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + return c.SendString("value=" + id) + }) + app.Post("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + c.Cookie(&fiber.Cookie{ + Name: "session_id", + Value: id, + }) + return nil + }) + + h := app.Handler() + + // Test GET request without cookie + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // Get session cookie + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test GET request with cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + require.Equal(t, "value="+token, string(ctx.Response.Body())) +} + +func Test_Session_FromSession(t *testing.T) { + t.Parallel() + app := fiber.New() + + sess := FromContext(app.AcquireCtx(&fasthttp.RequestCtx{})) + require.Nil(t, sess) + + app.Use(New()) +} + +func Test_Session_WithConfig(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Next: func(c fiber.Ctx) bool { + return c.Get("key") == "value" + }, + IdleTimeout: 1 * time.Second, + KeyLookup: "cookie:session_id_test", + KeyGenerator: func() string { + return "test" + }, + source: "cookie_test", + sessionName: "session_id_test", + })) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + return c.SendString("value=" + id) + }) + + app.Get("/isFresh", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess.Fresh() { + return c.SendStatus(fiber.StatusOK) + } + return c.SendStatus(fiber.StatusInternalServerError) + }) + + app.Post("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + c.Cookie(&fiber.Cookie{ + Name: "session_id_test", + Value: id, + }) + return nil + }) + + h := app.Handler() + + // Test GET request without cookie + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // Get session cookie + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test GET request with cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id_test", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test POST request with cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.SetCookie("session_id_test", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test POST request without cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test POST request with wrong key + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test POST request with wrong value + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.SetCookie("session_id_test", "wrong") + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Check idle timeout not expired + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id_test", token) + ctx.Request.SetRequestURI("/isFresh") + h(ctx) + require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) + + // Test idle timeout + time.Sleep(1200 * time.Millisecond) + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id_test", token) + ctx.Request.SetRequestURI("/isFresh") + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) +} + +func Test_Session_Next(t *testing.T) { + t.Parallel() + + var ( + doNext bool + muNext sync.RWMutex + ) + + app := fiber.New() + + app.Use(New(Config{ + Next: func(_ fiber.Ctx) bool { + muNext.RLock() + defer muNext.RUnlock() + return doNext + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + id := sess.ID() + return c.SendString("value=" + id) + }) + + h := app.Handler() + + // Test with Next returning false + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // Get session cookie + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test with Next returning true + muNext.Lock() + doNext = true + muNext.Unlock() + + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) +} + +func Test_Session_Middleware_Store(t *testing.T) { + t.Parallel() + app := fiber.New() + + handler, sessionStore := NewWithStore() + + app.Use(handler) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + st := sess.Store() + if st != sessionStore { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + // Test GET request + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) +} diff --git a/middleware/session/session.go b/middleware/session/session.go index 8a16590064..ffb5c52722 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -12,95 +12,177 @@ import ( "github.com/valyala/fasthttp" ) +// Session represents a user session. type Session struct { - ctx fiber.Ctx // fiber context - config *Store // store configuration - data *data // key value data - byteBuffer *bytes.Buffer // byte buffer for the en- and decode - id string // session id - exp time.Duration // expiration of this session - mu sync.RWMutex // Mutex to protect non-data fields - fresh bool // if new session + ctx fiber.Ctx // fiber context + config *Store // store configuration + data *data // key value data + id string // session id + idleTimeout time.Duration // idleTimeout of this session + mu sync.RWMutex // Mutex to protect non-data fields + fresh bool // if new session +} + +type absExpirationKeyType int + +const ( + // sessionIDContextKey is the key used to store the session ID in the context locals. + absExpirationKey absExpirationKeyType = iota +) + +// Session pool for reusing byte buffers. +var byteBufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, } var sessionPool = sync.Pool{ New: func() any { - return new(Session) + return &Session{} }, } +// acquireSession returns a new Session from the pool. +// +// Returns: +// - *Session: The session object. +// +// Usage: +// +// s := acquireSession() func acquireSession() *Session { s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool if s.data == nil { s.data = acquireData() } - if s.byteBuffer == nil { - s.byteBuffer = new(bytes.Buffer) - } s.fresh = true return s } +// Release releases the session back to the pool. +// +// This function should be called after the session is no longer needed. +// This function is used to reduce the number of allocations and +// to improve the performance of the session store. +// +// The session should not be used after calling this function. +// +// Important: The Release function should only be used when accessing the session directly, +// for example, when you have called func (s *Session) Get(ctx) to get the session. +// It should not be used when using the session with a *Middleware handler in the request +// call stack, as the middleware will still need to access the session. +// +// Usage: +// +// sess := session.Get(ctx) +// defer sess.Release() +func (s *Session) Release() { + if s == nil { + return + } + releaseSession(s) +} + func releaseSession(s *Session) { s.mu.Lock() s.id = "" - s.exp = 0 + s.idleTimeout = 0 s.ctx = nil s.config = nil if s.data != nil { s.data.Reset() } - if s.byteBuffer != nil { - s.byteBuffer.Reset() - } s.mu.Unlock() sessionPool.Put(s) } -// Fresh is true if the current session is new +// Fresh returns whether the session is new +// +// Returns: +// - bool: True if the session is fresh, otherwise false. +// +// Usage: +// +// isFresh := s.Fresh() func (s *Session) Fresh() bool { s.mu.RLock() defer s.mu.RUnlock() return s.fresh } -// ID returns the session id +// ID returns the session ID +// +// Returns: +// - string: The session ID. +// +// Usage: +// +// id := s.ID() func (s *Session) ID() string { s.mu.RLock() defer s.mu.RUnlock() return s.id } -// Get will return the value -func (s *Session) Get(key string) any { - // Better safe than sorry +// Get returns the value associated with the given key. +// +// Parameters: +// - key: The key to retrieve. +// +// Returns: +// - any: The value associated with the key. +// +// Usage: +// +// value := s.Get("key") +func (s *Session) Get(key any) any { if s.data == nil { return nil } return s.data.Get(key) } -// Set will update or create a new key value -func (s *Session) Set(key string, val any) { - // Better safe than sorry +// Set updates or creates a new key-value pair in the session. +// +// Parameters: +// - key: The key to set. +// - val: The value to set. +// +// Usage: +// +// s.Set("key", "value") +func (s *Session) Set(key, val any) { if s.data == nil { return } s.data.Set(key, val) } -// Delete will delete the value -func (s *Session) Delete(key string) { - // Better safe than sorry +// Delete removes the key-value pair from the session. +// +// Parameters: +// - key: The key to delete. +// +// Usage: +// +// s.Delete("key") +func (s *Session) Delete(key any) { if s.data == nil { return } s.data.Delete(key) } -// Destroy will delete the session from Storage and expire session cookie +// Destroy deletes the session from storage and expires the session cookie. +// +// Returns: +// - error: An error if the destruction fails. +// +// Usage: +// +// err := s.Destroy() func (s *Session) Destroy() error { - // Better safe than sorry if s.data == nil { return nil } @@ -121,7 +203,14 @@ func (s *Session) Destroy() error { return nil } -// Regenerate generates a new session id and delete the old one from Storage +// Regenerate generates a new session id and deletes the old one from storage. +// +// Returns: +// - error: An error if the regeneration fails. +// +// Usage: +// +// err := s.Regenerate() func (s *Session) Regenerate() error { s.mu.Lock() defer s.mu.Unlock() @@ -137,7 +226,14 @@ func (s *Session) Regenerate() error { return nil } -// Reset generates a new session id, deletes the old one from storage, and resets the associated data +// Reset generates a new session id, deletes the old one from storage, and resets the associated data. +// +// Returns: +// - error: An error if the reset fails. +// +// Usage: +// +// err := s.Reset() func (s *Session) Reset() error { // Reset local data if s.data != nil { @@ -147,12 +243,8 @@ func (s *Session) Reset() error { s.mu.Lock() defer s.mu.Unlock() - // Reset byte buffer - if s.byteBuffer != nil { - s.byteBuffer.Reset() - } // Reset expiration - s.exp = 0 + s.idleTimeout = 0 // Delete old id from storage if err := s.config.Storage.Delete(s.id); err != nil { @@ -168,75 +260,102 @@ func (s *Session) Reset() error { return nil } -// refresh generates a new session, and set session.fresh to be true +// refresh generates a new session, and sets session.fresh to be true. func (s *Session) refresh() { s.id = s.config.KeyGenerator() s.fresh = true } -// Save will update the storage and client cookie +// Save saves the session data and updates the cookie // -// sess.Save() will save the session data to the storage and update the -// client cookie, and it will release the session after saving. +// Note: If the session is being used in the handler, calling Save will have +// no effect and the session will automatically be saved when the handler returns. // -// It's not safe to use the session after calling Save(). +// Returns: +// - error: An error if the save operation fails. +// +// Usage: +// +// err := s.Save() func (s *Session) Save() error { - // Better safe than sorry + if s.ctx == nil { + return s.saveSession() + } + + // If the session is being used in the handler, it should not be saved + if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok { + if m.Session == s { + // Session is in use, so we do nothing and return + return nil + } + } + + return s.saveSession() +} + +// saveSession encodes session data to saves it to storage. +func (s *Session) saveSession() error { if s.data == nil { return nil } s.mu.Lock() + defer s.mu.Unlock() - // Check if session has your own expiration, otherwise use default value - if s.exp <= 0 { - s.exp = s.config.Expiration + // Set idleTimeout if not already set + if s.idleTimeout <= 0 { + s.idleTimeout = s.config.IdleTimeout } // Update client cookie s.setSession() - // Convert data to bytes - encCache := gob.NewEncoder(s.byteBuffer) - err := encCache.Encode(&s.data.Data) + // Encode session data + s.data.RLock() + encodedBytes, err := s.encodeSessionData() + s.data.RUnlock() if err != nil { return fmt.Errorf("failed to encode data: %w", err) } - // Copy the data in buffer - encodedBytes := make([]byte, s.byteBuffer.Len()) - copy(encodedBytes, s.byteBuffer.Bytes()) - // Pass copied bytes with session id to provider - if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil { - return err - } - - s.mu.Unlock() - - // Release session - // TODO: It's not safe to use the Session after calling Save() - releaseSession(s) - - return nil + return s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout) } -// Keys will retrieve all keys in current session -func (s *Session) Keys() []string { +// Keys retrieves all keys in the current session. +// +// Returns: +// - []string: A slice of all keys in the session. +// +// Usage: +// +// keys := s.Keys() +func (s *Session) Keys() []any { if s.data == nil { - return []string{} + return []any{} } return s.data.Keys() } -// SetExpiry sets a specific expiration for this session -func (s *Session) SetExpiry(exp time.Duration) { +// SetIdleTimeout used when saving the session on the next call to `Save()`. +// +// Parameters: +// - idleTimeout: The duration for the idle timeout. +// +// Usage: +// +// s.SetIdleTimeout(time.Hour) +func (s *Session) SetIdleTimeout(idleTimeout time.Duration) { s.mu.Lock() defer s.mu.Unlock() - s.exp = exp + s.idleTimeout = idleTimeout } func (s *Session) setSession() { + if s.ctx == nil { + return + } + if s.config.source == SourceHeader { s.ctx.Request().Header.SetBytesV(s.config.sessionName, []byte(s.id)) s.ctx.Response().Header.SetBytesV(s.config.sessionName, []byte(s.id)) @@ -249,8 +368,8 @@ func (s *Session) setSession() { // Cookies are also session cookies if they do not specify the Expires or Max-Age attribute. // refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie if !s.config.CookieSessionOnly { - fcookie.SetMaxAge(int(s.exp.Seconds())) - fcookie.SetExpire(time.Now().Add(s.exp)) + fcookie.SetMaxAge(int(s.idleTimeout.Seconds())) + fcookie.SetExpire(time.Now().Add(s.idleTimeout)) } fcookie.SetSecure(s.config.CookieSecure) fcookie.SetHTTPOnly(s.config.CookieHTTPOnly) @@ -269,6 +388,10 @@ func (s *Session) setSession() { } func (s *Session) delSession() { + if s.ctx == nil { + return + } + if s.config.source == SourceHeader { s.ctx.Request().Header.Del(s.config.sessionName) s.ctx.Response().Header.Del(s.config.sessionName) @@ -299,12 +422,92 @@ func (s *Session) delSession() { } } -// decodeSessionData decodes the session data from raw bytes. +// decodeSessionData decodes session data from raw bytes +// +// Parameters: +// - rawData: The raw byte data to decode. +// +// Returns: +// - error: An error if the decoding fails. +// +// Usage: +// +// err := s.decodeSessionData(rawData) func (s *Session) decodeSessionData(rawData []byte) error { - _, _ = s.byteBuffer.Write(rawData) - encCache := gob.NewDecoder(s.byteBuffer) - if err := encCache.Decode(&s.data.Data); err != nil { + byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + defer byteBufferPool.Put(byteBuffer) + defer byteBuffer.Reset() + _, _ = byteBuffer.Write(rawData) + decCache := gob.NewDecoder(byteBuffer) + if err := decCache.Decode(&s.data.Data); err != nil { return fmt.Errorf("failed to decode session data: %w", err) } return nil } + +// encodeSessionData encodes session data to raw bytes +// +// Parameters: +// - rawData: The raw byte data to encode. +// +// Returns: +// - error: An error if the encoding fails. +// +// Usage: +// +// err := s.encodeSessionData(rawData) +func (s *Session) encodeSessionData() ([]byte, error) { + byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + defer byteBufferPool.Put(byteBuffer) + defer byteBuffer.Reset() + encCache := gob.NewEncoder(byteBuffer) + if err := encCache.Encode(&s.data.Data); err != nil { + return nil, fmt.Errorf("failed to encode session data: %w", err) + } + // Copy the bytes + // Copy the data in buffer + encodedBytes := make([]byte, byteBuffer.Len()) + copy(encodedBytes, byteBuffer.Bytes()) + + return encodedBytes, nil +} + +// absExpiration returns the session absolute expiration time or a zero time if not set. +// +// Returns: +// - time.Time: The session absolute expiration time. Zero time if not set. +// +// Usage: +// +// expiration := s.absExpiration() +func (s *Session) absExpiration() time.Time { + absExpiration, ok := s.Get(absExpirationKey).(time.Time) + if ok { + return absExpiration + } + return time.Time{} +} + +// isAbsExpired returns true if the session is expired. +// +// If the session has an absolute expiration time set, this function will return true if the +// current time is after the absolute expiration time. +// +// Returns: +// - bool: True if the session is expired, otherwise false. +func (s *Session) isAbsExpired() bool { + absExpiration := s.absExpiration() + return !absExpiration.IsZero() && time.Now().After(absExpiration) +} + +// setAbsoluteExpiration sets the absolute session expiration time. +// +// Parameters: +// - expiration: The session expiration time. +// +// Usage: +// +// s.setExpiration(time.Now().Add(time.Hour)) +func (s *Session) setAbsExpiration(absExpiration time.Time) { + s.Set(absExpirationKey, absExpiration) +} diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index fa12d690a6..038bfc4b8d 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -8,6 +8,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -17,14 +18,13 @@ func Test_Session(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) // Get a new session sess, err := store.Get(ctx) @@ -33,6 +33,7 @@ func Test_Session(t *testing.T) { token := sess.ID() require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -46,7 +47,7 @@ func Test_Session(t *testing.T) { // get keys keys := sess.Keys() - require.Equal(t, []string{}, keys) + require.Equal(t, []any{}, keys) // get value name := sess.Get("name") @@ -60,7 +61,7 @@ func Test_Session(t *testing.T) { require.Equal(t, "john", name) keys = sess.Keys() - require.Equal(t, []string{"name"}, keys) + require.Equal(t, []any{"name"}, keys) // delete key sess.Delete("name") @@ -71,7 +72,7 @@ func Test_Session(t *testing.T) { // get keys keys = sess.Keys() - require.Equal(t, []string{}, keys) + require.Equal(t, []any{}, keys) // get id id := sess.ID() @@ -81,6 +82,9 @@ func Test_Session(t *testing.T) { err = sess.Save() require.NoError(t, err) + // release the session + sess.Release() + // release the context app.ReleaseCtx(ctx) // requesting entirely new context to prevent falsy tests @@ -93,6 +97,8 @@ func Test_Session(t *testing.T) { // this id should be randomly generated as session key was deleted require.Len(t, sess.ID(), 36) + sess.Release() + // when we use the original session for the second time // the session be should be same if the session is not expired app.ReleaseCtx(ctx) @@ -102,6 +108,7 @@ func Test_Session(t *testing.T) { // request the server with the old session ctx.Request().Header.SetCookie(store.sessionName, id) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Equal(t, sess.id, id) @@ -112,7 +119,7 @@ func Test_Session_Types(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() @@ -186,6 +193,7 @@ func Test_Session_Types(t *testing.T) { err = sess.Save() require.NoError(t, err) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -277,6 +285,8 @@ func Test_Session_Types(t *testing.T) { require.True(t, ok) require.Equal(t, vcomplex128, vcomplex128Result) + sess.Release() + app.ReleaseCtx(ctx) } @@ -284,7 +294,7 @@ func Test_Session_Types(t *testing.T) { func Test_Session_Store_Reset(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -304,6 +314,7 @@ func Test_Session_Store_Reset(t *testing.T) { require.NoError(t, store.Reset()) id := sess.ID() + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -311,11 +322,187 @@ func Test_Session_Store_Reset(t *testing.T) { // make sure the session is recreated sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.True(t, sess.Fresh()) require.Nil(t, sess.Get("hello")) } +func Test_Session_KeyTypes(t *testing.T) { + t.Parallel() + + // session store + store := NewStore() + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + // get session + sess, err := store.Get(ctx) + require.NoError(t, err) + require.True(t, sess.Fresh()) + + type Person struct { + Name string + } + + type unexportedKey int + + // register non-default types + store.RegisterType(Person{}) + store.RegisterType(unexportedKey(0)) + + type unregisteredKeyType int + type unregisteredValueType int + + // verify unregistered keys types are not allowed + var ( + unregisteredKey unregisteredKeyType + unregisteredValue unregisteredValueType + ) + sess.Set(unregisteredKey, "test") + err = sess.Save() + require.Error(t, err) + sess.Delete(unregisteredKey) + err = sess.Save() + require.NoError(t, err) + sess.Set("abc", unregisteredValue) + err = sess.Save() + require.Error(t, err) + sess.Delete("abc") + err = sess.Save() + require.NoError(t, err) + + require.NoError(t, sess.Reset()) + + var ( + kbool = true + kstring = "str" + kint = 13 + kint8 int8 = 13 + kint16 int16 = 13 + kint32 int32 = 13 + kint64 int64 = 13 + kuint uint = 13 + kuint8 uint8 = 13 + kuint16 uint16 = 13 + kuint32 uint32 = 13 + kuint64 uint64 = 13 + kuintptr uintptr = 13 + kbyte byte = 'k' + krune = 'k' + kfloat32 float32 = 13 + kfloat64 float64 = 13 + kcomplex64 complex64 = 13 + kcomplex128 complex128 = 13 + kuser = Person{Name: "John"} + kunexportedKey = unexportedKey(13) + ) + + var ( + vbool = true + vstring = "str" + vint = 13 + vint8 int8 = 13 + vint16 int16 = 13 + vint32 int32 = 13 + vint64 int64 = 13 + vuint uint = 13 + vuint8 uint8 = 13 + vuint16 uint16 = 13 + vuint32 uint32 = 13 + vuint64 uint64 = 13 + vuintptr uintptr = 13 + vbyte byte = 'k' + vrune = 'k' + vfloat32 float32 = 13 + vfloat64 float64 = 13 + vcomplex64 complex64 = 13 + vcomplex128 complex128 = 13 + vuser = Person{Name: "John"} + vunexportedKey = unexportedKey(13) + ) + + keys := []any{ + kbool, + kstring, + kint, + kint8, + kint16, + kint32, + kint64, + kuint, + kuint8, + kuint16, + kuint32, + kuint64, + kuintptr, + kbyte, + krune, + kfloat32, + kfloat64, + kcomplex64, + kcomplex128, + kuser, + kunexportedKey, + } + + values := []any{ + vbool, + vstring, + vint, + vint8, + vint16, + vint32, + vint64, + vuint, + vuint8, + vuint16, + vuint32, + vuint64, + vuintptr, + vbyte, + vrune, + vfloat32, + vfloat64, + vcomplex64, + vcomplex128, + vuser, + vunexportedKey, + } + + // loop test all key value pairs + for i, key := range keys { + sess.Set(key, values[i]) + } + + id := sess.ID() + ctx.Request().Header.SetCookie(store.sessionName, id) + // save session + err = sess.Save() + require.NoError(t, err) + + sess.Release() + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie(store.sessionName, id) + + // get session + sess, err = store.Get(ctx) + require.NoError(t, err) + defer sess.Release() + require.False(t, sess.Fresh()) + + // loop test all key value pairs + for i, key := range keys { + // get value + result := sess.Get(key) + require.Equal(t, values[i], result) + } +} + // go test -run Test_Session_Save func Test_Session_Save(t *testing.T) { t.Parallel() @@ -323,7 +510,7 @@ func Test_Session_Save(t *testing.T) { t.Run("save to cookie", func(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -338,12 +525,13 @@ func Test_Session_Save(t *testing.T) { // save session err = sess.Save() require.NoError(t, err) + sess.Release() }) t.Run("save to header", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := NewStore(Config{ KeyLookup: "header:session_id", }) // fiber instance @@ -363,10 +551,11 @@ func Test_Session_Save(t *testing.T) { require.NoError(t, err) require.Equal(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName))) require.Equal(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName))) + sess.Release() }) } -func Test_Session_Save_Expiration(t *testing.T) { +func Test_Session_Save_IdleTimeout(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { @@ -374,7 +563,7 @@ func Test_Session_Save_Expiration(t *testing.T) { const sessionDuration = 5 * time.Second // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -391,12 +580,13 @@ func Test_Session_Save_Expiration(t *testing.T) { token := sess.ID() // expire this session in 5 seconds - sess.SetExpiry(sessionDuration) + sess.SetIdleTimeout(sessionDuration) // save session err = sess.Save() require.NoError(t, err) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -409,16 +599,103 @@ func Test_Session_Save_Expiration(t *testing.T) { // just to make sure the session has been expired time.Sleep(sessionDuration + (10 * time.Millisecond)) + sess.Release() + app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) + // here you should get a new session + ctx.Request().Header.SetCookie(store.sessionName, token) + sess, err = store.Get(ctx) + defer sess.Release() + require.NoError(t, err) + require.Nil(t, sess.Get("name")) + require.NotEqual(t, sess.ID(), token) + }) +} + +func Test_Session_Save_AbsoluteTimeout(t *testing.T) { + t.Parallel() + + t.Run("save to cookie", func(t *testing.T) { + t.Parallel() + + const absoluteTimeout = 1 * time.Second + // session store + store := NewStore(Config{ + IdleTimeout: absoluteTimeout, + AbsoluteTimeout: absoluteTimeout, + }) + + // force change to IdleTimeout + store.Config.IdleTimeout = 10 * time.Second + + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // get session + sess, err := store.Get(ctx) + require.NoError(t, err) + + // set value + sess.Set("name", "john") + + token := sess.ID() + + // save session + err = sess.Save() + require.NoError(t, err) + + sess.Release() + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + + // here you need to get the old session yet + ctx.Request().Header.SetCookie(store.sessionName, token) + sess, err = store.Get(ctx) + require.NoError(t, err) + require.Equal(t, "john", sess.Get("name")) + + // just to make sure the session has been expired + time.Sleep(absoluteTimeout + (100 * time.Millisecond)) + + sess.Release() + + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + // here you should get a new session ctx.Request().Header.SetCookie(store.sessionName, token) sess, err = store.Get(ctx) require.NoError(t, err) require.Nil(t, sess.Get("name")) require.NotEqual(t, sess.ID(), token) + require.True(t, sess.Fresh()) + require.IsType(t, time.Time{}, sess.Get(absExpirationKey)) + + token = sess.ID() + + sess.Set("name", "john") + + // save session + err = sess.Save() + require.NoError(t, err) + + sess.Release() + app.ReleaseCtx(ctx) + + // just to make sure the session has been expired + time.Sleep(absoluteTimeout + (100 * time.Millisecond)) + + // try to get expired session by id + sess, err = store.GetByID(token) + require.Error(t, err) + require.ErrorIs(t, err, ErrSessionIDNotFoundInStore) + require.Nil(t, sess) }) } @@ -429,7 +706,7 @@ func Test_Session_Destroy(t *testing.T) { t.Run("destroy from cookie", func(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -438,6 +715,7 @@ func Test_Session_Destroy(t *testing.T) { // get session sess, err := store.Get(ctx) + defer sess.Release() require.NoError(t, err) sess.Set("name", "fenny") @@ -449,7 +727,7 @@ func Test_Session_Destroy(t *testing.T) { t.Run("destroy from header", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := NewStore(Config{ KeyLookup: "header:session_id", }) // fiber instance @@ -467,6 +745,7 @@ func Test_Session_Destroy(t *testing.T) { id := sess.ID() require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -475,6 +754,7 @@ func Test_Session_Destroy(t *testing.T) { ctx.Request().Header.Set(store.sessionName, id) sess, err = store.Get(ctx) require.NoError(t, err) + defer sess.Release() err = sess.Destroy() require.NoError(t, err) @@ -487,19 +767,19 @@ func Test_Session_Destroy(t *testing.T) { func Test_Session_Custom_Config(t *testing.T) { t.Parallel() - store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }}) - require.Equal(t, time.Hour, store.Expiration) + store := NewStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) + require.Equal(t, time.Hour, store.IdleTimeout) require.Equal(t, "very random", store.KeyGenerator()) - store = New(Config{Expiration: 0}) - require.Equal(t, ConfigDefault.Expiration, store.Expiration) + store = NewStore(Config{IdleTimeout: 0}) + require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout) } // go test -run Test_Session_Cookie func Test_Session_Cookie(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -511,15 +791,19 @@ func Test_Session_Cookie(t *testing.T) { require.NoError(t, err) require.NoError(t, sess.Save()) + sess.Release() + // cookie should be set on Save ( even if empty data ) - require.Len(t, ctx.Response().Header.PeekCookie(store.sessionName), 84) + cookie := ctx.Response().Header.PeekCookie(store.sessionName) + require.NotNil(t, cookie) + require.Regexp(t, `^session_id=[a-f0-9\-]{36}; max-age=\d+; path=/; SameSite=Lax$`, string(cookie)) } // go test -run Test_Session_Cookie_In_Response // Regression: https://github.com/gofiber/fiber/pull/1191 func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { t.Parallel() - store := New() + store := NewStore() app := fiber.New() // fiber context @@ -534,8 +818,11 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { id := sess.ID() require.NoError(t, sess.Save()) + sess.Release() + sess, err = store.Get(ctx) require.NoError(t, err) + defer sess.Release() sess.Set("name", "john") require.True(t, sess.Fresh()) require.Equal(t, id, sess.ID()) // session id should be the same @@ -548,7 +835,7 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { // Regression: https://github.com/gofiber/fiber/issues/1365 func Test_Session_Deletes_Single_Key(t *testing.T) { t.Parallel() - store := New() + store := NewStore() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -559,6 +846,7 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { sess.Set("id", "1") require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie(store.sessionName, id) @@ -568,11 +856,13 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { sess.Delete("id") require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie(store.sessionName, id) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Nil(t, sess.Get("id")) @@ -587,7 +877,7 @@ func Test_Session_Reset(t *testing.T) { app := fiber.New() // session store - store := New() + store := NewStore() t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) { t.Parallel() @@ -609,6 +899,7 @@ func Test_Session_Reset(t *testing.T) { err = freshSession.Save() require.NoError(t, err) + freshSession.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -630,7 +921,7 @@ func Test_Session_Reset(t *testing.T) { // Check that the session data has been reset keys := acquiredSession.Keys() - require.Equal(t, []string{}, keys) + require.Equal(t, []any{}, keys) // Set a new value for 'name' and check that it's updated acquiredSession.Set("name", "john") @@ -641,6 +932,8 @@ func Test_Session_Reset(t *testing.T) { err = acquiredSession.Save() require.NoError(t, err) + acquiredSession.Release() + // Check that the session id is not in the header or cookie anymore require.Equal(t, "", string(ctx.Response().Header.Peek(store.sessionName))) require.Equal(t, "", string(ctx.Request().Header.Peek(store.sessionName))) @@ -658,7 +951,7 @@ func Test_Session_Regenerate(t *testing.T) { t.Run("set fresh to be true when regenerating a session", func(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // a random session uuid originalSessionUUIDString := "" // fiber context @@ -674,6 +967,8 @@ func Test_Session_Regenerate(t *testing.T) { err = freshSession.Save() require.NoError(t, err) + freshSession.Release() + // release the context app.ReleaseCtx(ctx) @@ -686,6 +981,7 @@ func Test_Session_Regenerate(t *testing.T) { // as the session is in the storage, session.fresh should be false acquiredSession, err := store.Get(ctx) require.NoError(t, err) + defer acquiredSession.Release() require.False(t, acquiredSession.Fresh()) err = acquiredSession.Regenerate() @@ -704,7 +1000,7 @@ func Test_Session_Regenerate(t *testing.T) { // go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4 func Benchmark_Session(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), NewStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie(store.sessionName, "12356789") @@ -715,12 +1011,14 @@ func Benchmark_Session(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() } }) b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := NewStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -733,6 +1031,8 @@ func Benchmark_Session(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() } }) } @@ -740,7 +1040,7 @@ func Benchmark_Session(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4 func Benchmark_Session_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), NewStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -751,6 +1051,9 @@ func Benchmark_Session_Parallel(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() + app.ReleaseCtx(c) } }) @@ -758,7 +1061,7 @@ func Benchmark_Session_Parallel(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := NewStore(Config{ Storage: memory.New(), }) b.ReportAllocs() @@ -771,6 +1074,9 @@ func Benchmark_Session_Parallel(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() + app.ReleaseCtx(c) } }) @@ -780,7 +1086,7 @@ func Benchmark_Session_Parallel(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4 func Benchmark_Session_Asserted(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), NewStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie(store.sessionName, "12356789") @@ -793,12 +1099,13 @@ func Benchmark_Session_Asserted(b *testing.B) { sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) + sess.Release() } }) b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := NewStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -813,6 +1120,7 @@ func Benchmark_Session_Asserted(b *testing.B) { sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) + sess.Release() } }) } @@ -820,7 +1128,7 @@ func Benchmark_Session_Asserted(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4 func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), NewStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -832,6 +1140,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) + sess.Release() app.ReleaseCtx(c) } }) @@ -839,7 +1148,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := NewStore(Config{ Storage: memory.New(), }) b.ReportAllocs() @@ -853,6 +1162,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) + sess.Release() app.ReleaseCtx(c) } }) @@ -863,7 +1173,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { func Test_Session_Concurrency(t *testing.T) { t.Parallel() app := fiber.New() - store := New() + store := NewStore() var wg sync.WaitGroup errChan := make(chan error, 10) // Buffered channel to collect errors @@ -877,7 +1187,7 @@ func Test_Session_Concurrency(t *testing.T) { localCtx := app.AcquireCtx(&fasthttp.RequestCtx{}) - sess, err := store.Get(localCtx) + sess, err := store.getSession(localCtx) if err != nil { errChan <- err return @@ -901,6 +1211,9 @@ func Test_Session_Concurrency(t *testing.T) { return } + // release the session + sess.Release() + // Release the context app.ReleaseCtx(localCtx) @@ -917,6 +1230,7 @@ func Test_Session_Concurrency(t *testing.T) { errChan <- err return } + defer sess.Release() // Get the value name := sess.Get("name") @@ -963,3 +1277,42 @@ func Test_Session_Concurrency(t *testing.T) { require.NoError(t, err) } } + +func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) { + // Initialize a new store with default config + store := NewStore() + + // Create a new Fiber app + app := fiber.New() + + // Generate a fake session ID + sessionID := uuid.New().String() + + // Store invalid session data to simulate decode error + err := store.Storage.Set(sessionID, []byte("invalid data"), 0) + require.NoError(t, err, "Failed to set invalid session data") + + // Create a new request context + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + // Set the session ID in cookies + c.Request().Header.SetCookie(store.sessionName, sessionID) + + // Attempt to get the session + _, err = store.Get(c) + require.Error(t, err, "Expected error due to invalid session data, but got nil") + + // Check that the error message is as expected + require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message") + + // Check that the error is as expected + require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") + + // Attempt to get the session by ID + _, err = store.GetByID(sessionID) + require.Error(t, err, "Expected error due to invalid session data, but got nil") + + // Check that the error message is as expected + require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") +} diff --git a/middleware/session/store.go b/middleware/session/store.go index 01b4548c0a..013743d068 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -4,14 +4,20 @@ import ( "encoding/gob" "errors" "fmt" + "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) // ErrEmptySessionID is an error that occurs when the session ID is empty. -var ErrEmptySessionID = errors.New("session id cannot be empty") +var ( + ErrEmptySessionID = errors.New("session ID cannot be empty") + ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware") + ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") +) // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int @@ -26,7 +32,17 @@ type Store struct { } // New creates a new session store with the provided configuration. -func New(config ...Config) *Store { +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - *Store: The session store. +// +// Usage: +// +// store := session.New() +func NewStore(config ...Config) *Store { // Set default config cfg := configDefault(config...) @@ -34,18 +50,75 @@ func New(config ...Config) *Store { cfg.Storage = memory.New() } - return &Store{ + store := &Store{ Config: cfg, } + + if cfg.AbsoluteTimeout > 0 { + store.RegisterType(absExpirationKey) + store.RegisterType(time.Time{}) + } + + return store } // RegisterType registers a custom type for encoding/decoding into any storage provider. +// +// Parameters: +// - i: The custom type to register. +// +// Usage: +// +// store.RegisterType(MyCustomType{}) func (*Store) RegisterType(i any) { gob.Register(i) } -// Get retrieves or creates a session for the given context. +// Get will get/create a session. +// +// This function will return an ErrSessionAlreadyLoadedByMiddleware if +// the session is already loaded by the middleware. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - *Session: The session object. +// - error: An error if the session retrieval fails or if the session is already loaded by the middleware. +// +// Usage: +// +// sess, err := store.Get(c) +// if err != nil { +// // handle error +// } func (s *Store) Get(c fiber.Ctx) (*Session, error) { + // If session is already loaded in the context, + // it should not be loaded again + _, ok := c.Locals(middlewareContextKey).(*Middleware) + if ok { + return nil, ErrSessionAlreadyLoadedByMiddleware + } + + return s.getSession(c) +} + +// getSession retrieves a session based on the context. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - *Session: The session object. +// - error: An error if the session retrieval fails. +// +// Usage: +// +// sess, err := store.getSession(c) +// if err != nil { +// // handle error +// } +func (s *Store) getSession(c fiber.Ctx) (*Session, error) { var rawData []byte var err error @@ -79,7 +152,6 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { sess := acquireSession() sess.mu.Lock() - defer sess.mu.Unlock() sess.ctx = c sess.config = s @@ -89,16 +161,40 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { // Decode session data if found if rawData != nil { sess.data.Lock() - defer sess.data.Unlock() - if err := sess.decodeSessionData(rawData); err != nil { + err := sess.decodeSessionData(rawData) + sess.data.Unlock() + if err != nil { + sess.mu.Unlock() + sess.Release() return nil, fmt.Errorf("failed to decode session data: %w", err) } } + sess.mu.Unlock() + + if fresh && s.AbsoluteTimeout > 0 { + sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) + } else if sess.isAbsExpired() { + if err := sess.Reset(); err != nil { + return nil, fmt.Errorf("failed to reset session: %w", err) + } + sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) + } + return sess, nil } // getSessionID returns the session ID from cookies, headers, or query string. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - string: The session ID. +// +// Usage: +// +// id := store.getSessionID(c) func (s *Store) getSessionID(c fiber.Ctx) string { id := c.Cookies(s.sessionName) if len(id) > 0 { @@ -123,14 +219,113 @@ func (s *Store) getSessionID(c fiber.Ctx) string { } // Reset deletes all sessions from the storage. +// +// Returns: +// - error: An error if the reset operation fails. +// +// Usage: +// +// err := store.Reset() +// if err != nil { +// // handle error +// } func (s *Store) Reset() error { return s.Storage.Reset() } // Delete deletes a session by its ID. +// +// Parameters: +// - id: The unique identifier of the session. +// +// Returns: +// - error: An error if the deletion fails or if the session ID is empty. +// +// Usage: +// +// err := store.Delete(id) +// if err != nil { +// // handle error +// } func (s *Store) Delete(id string) error { if id == "" { return ErrEmptySessionID } return s.Storage.Delete(id) } + +// GetByID retrieves a session by its ID from the storage. +// If the session is not found, it returns nil and an error. +// +// Unlike session middleware methods, this function does not automatically: +// +// - Load the session into the request context. +// +// - Save the session data to the storage or update the client cookie. +// +// Important Notes: +// +// - The session object returned by GetByID does not have a context associated with it. +// +// - When using this method alongside session middleware, there is a potential for collisions, +// so be mindful of interactions between manually retrieved sessions and middleware-managed sessions. +// +// - If you modify a session returned by GetByID, you must call session.Save() to persist the changes. +// +// - When you are done with the session, you should call session.Release() to release the session back to the pool. +// +// Parameters: +// - id: The unique identifier of the session. +// +// Returns: +// - *Session: The session object if found, otherwise nil. +// - error: An error if the session retrieval fails or if the session ID is empty. +// +// Usage: +// +// sess, err := store.GetByID(id) +// if err != nil { +// // handle error +// } +func (s *Store) GetByID(id string) (*Session, error) { + if id == "" { + return nil, ErrEmptySessionID + } + + rawData, err := s.Storage.Get(id) + if err != nil { + return nil, err + } + if rawData == nil { + return nil, ErrSessionIDNotFoundInStore + } + + sess := acquireSession() + + sess.mu.Lock() + + sess.config = s + sess.id = id + sess.fresh = false + + sess.data.Lock() + decodeErr := sess.decodeSessionData(rawData) + sess.data.Unlock() + sess.mu.Unlock() + if decodeErr != nil { + sess.Release() + return nil, fmt.Errorf("failed to decode session data: %w", decodeErr) + } + + if s.AbsoluteTimeout > 0 { + if sess.isAbsExpired() { + if err := sess.Destroy(); err != nil { + sess.Release() + log.Errorf("failed to destroy session: %v", err) + } + return nil, ErrSessionIDNotFoundInStore + } + } + + return sess, nil +} diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index adce2e3488..8a45c7e5fb 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -20,9 +20,10 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from cookie", func(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.sessionName, expectedID) @@ -33,11 +34,12 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from header", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := NewStore(Config{ KeyLookup: "header:session_id", }) // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set header ctx.Request().Header.Set(store.sessionName, expectedID) @@ -48,11 +50,12 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from url query", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := NewStore(Config{ KeyLookup: "query:session_id", }) // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set url parameter ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.sessionName, expectedID)) @@ -73,9 +76,10 @@ func Test_Store_Get(t *testing.T) { t.Run("session should be re-generated if it is invalid", func(t *testing.T) { t.Parallel() // session store - store := New() + store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.sessionName, unexpectedID) @@ -93,10 +97,11 @@ func Test_Store_DeleteSession(t *testing.T) { // fiber instance app := fiber.New() // session store - store := New() + store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // Create a new session session, err := store.Get(ctx) @@ -116,3 +121,105 @@ func Test_Store_DeleteSession(t *testing.T) { // The session ID should be different now, because the old session was deleted require.NotEqual(t, sessionID, session.ID()) } + +func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { + // Create a new Fiber app + app := fiber.New() + + // Create a new context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // Mock middleware and set it in the context + middleware := &Middleware{} + ctx.Locals(middlewareContextKey, middleware) + + // Create a new store + store := &Store{} + + // Call the Get method + sess, err := store.Get(ctx) + + // Assert that the error is ErrSessionAlreadyLoadedByMiddleware + require.Nil(t, sess) + require.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err) +} + +func TestStore_Delete(t *testing.T) { + // Create a new store + store := NewStore() + + t.Run("delete with empty session ID", func(t *testing.T) { + err := store.Delete("") + require.Error(t, err) + require.Equal(t, ErrEmptySessionID, err) + }) + + t.Run("delete non-existing session", func(t *testing.T) { + err := store.Delete("non-existing-session-id") + require.NoError(t, err) + }) +} + +func Test_Store_GetByID(t *testing.T) { + t.Parallel() + // Create a new store + store := NewStore() + + t.Run("empty session ID", func(t *testing.T) { + t.Parallel() + sess, err := store.GetByID("") + require.Error(t, err) + require.Nil(t, sess) + require.Equal(t, ErrEmptySessionID, err) + }) + + t.Run("non-existent session ID", func(t *testing.T) { + t.Parallel() + sess, err := store.GetByID("non-existent-session-id") + require.Error(t, err) + require.Nil(t, sess) + require.Equal(t, ErrSessionIDNotFoundInStore, err) + }) + + t.Run("valid session ID", func(t *testing.T) { + t.Parallel() + app := fiber.New() + // Create a new session + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + session, err := store.Get(ctx) + defer session.Release() + defer app.ReleaseCtx(ctx) + require.NoError(t, err) + + // Save the session ID + sessionID := session.ID() + + // Save the session + err = session.Save() + require.NoError(t, err) + + // Retrieve the session by ID + retrievedSession, err := store.GetByID(sessionID) + require.NoError(t, err) + require.NotNil(t, retrievedSession) + require.Equal(t, sessionID, retrievedSession.ID()) + + // Call Save on the retrieved session + retrievedSession.Set("key", "value") + err = retrievedSession.Save() + require.NoError(t, err) + + // Call Other Session methods + require.Equal(t, "value", retrievedSession.Get("key")) + require.False(t, retrievedSession.Fresh()) + + require.NoError(t, retrievedSession.Reset()) + require.NoError(t, retrievedSession.Destroy()) + require.IsType(t, []any{}, retrievedSession.Keys()) + require.NoError(t, retrievedSession.Regenerate()) + require.NotPanics(t, func() { + retrievedSession.Release() + }) + }) +}