sqs_control_router.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package multitenant
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "math/rand"
  7. "sync"
  8. "time"
  9. "context"
  10. "github.com/aws/aws-sdk-go/aws"
  11. "github.com/aws/aws-sdk-go/aws/session"
  12. "github.com/aws/aws-sdk-go/service/sqs"
  13. "github.com/prometheus/client_golang/prometheus"
  14. log "github.com/sirupsen/logrus"
  15. "github.com/weaveworks/common/instrument"
  16. "github.com/weaveworks/scope/app"
  17. "github.com/weaveworks/scope/common/xfer"
  18. )
  19. var (
  20. longPollTime = aws.Int64(10)
  21. sqsRequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
  22. Namespace: "scope",
  23. Name: "sqs_request_duration_seconds",
  24. Help: "Time in seconds spent doing SQS requests.",
  25. Buckets: prometheus.DefBuckets,
  26. }, []string{"method", "status_code"})
  27. )
  28. func registerSQSMetrics() {
  29. prometheus.MustRegister(sqsRequestDuration)
  30. }
  31. var registerSQSMetricsOnce sync.Once
  32. // sqsControlRouter:
  33. // Creates a queue for every probe that connects to it, and a queue for
  34. // responses back to it. When it receives a request, posts it to the
  35. // probe queue. When probe receives a request, handles it and posts the
  36. // response back to the response queue.
  37. type sqsControlRouter struct {
  38. service *sqs.SQS
  39. responseQueueURL *string
  40. userIDer UserIDer
  41. prefix string
  42. rpcTimeout time.Duration
  43. mtx sync.Mutex
  44. responses map[string]chan xfer.Response
  45. probeWorkers map[int64]*probeWorker
  46. }
  47. type sqsRequestMessage struct {
  48. ID string
  49. Request xfer.Request
  50. ResponseQueueURL string
  51. }
  52. type sqsResponseMessage struct {
  53. ID string
  54. Response xfer.Response
  55. }
  56. // NewSQSControlRouter the harbinger of death
  57. func NewSQSControlRouter(config *aws.Config, userIDer UserIDer, prefix string, rpcTimeout time.Duration) app.ControlRouter {
  58. registerSQSMetricsOnce.Do(registerSQSMetrics)
  59. result := &sqsControlRouter{
  60. service: sqs.New(session.New(config)),
  61. responseQueueURL: nil,
  62. userIDer: userIDer,
  63. prefix: prefix,
  64. rpcTimeout: rpcTimeout,
  65. responses: map[string]chan xfer.Response{},
  66. probeWorkers: map[int64]*probeWorker{},
  67. }
  68. go result.loop()
  69. return result
  70. }
  71. func (cr *sqsControlRouter) Stop() error {
  72. return nil
  73. }
  74. func (cr *sqsControlRouter) setResponseQueueURL(url *string) {
  75. cr.mtx.Lock()
  76. defer cr.mtx.Unlock()
  77. cr.responseQueueURL = url
  78. }
  79. func (cr *sqsControlRouter) getResponseQueueURL() *string {
  80. cr.mtx.Lock()
  81. defer cr.mtx.Unlock()
  82. return cr.responseQueueURL
  83. }
  84. func (cr *sqsControlRouter) getOrCreateQueue(ctx context.Context, name string) (*string, error) {
  85. // CreateQueue creates a queue or if it already exists, returns url of said queue
  86. var createQueueRes *sqs.CreateQueueOutput
  87. var err error
  88. err = instrument.TimeRequestHistogram(ctx, "SQS.CreateQueue", sqsRequestDuration, func(_ context.Context) error {
  89. createQueueRes, err = cr.service.CreateQueue(&sqs.CreateQueueInput{
  90. QueueName: aws.String(name),
  91. })
  92. return err
  93. })
  94. if err != nil {
  95. return nil, err
  96. }
  97. return createQueueRes.QueueUrl, nil
  98. }
  99. func (cr *sqsControlRouter) loop() {
  100. var (
  101. responseQueueURL *string
  102. err error
  103. ctx = context.Background()
  104. )
  105. for {
  106. // This app has a random id and uses this as a return path for all responses from probes.
  107. name := fmt.Sprintf("%scontrol-app-%d", cr.prefix, rand.Int63())
  108. responseQueueURL, err = cr.getOrCreateQueue(ctx, name)
  109. if err != nil {
  110. log.Errorf("Failed to create queue: %v", err)
  111. time.Sleep(1 * time.Second)
  112. continue
  113. }
  114. cr.setResponseQueueURL(responseQueueURL)
  115. break
  116. }
  117. for {
  118. var res *sqs.ReceiveMessageOutput
  119. var err error
  120. err = instrument.TimeRequestHistogram(ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
  121. res, err = cr.service.ReceiveMessage(&sqs.ReceiveMessageInput{
  122. QueueUrl: responseQueueURL,
  123. WaitTimeSeconds: longPollTime,
  124. })
  125. return err
  126. })
  127. if err != nil {
  128. log.Errorf("Error receiving message from %s: %v", *responseQueueURL, err)
  129. continue
  130. }
  131. if len(res.Messages) == 0 {
  132. continue
  133. }
  134. if err := cr.deleteMessages(ctx, responseQueueURL, res.Messages); err != nil {
  135. log.Errorf("Error deleting message from %s: %v", *responseQueueURL, err)
  136. }
  137. cr.handleResponses(res)
  138. }
  139. }
  140. func (cr *sqsControlRouter) deleteMessages(ctx context.Context, queueURL *string, messages []*sqs.Message) error {
  141. entries := []*sqs.DeleteMessageBatchRequestEntry{}
  142. for _, message := range messages {
  143. entries = append(entries, &sqs.DeleteMessageBatchRequestEntry{
  144. ReceiptHandle: message.ReceiptHandle,
  145. Id: message.MessageId,
  146. })
  147. }
  148. return instrument.TimeRequestHistogram(ctx, "SQS.DeleteMessageBatch", sqsRequestDuration, func(_ context.Context) error {
  149. _, err := cr.service.DeleteMessageBatch(&sqs.DeleteMessageBatchInput{
  150. QueueUrl: queueURL,
  151. Entries: entries,
  152. })
  153. return err
  154. })
  155. }
  156. func (cr *sqsControlRouter) handleResponses(res *sqs.ReceiveMessageOutput) {
  157. cr.mtx.Lock()
  158. defer cr.mtx.Unlock()
  159. for _, message := range res.Messages {
  160. var sqsResponse sqsResponseMessage
  161. if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsResponse); err != nil {
  162. log.Errorf("Error decoding message: %v", err)
  163. continue
  164. }
  165. waiter, ok := cr.responses[sqsResponse.ID]
  166. if !ok {
  167. log.Errorf("Dropping response %s - no one waiting for it!", sqsResponse.ID)
  168. continue
  169. }
  170. waiter <- sqsResponse.Response
  171. }
  172. }
  173. func (cr *sqsControlRouter) sendMessage(ctx context.Context, queueURL *string, message interface{}) error {
  174. buf := bytes.Buffer{}
  175. if err := json.NewEncoder(&buf).Encode(message); err != nil {
  176. return err
  177. }
  178. log.Debugf("sendMessage to %s: %s", *queueURL, buf.String())
  179. return instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(_ context.Context) error {
  180. _, err := cr.service.SendMessage(&sqs.SendMessageInput{
  181. QueueUrl: queueURL,
  182. MessageBody: aws.String(buf.String()),
  183. })
  184. return err
  185. })
  186. }
  187. func (cr *sqsControlRouter) Handle(ctx context.Context, probeID string, req xfer.Request) (xfer.Response, error) {
  188. // Make sure we know the users
  189. userID, err := cr.userIDer(ctx)
  190. if err != nil {
  191. return xfer.Response{}, err
  192. }
  193. // Get the queue url for the local (control app) queue, and for the probe.
  194. responseQueueURL := cr.getResponseQueueURL()
  195. if responseQueueURL == nil {
  196. return xfer.Response{}, fmt.Errorf("no SQS queue yet")
  197. }
  198. var probeQueueURL *sqs.GetQueueUrlOutput
  199. err = instrument.TimeRequestHistogram(ctx, "SQS.GetQueueUrl", sqsRequestDuration, func(_ context.Context) error {
  200. probeQueueName := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
  201. probeQueueURL, err = cr.service.GetQueueUrl(&sqs.GetQueueUrlInput{
  202. QueueName: aws.String(probeQueueName),
  203. })
  204. return err
  205. })
  206. if err != nil {
  207. return xfer.Response{}, err
  208. }
  209. // Add a response channel before we send the request, to prevent races
  210. id := fmt.Sprintf("request-%s-%d", userID, rand.Int63())
  211. waiter := make(chan xfer.Response, 1)
  212. cr.mtx.Lock()
  213. cr.responses[id] = waiter
  214. cr.mtx.Unlock()
  215. defer func() {
  216. cr.mtx.Lock()
  217. delete(cr.responses, id)
  218. cr.mtx.Unlock()
  219. }()
  220. // Next, send the request to that queue
  221. if err := instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(ctx context.Context) error {
  222. return cr.sendMessage(ctx, probeQueueURL.QueueUrl, sqsRequestMessage{
  223. ID: id,
  224. Request: req,
  225. ResponseQueueURL: *responseQueueURL,
  226. })
  227. }); err != nil {
  228. return xfer.Response{}, err
  229. }
  230. // Finally, wait for a response on our queue
  231. select {
  232. case response := <-waiter:
  233. return response, nil
  234. case <-time.After(cr.rpcTimeout):
  235. return xfer.Response{}, fmt.Errorf("request timed out")
  236. }
  237. }
  238. func (cr *sqsControlRouter) Register(ctx context.Context, probeID string, handler xfer.ControlHandlerFunc) (int64, error) {
  239. userID, err := cr.userIDer(ctx)
  240. if err != nil {
  241. return 0, err
  242. }
  243. name := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
  244. queueURL, err := cr.getOrCreateQueue(ctx, name)
  245. if err != nil {
  246. return 0, err
  247. }
  248. pwID := rand.Int63()
  249. pw := &probeWorker{
  250. ctx: ctx,
  251. router: cr,
  252. requestQueueURL: queueURL,
  253. handler: handler,
  254. quit: make(chan struct{}),
  255. }
  256. pw.done.Add(1)
  257. go pw.loop()
  258. cr.mtx.Lock()
  259. defer cr.mtx.Unlock()
  260. cr.probeWorkers[pwID] = pw
  261. return pwID, nil
  262. }
  263. func (cr *sqsControlRouter) Deregister(_ context.Context, probeID string, id int64) error {
  264. cr.mtx.Lock()
  265. pw, ok := cr.probeWorkers[id]
  266. delete(cr.probeWorkers, id)
  267. cr.mtx.Unlock()
  268. if ok {
  269. pw.stop()
  270. }
  271. return nil
  272. }
  273. // a probeWorker encapsulates a goroutine serving a probe's websocket connection.
  274. type probeWorker struct {
  275. ctx context.Context
  276. router *sqsControlRouter
  277. requestQueueURL *string
  278. handler xfer.ControlHandlerFunc
  279. quit chan struct{}
  280. done sync.WaitGroup
  281. }
  282. func (pw *probeWorker) stop() {
  283. close(pw.quit)
  284. pw.done.Wait()
  285. }
  286. func (pw *probeWorker) loop() {
  287. defer pw.done.Done()
  288. for {
  289. // have we been stopped?
  290. select {
  291. case <-pw.quit:
  292. return
  293. default:
  294. }
  295. var res *sqs.ReceiveMessageOutput
  296. var err error
  297. err = instrument.TimeRequestHistogram(pw.ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
  298. res, err = pw.router.service.ReceiveMessage(&sqs.ReceiveMessageInput{
  299. QueueUrl: pw.requestQueueURL,
  300. WaitTimeSeconds: longPollTime,
  301. })
  302. return err
  303. })
  304. if err != nil {
  305. log.Errorf("Error receiving message: %v", err)
  306. continue
  307. }
  308. if len(res.Messages) == 0 {
  309. continue
  310. }
  311. if err := pw.router.deleteMessages(pw.ctx, pw.requestQueueURL, res.Messages); err != nil {
  312. log.Errorf("Error deleting message from %s: %v", *pw.requestQueueURL, err)
  313. }
  314. for _, message := range res.Messages {
  315. var sqsRequest sqsRequestMessage
  316. if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsRequest); err != nil {
  317. log.Errorf("Error decoding message from: %v", err)
  318. continue
  319. }
  320. response := pw.handler(sqsRequest.Request)
  321. if err := pw.router.sendMessage(pw.ctx, &sqsRequest.ResponseQueueURL, sqsResponseMessage{
  322. ID: sqsRequest.ID,
  323. Response: response,
  324. }); err != nil {
  325. log.Errorf("Error sending response: %v", err)
  326. }
  327. }
  328. }
  329. }