diff --git a/app.go b/app.go index e0240d3c16..82ba039a61 100644 --- a/app.go +++ b/app.go @@ -14,6 +14,7 @@ import ( "encoding/xml" "errors" "fmt" + "io" "net" "net/http" "net/http/httputil" @@ -864,13 +865,33 @@ func (app *App) Hooks() *Hooks { return app.hooks } +// TestConfig is a struct holding Test settings +type TestConfig struct { //nolint:govet // Aligning the struct fields is not necessary. betteralign:ignore + // Sets a timeout duration for the test. + // + // Default: time.Second + Timeout time.Duration + + // When set to true, the test will discard the + // current http response and give a timeout error. + // + // Default: true + ErrOnTimeout bool +} + // Test is used for internal debugging by passing a *http.Request. -// Timeout is optional and defaults to 1s, -1 will disable it completely. -func (app *App) Test(req *http.Request, timeout ...time.Duration) (*http.Response, error) { - // Set timeout - to := 1 * time.Second - if len(timeout) > 0 { - to = timeout[0] +// Config is optional and defaults to a 1s error on timeout, +// -1 timeout will disable it completely. +func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, error) { + // Default config + cfg := TestConfig{ + Timeout: time.Second, + ErrOnTimeout: true, + } + + // Override config if provided + if len(config) > 0 { + cfg = config[0] } // Add Content-Length if not provided with body @@ -909,12 +930,15 @@ func (app *App) Test(req *http.Request, timeout ...time.Duration) (*http.Respons }() // Wait for callback - if to >= 0 { + if cfg.Timeout >= 0 { // With timeout select { case err = <-channel: - case <-time.After(to): - return nil, fmt.Errorf("test: timeout error after %s", to) + case <-time.After(cfg.Timeout): + conn.Close() + if cfg.ErrOnTimeout { + return nil, fmt.Errorf("test: timeout error after %s", cfg.Timeout) + } } } else { // Without timeout @@ -932,6 +956,9 @@ func (app *App) Test(req *http.Request, timeout ...time.Duration) (*http.Respons // Convert raw http response to *http.Response res, err := http.ReadResponse(buffer, req) if err != nil { + if err == io.ErrUnexpectedEOF { + return nil, fmt.Errorf("test: got empty response") + } return nil, fmt.Errorf("failed to read response: %w", err) } diff --git a/app_test.go b/app_test.go index 6b493de1eb..507fe999be 100644 --- a/app_test.go +++ b/app_test.go @@ -1124,7 +1124,10 @@ func Test_Test_Timeout(t *testing.T) { app.Get("/", testEmptyHandler) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil), -1) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil), TestConfig{ + Timeout: -1, + ErrOnTimeout: false, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") @@ -1133,7 +1136,10 @@ func Test_Test_Timeout(t *testing.T) { return nil }) - _, err = app.Test(httptest.NewRequest(MethodGet, "/timeout", nil), 20*time.Millisecond) + _, err = app.Test(httptest.NewRequest(MethodGet, "/timeout", nil), TestConfig{ + Timeout: 20*time.Millisecond, + ErrOnTimeout: true, + }) require.Error(t, err, "app.Test(req)") } @@ -1432,7 +1438,10 @@ func Test_App_Test_no_timeout_infinitely(t *testing.T) { }) req := httptest.NewRequest(MethodGet, "/", nil) - _, err = app.Test(req, -1) + _, err = app.Test(req, TestConfig{ + Timeout: -1, + ErrOnTimeout: true, + }) }() tk := time.NewTimer(5 * time.Second) @@ -1460,10 +1469,29 @@ func Test_App_Test_timeout(t *testing.T) { return nil }) - _, err := app.Test(httptest.NewRequest(MethodGet, "/", nil), 100*time.Millisecond) + _, err := app.Test(httptest.NewRequest(MethodGet, "/", nil), TestConfig{ + Timeout: 100*time.Millisecond, + ErrOnTimeout: true, + }) require.Equal(t, errors.New("test: timeout error after 100ms"), err) } +func Test_App_Test_timeout_empty_response(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/", func(_ Ctx) error { + time.Sleep(1 * time.Second) + return nil + }) + + _, err := app.Test(httptest.NewRequest(MethodGet, "/", nil), TestConfig{ + Timeout: 100*time.Millisecond, + ErrOnTimeout: false, + }) + require.Equal(t, errors.New("test: got empty response"), err) +} + func Test_App_SetTLSHandler(t *testing.T) { t.Parallel() tlsHandler := &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{ diff --git a/ctx_test.go b/ctx_test.go index a94e4cb42b..5f0e7d023d 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -3142,7 +3142,10 @@ func Test_Static_Compress(t *testing.T) { req := httptest.NewRequest(MethodGet, "/file", nil) req.Header.Set("Accept-Encoding", algo) - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") diff --git a/docs/api/app.md b/docs/api/app.md index ef9c2ea08d..8b563b5c57 100644 --- a/docs/api/app.md +++ b/docs/api/app.md @@ -540,7 +540,7 @@ func (app *App) SetTLSHandler(tlsHandler *TLSHandler) Testing your application is done with the **Test** method. Use this method for creating `_test.go` files or when you need to debug your routing logic. The default timeout is `1s` if you want to disable a timeout altogether, pass `-1` as a second argument. ```go title="Signature" -func (app *App) Test(req *http.Request, msTimeout ...int) (*http.Response, error) +func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, error) ``` ```go title="Examples" diff --git a/middleware/compress/compress_test.go b/middleware/compress/compress_test.go index f258ba4460..5de1127d66 100644 --- a/middleware/compress/compress_test.go +++ b/middleware/compress/compress_test.go @@ -39,7 +39,10 @@ func Test_Compress_Gzip(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "gzip") - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding)) @@ -72,7 +75,10 @@ func Test_Compress_Different_Level(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", algo) - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding)) @@ -99,7 +105,10 @@ func Test_Compress_Deflate(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "deflate") - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "deflate", resp.Header.Get(fiber.HeaderContentEncoding)) @@ -123,7 +132,10 @@ func Test_Compress_Brotli(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "br") - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "br", resp.Header.Get(fiber.HeaderContentEncoding)) @@ -147,7 +159,10 @@ func Test_Compress_Zstd(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "zstd") - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "zstd", resp.Header.Get(fiber.HeaderContentEncoding)) @@ -171,7 +186,10 @@ func Test_Compress_Disabled(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "br") - resp, err := app.Test(req, 10*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 10*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "", resp.Header.Get(fiber.HeaderContentEncoding)) diff --git a/middleware/idempotency/idempotency_test.go b/middleware/idempotency/idempotency_test.go index 91394ca26a..9cf66e7c9d 100644 --- a/middleware/idempotency/idempotency_test.go +++ b/middleware/idempotency/idempotency_test.go @@ -82,7 +82,10 @@ func Test_Idempotency(t *testing.T) { if idempotencyKey != "" { req.Header.Set("X-Idempotency-Key", idempotencyKey) } - resp, err := app.Test(req, 15*time.Second) + resp, err := app.Test(req, fiber.TestConfig{ + Timeout: 15*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 9da675fe8f..2bd52f7138 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -104,7 +104,10 @@ func Test_AuthSources(t *testing.T) { req.URL.Path = r } - res, err := app.Test(req, -1) + res, err := app.Test(req, fiber.TestConfig{ + Timeout: -1, + ErrOnTimeout: false, + }) require.NoError(t, err, test.description) @@ -209,7 +212,10 @@ func TestMultipleKeyLookup(t *testing.T) { q.Add("key", CorrectKey) req.URL.RawQuery = q.Encode() - res, err := app.Test(req, -1) + res, err := app.Test(req, fiber.TestConfig{ + Timeout: -1, + ErrOnTimeout: false, + }) require.NoError(t, err) @@ -226,7 +232,10 @@ func TestMultipleKeyLookup(t *testing.T) { // construct a second request without proper key req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil) require.NoError(t, err) - res, err = app.Test(req, -1) + res, err = app.Test(req, fiber.TestConfig{ + Timeout: -1, + ErrOnTimeout: false, + }) require.NoError(t, err) errBody, err := io.ReadAll(res.Body) require.NoError(t, err) @@ -350,7 +359,10 @@ func Test_MultipleKeyAuth(t *testing.T) { req.Header.Set("key", test.APIKey) } - res, err := app.Test(req, -1) + res, err := app.Test(req, fiber.TestConfig{ + Timeout: -1, + ErrOnTimeout: false, + }) require.NoError(t, err, test.description) diff --git a/middleware/logger/logger_test.go b/middleware/logger/logger_test.go index 0bc06531c9..b6df6d1daf 100644 --- a/middleware/logger/logger_test.go +++ b/middleware/logger/logger_test.go @@ -300,7 +300,10 @@ func Test_Logger_WithLatency(t *testing.T) { sleepDuration = 1 * tu.div // Create a new HTTP request to the test route - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 3*time.Second) + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{ + Timeout: 3*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) @@ -342,7 +345,10 @@ func Test_Logger_WithLatency_DefaultFormat(t *testing.T) { sleepDuration = 1 * tu.div // Create a new HTTP request to the test route - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 2*time.Second) + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{ + Timeout: 2*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) diff --git a/middleware/pprof/pprof_test.go b/middleware/pprof/pprof_test.go index 7a279d488d..5f26e95011 100644 --- a/middleware/pprof/pprof_test.go +++ b/middleware/pprof/pprof_test.go @@ -105,7 +105,10 @@ func Test_Pprof_Subs(t *testing.T) { if sub == "profile" { target += "?seconds=1" } - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), 5*time.Second) + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), fiber.TestConfig{ + Timeout: 5*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) }) @@ -132,7 +135,10 @@ func Test_Pprof_Subs_WithPrefix(t *testing.T) { if sub == "profile" { target += "?seconds=1" } - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), 5*time.Second) + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), fiber.TestConfig{ + Timeout: 5*time.Second, + ErrOnTimeout: true, + }) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) })