123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- package multitenant
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "math/rand"
- "sync"
- "time"
- "context"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/aws/aws-sdk-go/service/sqs"
- "github.com/prometheus/client_golang/prometheus"
- log "github.com/sirupsen/logrus"
- "github.com/weaveworks/common/instrument"
- "github.com/weaveworks/scope/app"
- "github.com/weaveworks/scope/common/xfer"
- )
- var (
- longPollTime = aws.Int64(10)
- sqsRequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
- Namespace: "scope",
- Name: "sqs_request_duration_seconds",
- Help: "Time in seconds spent doing SQS requests.",
- Buckets: prometheus.DefBuckets,
- }, []string{"method", "status_code"})
- )
- func registerSQSMetrics() {
- prometheus.MustRegister(sqsRequestDuration)
- }
- var registerSQSMetricsOnce sync.Once
- // sqsControlRouter:
- // Creates a queue for every probe that connects to it, and a queue for
- // responses back to it. When it receives a request, posts it to the
- // probe queue. When probe receives a request, handles it and posts the
- // response back to the response queue.
- type sqsControlRouter struct {
- service *sqs.SQS
- responseQueueURL *string
- userIDer UserIDer
- prefix string
- rpcTimeout time.Duration
- mtx sync.Mutex
- responses map[string]chan xfer.Response
- probeWorkers map[int64]*probeWorker
- }
- type sqsRequestMessage struct {
- ID string
- Request xfer.Request
- ResponseQueueURL string
- }
- type sqsResponseMessage struct {
- ID string
- Response xfer.Response
- }
- // NewSQSControlRouter the harbinger of death
- func NewSQSControlRouter(config *aws.Config, userIDer UserIDer, prefix string, rpcTimeout time.Duration) app.ControlRouter {
- registerSQSMetricsOnce.Do(registerSQSMetrics)
- result := &sqsControlRouter{
- service: sqs.New(session.New(config)),
- responseQueueURL: nil,
- userIDer: userIDer,
- prefix: prefix,
- rpcTimeout: rpcTimeout,
- responses: map[string]chan xfer.Response{},
- probeWorkers: map[int64]*probeWorker{},
- }
- go result.loop()
- return result
- }
- func (cr *sqsControlRouter) Stop() error {
- return nil
- }
- func (cr *sqsControlRouter) setResponseQueueURL(url *string) {
- cr.mtx.Lock()
- defer cr.mtx.Unlock()
- cr.responseQueueURL = url
- }
- func (cr *sqsControlRouter) getResponseQueueURL() *string {
- cr.mtx.Lock()
- defer cr.mtx.Unlock()
- return cr.responseQueueURL
- }
- func (cr *sqsControlRouter) getOrCreateQueue(ctx context.Context, name string) (*string, error) {
- // CreateQueue creates a queue or if it already exists, returns url of said queue
- var createQueueRes *sqs.CreateQueueOutput
- var err error
- err = instrument.TimeRequestHistogram(ctx, "SQS.CreateQueue", sqsRequestDuration, func(_ context.Context) error {
- createQueueRes, err = cr.service.CreateQueue(&sqs.CreateQueueInput{
- QueueName: aws.String(name),
- })
- return err
- })
- if err != nil {
- return nil, err
- }
- return createQueueRes.QueueUrl, nil
- }
- func (cr *sqsControlRouter) loop() {
- var (
- responseQueueURL *string
- err error
- ctx = context.Background()
- )
- for {
- // This app has a random id and uses this as a return path for all responses from probes.
- name := fmt.Sprintf("%scontrol-app-%d", cr.prefix, rand.Int63())
- responseQueueURL, err = cr.getOrCreateQueue(ctx, name)
- if err != nil {
- log.Errorf("Failed to create queue: %v", err)
- time.Sleep(1 * time.Second)
- continue
- }
- cr.setResponseQueueURL(responseQueueURL)
- break
- }
- for {
- var res *sqs.ReceiveMessageOutput
- var err error
- err = instrument.TimeRequestHistogram(ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
- res, err = cr.service.ReceiveMessage(&sqs.ReceiveMessageInput{
- QueueUrl: responseQueueURL,
- WaitTimeSeconds: longPollTime,
- })
- return err
- })
- if err != nil {
- log.Errorf("Error receiving message from %s: %v", *responseQueueURL, err)
- continue
- }
- if len(res.Messages) == 0 {
- continue
- }
- if err := cr.deleteMessages(ctx, responseQueueURL, res.Messages); err != nil {
- log.Errorf("Error deleting message from %s: %v", *responseQueueURL, err)
- }
- cr.handleResponses(res)
- }
- }
- func (cr *sqsControlRouter) deleteMessages(ctx context.Context, queueURL *string, messages []*sqs.Message) error {
- entries := []*sqs.DeleteMessageBatchRequestEntry{}
- for _, message := range messages {
- entries = append(entries, &sqs.DeleteMessageBatchRequestEntry{
- ReceiptHandle: message.ReceiptHandle,
- Id: message.MessageId,
- })
- }
- return instrument.TimeRequestHistogram(ctx, "SQS.DeleteMessageBatch", sqsRequestDuration, func(_ context.Context) error {
- _, err := cr.service.DeleteMessageBatch(&sqs.DeleteMessageBatchInput{
- QueueUrl: queueURL,
- Entries: entries,
- })
- return err
- })
- }
- func (cr *sqsControlRouter) handleResponses(res *sqs.ReceiveMessageOutput) {
- cr.mtx.Lock()
- defer cr.mtx.Unlock()
- for _, message := range res.Messages {
- var sqsResponse sqsResponseMessage
- if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsResponse); err != nil {
- log.Errorf("Error decoding message: %v", err)
- continue
- }
- waiter, ok := cr.responses[sqsResponse.ID]
- if !ok {
- log.Errorf("Dropping response %s - no one waiting for it!", sqsResponse.ID)
- continue
- }
- waiter <- sqsResponse.Response
- }
- }
- func (cr *sqsControlRouter) sendMessage(ctx context.Context, queueURL *string, message interface{}) error {
- buf := bytes.Buffer{}
- if err := json.NewEncoder(&buf).Encode(message); err != nil {
- return err
- }
- log.Debugf("sendMessage to %s: %s", *queueURL, buf.String())
- return instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(_ context.Context) error {
- _, err := cr.service.SendMessage(&sqs.SendMessageInput{
- QueueUrl: queueURL,
- MessageBody: aws.String(buf.String()),
- })
- return err
- })
- }
- func (cr *sqsControlRouter) Handle(ctx context.Context, probeID string, req xfer.Request) (xfer.Response, error) {
- // Make sure we know the users
- userID, err := cr.userIDer(ctx)
- if err != nil {
- return xfer.Response{}, err
- }
- // Get the queue url for the local (control app) queue, and for the probe.
- responseQueueURL := cr.getResponseQueueURL()
- if responseQueueURL == nil {
- return xfer.Response{}, fmt.Errorf("no SQS queue yet")
- }
- var probeQueueURL *sqs.GetQueueUrlOutput
- err = instrument.TimeRequestHistogram(ctx, "SQS.GetQueueUrl", sqsRequestDuration, func(_ context.Context) error {
- probeQueueName := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
- probeQueueURL, err = cr.service.GetQueueUrl(&sqs.GetQueueUrlInput{
- QueueName: aws.String(probeQueueName),
- })
- return err
- })
- if err != nil {
- return xfer.Response{}, err
- }
- // Add a response channel before we send the request, to prevent races
- id := fmt.Sprintf("request-%s-%d", userID, rand.Int63())
- waiter := make(chan xfer.Response, 1)
- cr.mtx.Lock()
- cr.responses[id] = waiter
- cr.mtx.Unlock()
- defer func() {
- cr.mtx.Lock()
- delete(cr.responses, id)
- cr.mtx.Unlock()
- }()
- // Next, send the request to that queue
- if err := instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(ctx context.Context) error {
- return cr.sendMessage(ctx, probeQueueURL.QueueUrl, sqsRequestMessage{
- ID: id,
- Request: req,
- ResponseQueueURL: *responseQueueURL,
- })
- }); err != nil {
- return xfer.Response{}, err
- }
- // Finally, wait for a response on our queue
- select {
- case response := <-waiter:
- return response, nil
- case <-time.After(cr.rpcTimeout):
- return xfer.Response{}, fmt.Errorf("request timed out")
- }
- }
- func (cr *sqsControlRouter) Register(ctx context.Context, probeID string, handler xfer.ControlHandlerFunc) (int64, error) {
- userID, err := cr.userIDer(ctx)
- if err != nil {
- return 0, err
- }
- name := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
- queueURL, err := cr.getOrCreateQueue(ctx, name)
- if err != nil {
- return 0, err
- }
- pwID := rand.Int63()
- pw := &probeWorker{
- ctx: ctx,
- router: cr,
- requestQueueURL: queueURL,
- handler: handler,
- quit: make(chan struct{}),
- }
- pw.done.Add(1)
- go pw.loop()
- cr.mtx.Lock()
- defer cr.mtx.Unlock()
- cr.probeWorkers[pwID] = pw
- return pwID, nil
- }
- func (cr *sqsControlRouter) Deregister(_ context.Context, probeID string, id int64) error {
- cr.mtx.Lock()
- pw, ok := cr.probeWorkers[id]
- delete(cr.probeWorkers, id)
- cr.mtx.Unlock()
- if ok {
- pw.stop()
- }
- return nil
- }
- // a probeWorker encapsulates a goroutine serving a probe's websocket connection.
- type probeWorker struct {
- ctx context.Context
- router *sqsControlRouter
- requestQueueURL *string
- handler xfer.ControlHandlerFunc
- quit chan struct{}
- done sync.WaitGroup
- }
- func (pw *probeWorker) stop() {
- close(pw.quit)
- pw.done.Wait()
- }
- func (pw *probeWorker) loop() {
- defer pw.done.Done()
- for {
- // have we been stopped?
- select {
- case <-pw.quit:
- return
- default:
- }
- var res *sqs.ReceiveMessageOutput
- var err error
- err = instrument.TimeRequestHistogram(pw.ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
- res, err = pw.router.service.ReceiveMessage(&sqs.ReceiveMessageInput{
- QueueUrl: pw.requestQueueURL,
- WaitTimeSeconds: longPollTime,
- })
- return err
- })
- if err != nil {
- log.Errorf("Error receiving message: %v", err)
- continue
- }
- if len(res.Messages) == 0 {
- continue
- }
- if err := pw.router.deleteMessages(pw.ctx, pw.requestQueueURL, res.Messages); err != nil {
- log.Errorf("Error deleting message from %s: %v", *pw.requestQueueURL, err)
- }
- for _, message := range res.Messages {
- var sqsRequest sqsRequestMessage
- if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsRequest); err != nil {
- log.Errorf("Error decoding message from: %v", err)
- continue
- }
- response := pw.handler(sqsRequest.Request)
- if err := pw.router.sendMessage(pw.ctx, &sqsRequest.ResponseQueueURL, sqsResponseMessage{
- ID: sqsRequest.ID,
- Response: response,
- }); err != nil {
- log.Errorf("Error sending response: %v", err)
- }
- }
- }
- }
|