-
Notifications
You must be signed in to change notification settings - Fork 0
/
httpgraceful.go
137 lines (116 loc) · 3.64 KB
/
httpgraceful.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package httpgraceful
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
// Server is a contract for server that can be started and shutdown gracefully.
type Server interface {
// ListenAndServe starts listening and serving the server.
// This method should block until shutdown signal received or failed to start.
ListenAndServe() error
// Shutdown gracefully shuts down the server, it will wait for all active connections to be closed.
Shutdown(ctx context.Context) error
// Close force closes the server.
// Close is called when Shutdown timeout exceeded.
Close() error
}
// gracefulServer is a wrapper of http.Server that can be shutdown gracefully.
type gracefulServer struct {
Server
signalListener chan os.Signal
waitTimeout time.Duration
shutdownDone chan struct{}
}
// Option is a function to configure gracefulServer.
type Option func(*gracefulServer)
func (f Option) apply(gs *gracefulServer) { f(gs) }
// WithSignals sets the signals that will be listened to initiate shutdown.
func WithSignals(signals ...os.Signal) Option {
return func(s *gracefulServer) {
signalListener := make(chan os.Signal, 1)
signal.Notify(signalListener, signals...)
s.signalListener = signalListener
}
}
// WithWaitTimeout sets the timeout for waiting active connections to be closed.
func WithWaitTimeout(timeout time.Duration) Option {
return func(s *gracefulServer) {
s.waitTimeout = timeout
}
}
// WrapServer wraps a Server with graceful shutdown capability.
// It will listen to SIGINT and SIGTERM signals to initiate shutdown and
// wait for all active connections to be closed. If still active connections
// after wait timeout exceeded, it will force close the server. The default
// wait timeout is 5 seconds.
func WrapServer(server Server, opts ...Option) Server {
gs := gracefulServer{
Server: server,
shutdownDone: make(chan struct{}),
}
for _, opt := range opts {
opt.apply(&gs)
}
if gs.signalListener == nil {
WithSignals(syscall.SIGTERM, syscall.SIGINT).apply(&gs)
}
if gs.waitTimeout <= 0 {
WithWaitTimeout(5 * time.Second).apply(&gs)
}
return &gs
}
// ListenAndServe starts listening and serving the server gracefully.
func (s *gracefulServer) ListenAndServe() error {
serverErr := make(chan error, 1)
shutdownCompleted := make(chan struct{})
// start the original server.
go func() {
err := s.Server.ListenAndServe()
// if shutdown succeeded, http.ErrServerClosed will be returned.
if errors.Is(err, http.ErrServerClosed) {
shutdownCompleted <- struct{}{}
} else {
// only send error if it's not http.ErrServerClosed.
serverErr <- err
}
}()
// block until signalListener received or mux failed to start.
select {
case sig := <-s.signalListener:
ctx, cancel := context.WithTimeout(context.Background(), s.waitTimeout)
defer cancel()
err := s.Server.Shutdown(ctx)
// only force shutdown if deadline exceeded.
if errors.Is(err, context.DeadlineExceeded) {
closeErr := s.Server.Close()
if closeErr != nil {
return fmt.Errorf("deadline exceeded, force shutdown failed: %w", closeErr)
}
// force shutdown succeeded.
return nil
}
// unexpected error.
if err != nil {
return fmt.Errorf("shutdown failed, signal: %s: %w", sig, err)
}
// make sure shutdown completed.
<-shutdownCompleted
return nil
case err := <-serverErr:
return fmt.Errorf("server failed to start: %w", err)
}
}
// SendSignal sends a signal to the server, if signal is one of registered signals,
// shutdown will be triggered.
// this useful for testing.
func (s *gracefulServer) SendSignal(sig os.Signal) {
if s.signalListener != nil {
s.signalListener <- sig
}
}