websocket.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package xfer
  2. import (
  3. "io"
  4. "net/http"
  5. "sync"
  6. "time"
  7. "github.com/gorilla/websocket"
  8. log "github.com/sirupsen/logrus"
  9. "github.com/ugorji/go/codec"
  10. "github.com/weaveworks/common/mtime"
  11. )
  12. const (
  13. // Time allowed to write a message to the peer.
  14. writeWait = 10 * time.Second
  15. // Time allowed to read the next pong message from the peer. Needs to be less
  16. // than the idle timeout on whatever frontend server is proxying the
  17. // websocket connections (e.g. nginx).
  18. pongWait = 60 * time.Second
  19. // Send pings to peer with this period. Must be less than pongWait. The peer
  20. // must respond with a pong in < pongWait. But it may take writeWait for the
  21. // pong to be sent. Therefore we want to allow time for that, and a bit of
  22. // delay/round-trip in case the peer is busy. 1/3 of pongWait seems like a
  23. // reasonable amount of time to respond to a ping.
  24. pingPeriod = ((pongWait - writeWait) * 2 / 3)
  25. )
  26. // Websocket exposes the bits of *websocket.Conn we actually use.
  27. type Websocket interface {
  28. ReadMessage() (messageType int, p []byte, err error)
  29. WriteMessage(messageType int, data []byte) error
  30. ReadJSON(v interface{}) error
  31. WriteJSON(v interface{}) error
  32. Close() error
  33. }
  34. type pingingWebsocket struct {
  35. pinger *time.Timer
  36. readLock sync.Mutex
  37. writeLock sync.Mutex
  38. conn *websocket.Conn
  39. }
  40. var upgrader = websocket.Upgrader{
  41. CheckOrigin: func(r *http.Request) bool { return true },
  42. }
  43. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
  44. func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Websocket, error) {
  45. wsConn, err := upgrader.Upgrade(w, r, responseHeader)
  46. if err != nil {
  47. return nil, err
  48. }
  49. return Ping(wsConn), nil
  50. }
  51. // WSDialer can dial a new websocket
  52. type WSDialer interface {
  53. Dial(urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error)
  54. }
  55. // DialWS creates a new client connection. Use requestHeader to specify the
  56. // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
  57. // Use the response.Header to get the selected subprotocol
  58. // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
  59. func DialWS(d WSDialer, urlStr string, requestHeader http.Header) (Websocket, *http.Response, error) {
  60. wsConn, resp, err := d.Dial(urlStr, requestHeader)
  61. if err != nil {
  62. return nil, resp, err
  63. }
  64. return Ping(wsConn), resp, nil
  65. }
  66. // Ping adds a periodic ping to a websocket connection.
  67. func Ping(c *websocket.Conn) Websocket {
  68. p := &pingingWebsocket{conn: c}
  69. p.conn.SetPongHandler(p.pong)
  70. p.conn.SetReadDeadline(mtime.Now().Add(pongWait))
  71. p.pinger = time.AfterFunc(pingPeriod, p.ping)
  72. return p
  73. }
  74. func (p *pingingWebsocket) ping() {
  75. p.writeLock.Lock()
  76. defer p.writeLock.Unlock()
  77. if err := p.conn.WriteControl(websocket.PingMessage, nil, mtime.Now().Add(writeWait)); err != nil {
  78. log.Errorf("websocket ping error: %v", err)
  79. p.conn.Close()
  80. return
  81. }
  82. p.pinger.Reset(pingPeriod)
  83. }
  84. func (p *pingingWebsocket) pong(string) error {
  85. p.conn.SetReadDeadline(mtime.Now().Add(pongWait))
  86. return nil
  87. }
  88. // ReadMessage is a helper method for getting a reader using NextReader and
  89. // reading from that reader to a buffer.
  90. func (p *pingingWebsocket) ReadMessage() (int, []byte, error) {
  91. p.readLock.Lock()
  92. defer p.readLock.Unlock()
  93. return p.conn.ReadMessage()
  94. }
  95. // WriteMessage is a helper method for getting a writer using NextWriter,
  96. // writing the message and closing the writer.
  97. func (p *pingingWebsocket) WriteMessage(messageType int, data []byte) error {
  98. p.writeLock.Lock()
  99. defer p.writeLock.Unlock()
  100. if err := p.conn.SetWriteDeadline(mtime.Now().Add(writeWait)); err != nil {
  101. return err
  102. }
  103. return p.conn.WriteMessage(messageType, data)
  104. }
  105. // WriteJSON writes the JSON encoding of v to the connection.
  106. func (p *pingingWebsocket) WriteJSON(v interface{}) error {
  107. p.writeLock.Lock()
  108. defer p.writeLock.Unlock()
  109. w, err := p.conn.NextWriter(websocket.TextMessage)
  110. if err != nil {
  111. return err
  112. }
  113. if err := p.conn.SetWriteDeadline(mtime.Now().Add(writeWait)); err != nil {
  114. return err
  115. }
  116. err1 := codec.NewEncoder(w, &codec.JsonHandle{}).Encode(v)
  117. err2 := w.Close()
  118. if err1 != nil {
  119. return err1
  120. }
  121. return err2
  122. }
  123. // ReadJSON reads the next JSON-encoded message from the connection and stores
  124. // it in the value pointed to by v.
  125. func (p *pingingWebsocket) ReadJSON(v interface{}) error {
  126. p.readLock.Lock()
  127. defer p.readLock.Unlock()
  128. _, r, err := p.conn.NextReader()
  129. if err != nil {
  130. return err
  131. }
  132. err = codec.NewDecoder(r, &codec.JsonHandle{}).Decode(v)
  133. if err == io.EOF {
  134. // One value is expected in the message.
  135. err = io.ErrUnexpectedEOF
  136. }
  137. return err
  138. }
  139. // Close closes the connection
  140. func (p *pingingWebsocket) Close() error {
  141. p.writeLock.Lock()
  142. defer p.writeLock.Unlock()
  143. p.pinger.Stop()
  144. return p.conn.Close()
  145. }
  146. // IsExpectedWSCloseError returns boolean indicating whether the error is a
  147. // clean disconnection.
  148. func IsExpectedWSCloseError(err error) bool {
  149. return err == io.EOF || err == io.ErrClosedPipe || websocket.IsCloseError(err,
  150. websocket.CloseNormalClosure,
  151. websocket.CloseGoingAway,
  152. websocket.CloseNoStatusReceived,
  153. websocket.CloseAbnormalClosure,
  154. )
  155. }