diff --git a/middleware/proxy/README.md b/middleware/proxy/README.md index 0c68b1d86a..6cc03e8b9d 100644 --- a/middleware/proxy/README.md +++ b/middleware/proxy/README.md @@ -12,9 +12,16 @@ Proxy middleware for [Fiber](https://github.com/gofiber/fiber) that allows you t ### Signatures ```go +// Balancer create a load balancer among multiple upstrem servers. func Balancer(config Config) fiber.Handler +// Forward performs the given http request and fills the given http response. func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler +// Do performs the given http request and fills the given http response. func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error +// DomainForward the given http request based on the given domain and fills the given http response +func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler +// BalancerForward performs the given http request based round robin balancer and fills the given http response +func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler ``` ### Examples @@ -23,8 +30,8 @@ Import the middleware package that is part of the Fiber web framework ```go import ( - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/proxy" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/proxy" ) ``` @@ -39,54 +46,64 @@ proxy.WithTlsConfig(&tls.Config{ // if you need to use global self-custom client, you should use proxy.WithClient. proxy.WithClient(&fasthttp.Client{ - NoDefaultUserAgentHeader: true, - DisablePathNormalizing: true, + NoDefaultUserAgentHeader: true, + DisablePathNormalizing: true, }) // Forward to url app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) +// If you want to forward with a specific domain. You have to use proxy.DomainForward. +app.Get("/payments", proxy.DomainForward("docs.gofiber.io", "http://localhost:8000")) + // Forward to url with local custom client app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif", &fasthttp.Client{ - NoDefaultUserAgentHeader: true, - DisablePathNormalizing: true, + NoDefaultUserAgentHeader: true, + DisablePathNormalizing: true, })) // Make request within handler app.Get("/:id", func(c *fiber.Ctx) error { - url := "https://i.imgur.com/"+c.Params("id")+".gif" - if err := proxy.Do(c, url); err != nil { - return err - } - // Remove Server header from response - c.Response().Header.Del(fiber.HeaderServer) - return nil + url := "https://i.imgur.com/"+c.Params("id")+".gif" + if err := proxy.Do(c, url); err != nil { + return err + } + // Remove Server header from response + c.Response().Header.Del(fiber.HeaderServer) + return nil }) // Minimal round robin balancer app.Use(proxy.Balancer(proxy.Config{ - Servers: []string{ - "http://localhost:3001", - "http://localhost:3002", - "http://localhost:3003", - }, + Servers: []string{ + "http://localhost:3001", + "http://localhost:3002", + "http://localhost:3003", + }, })) // Or extend your balancer for customization app.Use(proxy.Balancer(proxy.Config{ - Servers: []string{ - "http://localhost:3001", - "http://localhost:3002", - "http://localhost:3003", - }, - ModifyRequest: func(c *fiber.Ctx) error { - c.Request().Header.Add("X-Real-IP", c.IP()) - return nil - }, - ModifyResponse: func(c *fiber.Ctx) error { - c.Response().Header.Del(fiber.HeaderServer) - return nil - }, + Servers: []string{ + "http://localhost:3001", + "http://localhost:3002", + "http://localhost:3003", + }, + ModifyRequest: func(c *fiber.Ctx) error { + c.Request().Header.Add("X-Real-IP", c.IP()) + return nil + }, + ModifyResponse: func(c *fiber.Ctx) error { + c.Response().Header.Del(fiber.HeaderServer) + return nil + }, +})) + +// Or this way if the balancer is using https and the destination server is only using http. +app.Use(proxy.BalancerForward([]string{ + "http://localhost:3001", + "http://localhost:3002", + "http://localhost:3003", })) ``` @@ -95,50 +112,50 @@ app.Use(proxy.Balancer(proxy.Config{ ```go // Config defines the config for middleware. type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Servers defines a list of :// HTTP servers, - // - // which are used in a round-robin manner. - // i.e.: "https://foobar.com, http://www.foobar.com" - // - // Required - Servers []string - - // ModifyRequest allows you to alter the request - // - // Optional. Default: nil - ModifyRequest fiber.Handler - - // ModifyResponse allows you to alter the response - // - // Optional. Default: nil - ModifyResponse fiber.Handler - - // Timeout is the request timeout used when calling the proxy client - // - // Optional. Default: 1 second - Timeout time.Duration - - // Per-connection buffer size for requests' reading. - // This also limits the maximum header size. - // Increase this buffer if your clients send multi-KB RequestURIs - // and/or multi-KB headers (for example, BIG cookies). - ReadBufferSize int + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Servers defines a list of :// HTTP servers, + // + // which are used in a round-robin manner. + // i.e.: "https://foobar.com, http://www.foobar.com" + // + // Required + Servers []string + + // ModifyRequest allows you to alter the request + // + // Optional. Default: nil + ModifyRequest fiber.Handler + + // ModifyResponse allows you to alter the response + // + // Optional. Default: nil + ModifyResponse fiber.Handler + + // Timeout is the request timeout used when calling the proxy client + // + // Optional. Default: 1 second + Timeout time.Duration + + // Per-connection buffer size for requests' reading. + // This also limits the maximum header size. + // Increase this buffer if your clients send multi-KB RequestURIs + // and/or multi-KB headers (for example, BIG cookies). + ReadBufferSize int - // Per-connection buffer size for responses' writing. - WriteBufferSize int - - // tls config for the http client. - TlsConfig *tls.Config - - // Client is custom client when client config is complex. - // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig - // will not be used if the client are set. - Client *fasthttp.LBClient + // Per-connection buffer size for responses' writing. + WriteBufferSize int + + // tls config for the http client. + TlsConfig *tls.Config + + // Client is custom client when client config is complex. + // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig + // will not be used if the client are set. + Client *fasthttp.LBClient } ``` diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 830d71f985..008342631f 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -178,3 +178,53 @@ func getScheme(uri []byte) []byte { } return uri[:i-1] } + +// DomainForward performs an http request based on the given domain and populates the given http response. +// This method will return an fiber.Handler +func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler { + return func(c *fiber.Ctx) error { + host := string(c.Request().Host()) + if host == hostname { + return Do(c, addr+c.OriginalURL(), clients...) + } + return nil + } +} + +type roundrobin struct { + sync.Mutex + + current int + pool []string +} + +// this method will return a string of addr server from list server. +func (r *roundrobin) get() string { + r.Lock() + defer r.Unlock() + + if r.current >= len(r.pool) { + r.current %= len(r.pool) + } + + result := r.pool[r.current] + r.current++ + return result +} + +// BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response. +// This method will return an fiber.Handler +func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler { + r := &roundrobin{ + current: 0, + pool: servers, + } + return func(c *fiber.Ctx) error { + server := r.get() + if !strings.HasPrefix(server, "http") { + server = "http://" + server + } + c.Request().Header.Add("X-Real-IP", c.IP()) + return Do(c, server+c.OriginalURL(), clients...) + } +} diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index d87145514a..6ed169ed69 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -473,3 +473,60 @@ func Test_ProxyBalancer_Custom_Client(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) } + +// go test -run Test_Proxy_Domain_Forward_Local +func Test_Proxy_Domain_Forward_Local(t *testing.T) { + t.Parallel() + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + utils.AssertEqual(t, nil, err) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + // target server + ln1, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + utils.AssertEqual(t, nil, err) + app1 := fiber.New(fiber.Config{DisableStartupMessage: true}) + + app1.Get("/test", func(c *fiber.Ctx) error { + return c.SendString("test_local_client:" + c.Query("query_test")) + }) + + proxyAddr := ln.Addr().String() + targetAddr := ln1.Addr().String() + localDomain := strings.Replace(proxyAddr, "127.0.0.1", "localhost", 1) + app.Use(DomainForward(localDomain, "http://"+targetAddr, &fasthttp.Client{ + NoDefaultUserAgentHeader: true, + DisablePathNormalizing: true, + + Dial: fasthttp.Dial, + })) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + go func() { utils.AssertEqual(t, nil, app1.Listener(ln1)) }() + + code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String() + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, fiber.StatusOK, code) + utils.AssertEqual(t, "test_local_client:true", body) +} + +// go test -run Test_Proxy_Balancer_Forward_Local +func Test_Proxy_Balancer_Forward_Local(t *testing.T) { + t.Parallel() + + app := fiber.New() + + _, addr := createProxyTestServer(t, func(c *fiber.Ctx) error { + return c.SendString("forwarded") + }) + + app.Use(BalancerForward([]string{addr})) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + + utils.AssertEqual(t, string(b), "forwarded") +}