graceful.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. package graceful
  2. import (
  3. "crypto/tls"
  4. "log"
  5. "net"
  6. "net/http"
  7. "os"
  8. "os/signal"
  9. "sync"
  10. "syscall"
  11. "time"
  12. )
  13. // Server wraps an http.Server with graceful connection handling.
  14. // It may be used directly in the same way as http.Server, or may
  15. // be constructed with the global functions in this package.
  16. //
  17. // Example:
  18. // srv := &graceful.Server{
  19. // Timeout: 5 * time.Second,
  20. // Server: &http.Server{Addr: ":1234", Handler: handler},
  21. // }
  22. // srv.ListenAndServe()
  23. type Server struct {
  24. *http.Server
  25. // Timeout is the duration to allow outstanding requests to survive
  26. // before forcefully terminating them.
  27. Timeout time.Duration
  28. // Limit the number of outstanding requests
  29. ListenLimit int
  30. // TCPKeepAlive sets the TCP keep-alive timeouts on accepted
  31. // connections. It prunes dead TCP connections ( e.g. closing
  32. // laptop mid-download)
  33. TCPKeepAlive time.Duration
  34. // ConnState specifies an optional callback function that is
  35. // called when a client connection changes state. This is a proxy
  36. // to the underlying http.Server's ConnState, and the original
  37. // must not be set directly.
  38. ConnState func(net.Conn, http.ConnState)
  39. // BeforeShutdown is an optional callback function that is called
  40. // before the listener is closed. Returns true if shutdown is allowed
  41. BeforeShutdown func() bool
  42. // ShutdownInitiated is an optional callback function that is called
  43. // when shutdown is initiated. It can be used to notify the client
  44. // side of long lived connections (e.g. websockets) to reconnect.
  45. ShutdownInitiated func()
  46. // NoSignalHandling prevents graceful from automatically shutting down
  47. // on SIGINT and SIGTERM. If set to true, you must shut down the server
  48. // manually with Stop().
  49. NoSignalHandling bool
  50. // Logger used to notify of errors on startup and on stop.
  51. Logger *log.Logger
  52. // LogFunc can be assigned with a logging function of your choice, allowing
  53. // you to use whatever logging approach you would like
  54. LogFunc func(format string, args ...interface{})
  55. // Interrupted is true if the server is handling a SIGINT or SIGTERM
  56. // signal and is thus shutting down.
  57. Interrupted bool
  58. // interrupt signals the listener to stop serving connections,
  59. // and the server to shut down.
  60. interrupt chan os.Signal
  61. // stopLock is used to protect against concurrent calls to Stop
  62. stopLock sync.Mutex
  63. // stopChan is the channel on which callers may block while waiting for
  64. // the server to stop.
  65. stopChan chan struct{}
  66. // chanLock is used to protect access to the various channel constructors.
  67. chanLock sync.RWMutex
  68. // connections holds all connections managed by graceful
  69. connections map[net.Conn]struct{}
  70. // idleConnections holds all idle connections managed by graceful
  71. idleConnections map[net.Conn]struct{}
  72. }
  73. // Run serves the http.Handler with graceful shutdown enabled.
  74. //
  75. // timeout is the duration to wait until killing active requests and stopping the server.
  76. // If timeout is 0, the server never times out. It waits for all active requests to finish.
  77. func Run(addr string, timeout time.Duration, n http.Handler) {
  78. srv := &Server{
  79. Timeout: timeout,
  80. TCPKeepAlive: 3 * time.Minute,
  81. Server: &http.Server{Addr: addr, Handler: n},
  82. // Logger: DefaultLogger(),
  83. }
  84. if err := srv.ListenAndServe(); err != nil {
  85. if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") {
  86. srv.logf("%s", err)
  87. os.Exit(1)
  88. }
  89. }
  90. }
  91. // RunWithErr is an alternative version of Run function which can return error.
  92. //
  93. // Unlike Run this version will not exit the program if an error is encountered but will
  94. // return it instead.
  95. func RunWithErr(addr string, timeout time.Duration, n http.Handler) error {
  96. srv := &Server{
  97. Timeout: timeout,
  98. TCPKeepAlive: 3 * time.Minute,
  99. Server: &http.Server{Addr: addr, Handler: n},
  100. Logger: DefaultLogger(),
  101. }
  102. return srv.ListenAndServe()
  103. }
  104. // ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled.
  105. //
  106. // timeout is the duration to wait until killing active requests and stopping the server.
  107. // If timeout is 0, the server never times out. It waits for all active requests to finish.
  108. func ListenAndServe(server *http.Server, timeout time.Duration) error {
  109. srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
  110. return srv.ListenAndServe()
  111. }
  112. // ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled.
  113. func (srv *Server) ListenAndServe() error {
  114. // Create the listener so we can control their lifetime
  115. addr := srv.Addr
  116. if addr == "" {
  117. addr = ":http"
  118. }
  119. conn, err := srv.newTCPListener(addr)
  120. if err != nil {
  121. return err
  122. }
  123. return srv.Serve(conn)
  124. }
  125. // ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled.
  126. //
  127. // timeout is the duration to wait until killing active requests and stopping the server.
  128. // If timeout is 0, the server never times out. It waits for all active requests to finish.
  129. func ListenAndServeTLS(server *http.Server, certFile, keyFile string, timeout time.Duration) error {
  130. srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
  131. return srv.ListenAndServeTLS(certFile, keyFile)
  132. }
  133. // ListenTLS is a convenience method that creates an https listener using the
  134. // provided cert and key files. Use this method if you need access to the
  135. // listener object directly. When ready, pass it to the Serve method.
  136. func (srv *Server) ListenTLS(certFile, keyFile string) (net.Listener, error) {
  137. // Create the listener ourselves so we can control its lifetime
  138. addr := srv.Addr
  139. if addr == "" {
  140. addr = ":https"
  141. }
  142. config := &tls.Config{}
  143. if srv.TLSConfig != nil {
  144. *config = *srv.TLSConfig
  145. }
  146. var err error
  147. config.Certificates = make([]tls.Certificate, 1)
  148. config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  149. if err != nil {
  150. return nil, err
  151. }
  152. // Enable http2
  153. enableHTTP2ForTLSConfig(config)
  154. conn, err := srv.newTCPListener(addr)
  155. if err != nil {
  156. return nil, err
  157. }
  158. srv.TLSConfig = config
  159. tlsListener := tls.NewListener(conn, config)
  160. return tlsListener, nil
  161. }
  162. // Enable HTTP2ForTLSConfig explicitly enables http/2 for a TLS Config. This is due to changes in Go 1.7 where
  163. // http servers are no longer automatically configured to enable http/2 if the server's TLSConfig is set.
  164. // See https://github.com/golang/go/issues/15908
  165. func enableHTTP2ForTLSConfig(t *tls.Config) {
  166. if TLSConfigHasHTTP2Enabled(t) {
  167. return
  168. }
  169. t.NextProtos = append(t.NextProtos, "h2")
  170. }
  171. // TLSConfigHasHTTP2Enabled checks to see if a given TLS Config has http2 enabled.
  172. func TLSConfigHasHTTP2Enabled(t *tls.Config) bool {
  173. for _, value := range t.NextProtos {
  174. if value == "h2" {
  175. return true
  176. }
  177. }
  178. return false
  179. }
  180. // ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled.
  181. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
  182. l, err := srv.ListenTLS(certFile, keyFile)
  183. if err != nil {
  184. return err
  185. }
  186. return srv.Serve(l)
  187. }
  188. // ListenAndServeTLSConfig can be used with an existing TLS config and is equivalent to
  189. // http.Server.ListenAndServeTLS with graceful shutdown enabled,
  190. func (srv *Server) ListenAndServeTLSConfig(config *tls.Config) error {
  191. addr := srv.Addr
  192. if addr == "" {
  193. addr = ":https"
  194. }
  195. conn, err := srv.newTCPListener(addr)
  196. if err != nil {
  197. return err
  198. }
  199. srv.TLSConfig = config
  200. tlsListener := tls.NewListener(conn, config)
  201. return srv.Serve(tlsListener)
  202. }
  203. // Serve is equivalent to http.Server.Serve with graceful shutdown enabled.
  204. //
  205. // timeout is the duration to wait until killing active requests and stopping the server.
  206. // If timeout is 0, the server never times out. It waits for all active requests to finish.
  207. func Serve(server *http.Server, l net.Listener, timeout time.Duration) error {
  208. srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()}
  209. return srv.Serve(l)
  210. }
  211. // Serve is equivalent to http.Server.Serve with graceful shutdown enabled.
  212. func (srv *Server) Serve(listener net.Listener) error {
  213. if srv.ListenLimit != 0 {
  214. listener = LimitListener(listener, srv.ListenLimit)
  215. }
  216. // Make our stopchan
  217. srv.StopChan()
  218. // Track connection state
  219. add := make(chan net.Conn)
  220. idle := make(chan net.Conn)
  221. active := make(chan net.Conn)
  222. remove := make(chan net.Conn)
  223. srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
  224. switch state {
  225. case http.StateNew:
  226. add <- conn
  227. case http.StateActive:
  228. active <- conn
  229. case http.StateIdle:
  230. idle <- conn
  231. case http.StateClosed, http.StateHijacked:
  232. remove <- conn
  233. }
  234. srv.stopLock.Lock()
  235. defer srv.stopLock.Unlock()
  236. if srv.ConnState != nil {
  237. srv.ConnState(conn, state)
  238. }
  239. }
  240. // Manage open connections
  241. shutdown := make(chan chan struct{})
  242. kill := make(chan struct{})
  243. go srv.manageConnections(add, idle, active, remove, shutdown, kill)
  244. interrupt := srv.interruptChan()
  245. // Set up the interrupt handler
  246. if !srv.NoSignalHandling {
  247. signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
  248. }
  249. quitting := make(chan struct{})
  250. go srv.handleInterrupt(interrupt, quitting, listener)
  251. // Serve with graceful listener.
  252. // Execution blocks here until listener.Close() is called, above.
  253. err := srv.Server.Serve(listener)
  254. if err != nil {
  255. // If the underlying listening is closed, Serve returns an error
  256. // complaining about listening on a closed socket. This is expected, so
  257. // let's ignore the error if we are the ones who explicitly closed the
  258. // socket.
  259. select {
  260. case <-quitting:
  261. err = nil
  262. default:
  263. }
  264. }
  265. srv.shutdown(shutdown, kill)
  266. return err
  267. }
  268. // Stop instructs the type to halt operations and close
  269. // the stop channel when it is finished.
  270. //
  271. // timeout is grace period for which to wait before shutting
  272. // down the server. The timeout value passed here will override the
  273. // timeout given when constructing the server, as this is an explicit
  274. // command to stop the server.
  275. func (srv *Server) Stop(timeout time.Duration) {
  276. srv.stopLock.Lock()
  277. defer srv.stopLock.Unlock()
  278. srv.Timeout = timeout
  279. interrupt := srv.interruptChan()
  280. interrupt <- syscall.SIGINT
  281. }
  282. // StopChan gets the stop channel which will block until
  283. // stopping has completed, at which point it is closed.
  284. // Callers should never close the stop channel.
  285. func (srv *Server) StopChan() <-chan struct{} {
  286. srv.chanLock.Lock()
  287. defer srv.chanLock.Unlock()
  288. if srv.stopChan == nil {
  289. srv.stopChan = make(chan struct{})
  290. }
  291. return srv.stopChan
  292. }
  293. // DefaultLogger returns the logger used by Run, RunWithErr, ListenAndServe, ListenAndServeTLS and Serve.
  294. // The logger outputs to STDERR by default.
  295. func DefaultLogger() *log.Logger {
  296. return log.New(os.Stderr, "[graceful] ", 0)
  297. }
  298. func (srv *Server) manageConnections(add, idle, active, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) {
  299. var done chan struct{}
  300. srv.connections = map[net.Conn]struct{}{}
  301. srv.idleConnections = map[net.Conn]struct{}{}
  302. for {
  303. select {
  304. case conn := <-add:
  305. srv.connections[conn] = struct{}{}
  306. case conn := <-idle:
  307. srv.idleConnections[conn] = struct{}{}
  308. case conn := <-active:
  309. delete(srv.idleConnections, conn)
  310. case conn := <-remove:
  311. delete(srv.connections, conn)
  312. delete(srv.idleConnections, conn)
  313. if done != nil && len(srv.connections) == 0 {
  314. done <- struct{}{}
  315. return
  316. }
  317. case done = <-shutdown:
  318. if len(srv.connections) == 0 && len(srv.idleConnections) == 0 {
  319. done <- struct{}{}
  320. return
  321. }
  322. // a shutdown request has been received. if we have open idle
  323. // connections, we must close all of them now. this prevents idle
  324. // connections from holding the server open while waiting for them to
  325. // hit their idle timeout.
  326. for k := range srv.idleConnections {
  327. if err := k.Close(); err != nil {
  328. srv.logf("[ERROR] %s", err)
  329. }
  330. }
  331. case <-kill:
  332. srv.stopLock.Lock()
  333. defer srv.stopLock.Unlock()
  334. srv.Server.ConnState = nil
  335. for k := range srv.connections {
  336. if err := k.Close(); err != nil {
  337. srv.logf("[ERROR] %s", err)
  338. }
  339. }
  340. return
  341. }
  342. }
  343. }
  344. func (srv *Server) interruptChan() chan os.Signal {
  345. srv.chanLock.Lock()
  346. defer srv.chanLock.Unlock()
  347. if srv.interrupt == nil {
  348. srv.interrupt = make(chan os.Signal, 1)
  349. }
  350. return srv.interrupt
  351. }
  352. func (srv *Server) handleInterrupt(interrupt chan os.Signal, quitting chan struct{}, listener net.Listener) {
  353. for _ = range interrupt {
  354. if srv.Interrupted {
  355. srv.logf("already shutting down")
  356. continue
  357. }
  358. srv.logf("shutdown initiated")
  359. srv.Interrupted = true
  360. if srv.BeforeShutdown != nil {
  361. if !srv.BeforeShutdown() {
  362. srv.Interrupted = false
  363. continue
  364. }
  365. }
  366. close(quitting)
  367. srv.SetKeepAlivesEnabled(false)
  368. if err := listener.Close(); err != nil {
  369. srv.logf("[ERROR] %s", err)
  370. }
  371. if srv.ShutdownInitiated != nil {
  372. srv.ShutdownInitiated()
  373. }
  374. }
  375. }
  376. func (srv *Server) logf(format string, args ...interface{}) {
  377. if srv.LogFunc != nil {
  378. srv.LogFunc(format, args...)
  379. } else if srv.Logger != nil {
  380. srv.Logger.Printf(format, args...)
  381. }
  382. }
  383. func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) {
  384. // Request done notification
  385. done := make(chan struct{})
  386. shutdown <- done
  387. if srv.Timeout > 0 {
  388. select {
  389. case <-done:
  390. case <-time.After(srv.Timeout):
  391. close(kill)
  392. }
  393. } else {
  394. <-done
  395. }
  396. // Close the stopChan to wake up any blocked goroutines.
  397. srv.chanLock.Lock()
  398. if srv.stopChan != nil {
  399. close(srv.stopChan)
  400. }
  401. srv.chanLock.Unlock()
  402. }
  403. func (srv *Server) newTCPListener(addr string) (net.Listener, error) {
  404. conn, err := net.Listen("tcp", addr)
  405. if err != nil {
  406. return conn, err
  407. }
  408. if srv.TCPKeepAlive != 0 {
  409. conn = keepAliveListener{conn, srv.TCPKeepAlive}
  410. }
  411. return conn, nil
  412. }