From d4153992cd311ff1057b3a73cc85f315006ea005 Mon Sep 17 00:00:00 2001 From: yinheli Date: Sat, 14 Sep 2024 23:04:15 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=A9=B9=20Fix:=20static=20server=20in=20su?= =?UTF-8?q?b=20app=20with=20mount=20(#3104)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mount.go | 11 +++++++++++ router.go | 5 +++++ router_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/mount.go b/mount.go index abb5695e9f..c154f08ca2 100644 --- a/mount.go +++ b/mount.go @@ -23,6 +23,8 @@ type mountFields struct { subAppsProcessed sync.Once // Prefix of app if it was mounted mountPath string + // Parent app of the current app + parentApp *App } // Create empty mountFields instance @@ -50,6 +52,7 @@ func (app *App) Mount(prefix string, subApp *App) Router { subApp.mountFields.mountPath = path app.mountFields.appList[path] = subApp + subApp.mountFields.parentApp = app } // register mounted group @@ -99,6 +102,14 @@ func (app *App) MountPath() string { return app.mountFields.mountPath } +// FullMountPath returns the full mount path of the app, including the parent app's mount path. +func (app *App) FullMountPath() string { + if app.mountFields.parentApp == nil { + return app.mountFields.mountPath + } + return getGroupPath(app.mountFields.parentApp.FullMountPath(), app.mountFields.mountPath) +} + // hasMountedApps Checks if there are any mounted apps in the current application. func (app *App) hasMountedApps() bool { return len(app.mountFields.appList) > 1 diff --git a/router.go b/router.go index 4afa741537..c8772f9bb9 100644 --- a/router.go +++ b/router.go @@ -5,6 +5,7 @@ package fiber import ( + "bytes" "fmt" "html" "sort" @@ -357,6 +358,10 @@ func (app *App) registerStatic(prefix, root string, config ...Static) { IndexNames: []string{"index.html"}, PathRewrite: func(fctx *fasthttp.RequestCtx) []byte { path := fctx.Path() + mountPath := app.FullMountPath() + if n := len(mountPath); n > 0 && bytes.Equal(path[:n], utils.UnsafeBytes(mountPath)) { + path = path[n:] + } if len(path) >= prefixLen { if isStar && app.getString(path[0:prefixLen]) == prefix { path = append(path[0:0], '/') diff --git a/router_test.go b/router_test.go index 6a43db5937..675a93ccea 100644 --- a/router_test.go +++ b/router_test.go @@ -2,7 +2,7 @@ // 📃 Github Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io -//nolint:bodyclose // Much easier to just ignore memory leaks in tests +//nolint:bodyclose,goconst // Much easier to just ignore memory leaks in tests and constant variables in tests package fiber import ( @@ -471,6 +471,53 @@ func Test_Route_Static_HasPrefix(t *testing.T) { body, err = io.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + + app = New() + app.Static("/css", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/css/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) +} + +func Test_Route_Static_SubApp(t *testing.T) { + t.Parallel() + + dir := "./.github/testdata/fs/css" + app := New() + + // subapp + subApp := New() + subApp.Static("/css", dir) + app.Mount("/sub", subApp) + + // nested subapp + nestApp := New() + nestApp.Static("/css", dir) + subApp.Mount("/nest", nestApp) + + // test subapp + resp, err := app.Test(httptest.NewRequest(MethodGet, "/sub/css/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + + // test nested subapp + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/sub/nest/css/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) } func Test_Router_NotFound(t *testing.T) {