123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- package multitenant
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "time"
- consul "github.com/hashicorp/consul/api"
- opentracing "github.com/opentracing/opentracing-go"
- log "github.com/sirupsen/logrus"
- )
- const (
- longPollDuration = 10 * time.Second
- )
- // ConsulClient is a high-level client for Consul, that exposes operations
- // such as CAS and Watch which take callbacks. It also deals with serialisation.
- type ConsulClient interface {
- Get(ctx context.Context, key string, out interface{}) error
- CAS(ctx context.Context, key string, out interface{}, f CASCallback) error
- WatchPrefix(prefix string, out interface{}, done chan struct{}, f func(string, interface{}) bool)
- }
- // CASCallback is the type of the callback to CAS. If err is nil, out must be non-nil.
- type CASCallback func(in interface{}) (out interface{}, retry bool, err error)
- // NewConsulClient returns a new ConsulClient
- func NewConsulClient(addr string) (ConsulClient, error) {
- client, err := consul.NewClient(&consul.Config{
- Address: addr,
- Scheme: "http",
- })
- if err != nil {
- return nil, err
- }
- return &consulClient{client.KV()}, nil
- }
- var (
- queryOptions = &consul.QueryOptions{
- RequireConsistent: true,
- }
- writeOptions = &consul.WriteOptions{}
- // ErrNotFound is returned by ConsulClient.Get
- ErrNotFound = fmt.Errorf("Not found")
- )
- type kv interface {
- CAS(p *consul.KVPair, q *consul.WriteOptions) (bool, *consul.WriteMeta, error)
- Get(key string, q *consul.QueryOptions) (*consul.KVPair, *consul.QueryMeta, error)
- List(prefix string, q *consul.QueryOptions) (consul.KVPairs, *consul.QueryMeta, error)
- }
- type consulClient struct {
- kv kv
- }
- // Get and deserialise a JSON value from consul.
- func (c *consulClient) Get(ctx context.Context, key string, out interface{}) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "Consul Get", opentracing.Tag{Key: "key", Value: key})
- defer span.Finish()
- kvp, _, err := c.kv.Get(key, queryOptions)
- if err != nil {
- return err
- }
- if kvp == nil {
- return ErrNotFound
- }
- return json.NewDecoder(bytes.NewReader(kvp.Value)).Decode(out)
- }
- // CAS atomically modify a value in a callback.
- // If value doesn't exist you'll get nil as a argument to your callback.
- func (c *consulClient) CAS(ctx context.Context, key string, out interface{}, f CASCallback) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "Consul CAS", opentracing.Tag{Key: "key", Value: key})
- defer span.Finish()
- var (
- index = uint64(0)
- retries = 10
- retry = true
- intermediate interface{}
- )
- for i := 0; i < retries; i++ {
- kvp, _, err := c.kv.Get(key, queryOptions)
- if err != nil {
- log.Errorf("Error getting %s: %v", key, err)
- continue
- }
- if kvp != nil {
- if err := json.NewDecoder(bytes.NewReader(kvp.Value)).Decode(out); err != nil {
- log.Errorf("Error deserialising %s: %v", key, err)
- continue
- }
- index = kvp.ModifyIndex // if key doesn't exist, index will be 0
- intermediate = out
- }
- intermediate, retry, err = f(intermediate)
- if err != nil {
- log.Errorf("Error CASing %s: %v", key, err)
- if !retry {
- return err
- }
- continue
- }
- if intermediate == nil {
- panic("Callback must instantiate value!")
- }
- value := bytes.Buffer{}
- if err := json.NewEncoder(&value).Encode(intermediate); err != nil {
- log.Errorf("Error serialising value for %s: %v", key, err)
- continue
- }
- ok, _, err := c.kv.CAS(&consul.KVPair{
- Key: key,
- Value: value.Bytes(),
- ModifyIndex: index,
- }, writeOptions)
- if err != nil {
- log.Errorf("Error CASing %s: %v", key, err)
- continue
- }
- if !ok {
- log.Errorf("Error CASing %s, trying again %d", key, index)
- continue
- }
- return nil
- }
- return fmt.Errorf("Failed to CAS %s", key)
- }
- func (c *consulClient) WatchPrefix(prefix string, out interface{}, done chan struct{}, f func(string, interface{}) bool) {
- const (
- initialBackoff = 1 * time.Second
- maxBackoff = 1 * time.Minute
- )
- var (
- backoff = initialBackoff / 2
- index = uint64(0)
- )
- for {
- select {
- case <-done:
- return
- default:
- }
- kvps, meta, err := c.kv.List(prefix, &consul.QueryOptions{
- RequireConsistent: true,
- WaitIndex: index,
- WaitTime: longPollDuration,
- })
- if err != nil {
- log.Errorf("Error getting path %s: %v", prefix, err)
- backoff = backoff * 2
- if backoff > maxBackoff {
- backoff = maxBackoff
- }
- select {
- case <-done:
- return
- case <-time.After(backoff):
- continue
- }
- }
- backoff = initialBackoff
- if index == meta.LastIndex {
- continue
- }
- index = meta.LastIndex
- for _, kvp := range kvps {
- if err := json.NewDecoder(bytes.NewReader(kvp.Value)).Decode(out); err != nil {
- log.Errorf("Error deserialising %s: %v", kvp.Key, err)
- continue
- }
- if !f(kvp.Key, out) {
- return
- }
- }
- }
- }
|