diff --git a/network/handlers/drain.go b/network/handlers/drain.go index b8dc6ff1b4..d6ef19f4ea 100644 --- a/network/handlers/drain.go +++ b/network/handlers/drain.go @@ -19,6 +19,7 @@ package handlers import ( "fmt" "net/http" + "strings" "sync" "time" @@ -74,6 +75,10 @@ type Drainer struct { // timer is used to orchestrate the drain. timer timer + + // HealthCheckUAPrefixes are the additional user agent prefixes that trigger the + // drainer's health check + HealthCheckUAPrefixes []string } // Ensure Drainer implements http.Handler @@ -81,7 +86,8 @@ var _ http.Handler = (*Drainer)(nil) // ServeHTTP implements http.Handler func (d *Drainer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if network.IsKubeletProbe(r) { // Respond to probes regardless of path. + // Respond to probes regardless of path. + if d.isHealthCheckRequest(r) { if d.draining() { http.Error(w, "shutting down", http.StatusServiceUnavailable) } else if d.HealthCheck != nil { @@ -124,6 +130,21 @@ func (d *Drainer) Drain() { }) } +// isHealthcheckRequest validates if the request has a user agent that is for healthcheck +func (d *Drainer) isHealthCheckRequest(r *http.Request) bool { + if network.IsKubeletProbe(r) { + return true + } + + for _, ua := range d.HealthCheckUAPrefixes { + if strings.HasPrefix(r.Header.Get(network.UserAgentKey), ua) { + return true + } + } + + return false +} + // reset resets the drain timer to the full amount of time. func (d *Drainer) reset() { if func() bool { diff --git a/network/handlers/drain_test.go b/network/handlers/drain_test.go index 0d0ead0ac7..3970df9253 100644 --- a/network/handlers/drain_test.go +++ b/network/handlers/drain_test.go @@ -319,47 +319,120 @@ func TestDefaultQuietPeriod(t *testing.T) { mt.advance(network.DefaultDrainTimeout) } -func TestHealthCheck(t *testing.T) { - var ( - w http.ResponseWriter - req = &http.Request{} - probe = &http.Request{ +func TestHealthCheckWithProbeType(t *testing.T) { + tests := []struct { + name string + Header http.Header + UserAgents []string + }{{ + name: "with kube-probe header", + Header: http.Header{ + network.UserAgentKey: []string{network.KubeProbeUAPrefix}, + }, + UserAgents: []string{}, + }, { + name: "with extra probe header", + Header: http.Header{ + network.UserAgentKey: []string{"extra"}, + }, + UserAgents: []string{"extra"}, + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var ( + w http.ResponseWriter + req = &http.Request{} + cnt = 0 + inner = http.HandlerFunc(func(http.ResponseWriter, *http.Request) { cnt++ }) + checker = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL != nil && req.URL.Path == "/healthz" { + w.WriteHeader(http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusAccepted) + }) + probe = &http.Request{ + URL: &url.URL{ + Path: "/healthz", + }, + Header: tc.Header, + } + ) + + drainer := &Drainer{ + HealthCheck: checker, + Inner: inner, + HealthCheckUAPrefixes: tc.UserAgents, + } + + // Works before Drain is called. + drainer.ServeHTTP(w, req) + drainer.ServeHTTP(w, req) + drainer.ServeHTTP(w, req) + if cnt != 3 { + t.Error("Inner handler was not properly invoked") + } + + // Works for HealthCheck. + resp := httptest.NewRecorder() + drainer.ServeHTTP(resp, probe) + if got, want := resp.Code, http.StatusBadRequest; got != want { + t.Errorf("Probe status = %d, wanted %d", got, want) + } + }) + } +} + +func TestIsHealthcheckRequest(t *testing.T) { + tests := []struct { + name string + UserAgents []string + request *http.Request + result bool + }{{ + name: "with kube-probe header", + UserAgents: []string{}, + request: &http.Request{ URL: &url.URL{ Path: "/healthz", }, Header: http.Header{ network.UserAgentKey: []string{network.KubeProbeUAPrefix}, }, - } - cnt = 0 - inner = http.HandlerFunc(func(http.ResponseWriter, *http.Request) { cnt++ }) - checker = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.URL != nil && req.URL.Path == "/healthz" { - w.WriteHeader(http.StatusBadRequest) - return + }, + result: true, + }, { + name: "with extra probe header", + UserAgents: []string{"extra"}, + request: &http.Request{ + URL: &url.URL{ + Path: "/healthz", + }, + Header: http.Header{ + network.UserAgentKey: []string{"extra"}, + }, + }, + result: true, + }, { + name: "without probe header", + UserAgents: []string{}, + request: &http.Request{ + URL: &url.URL{ + Path: "/healthz", + }, + Header: http.Header{ + network.UserAgentKey: []string{"not-a-probe"}, + }, + }, + result: false, + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + d := Drainer{ + HealthCheckUAPrefixes: tc.UserAgents, } - w.WriteHeader(http.StatusAccepted) + d.isHealthCheckRequest(tc.request) }) - ) - - drainer := &Drainer{ - HealthCheck: checker, - Inner: inner, - } - - // Works before Drain is called. - drainer.ServeHTTP(w, req) - drainer.ServeHTTP(w, req) - drainer.ServeHTTP(w, req) - if cnt != 3 { - t.Error("Inner handler was not properly invoked") - } - - // Works for HealthCheck. - resp := httptest.NewRecorder() - drainer.ServeHTTP(resp, probe) - if got, want := resp.Code, http.StatusBadRequest; got != want { - t.Errorf("Probe status = %d, wanted %d", got, want) } }