vikunja/vendor/github.com/ulule/limiter/v3/drivers/store/redis/store.go

321 lines
8.0 KiB
Go

package redis
import (
"context"
"fmt"
"time"
libredis "github.com/go-redis/redis"
"github.com/pkg/errors"
"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/store/common"
)
// Client is an interface thats allows to use a redis cluster or a redis single client seamlessly.
type Client interface {
Ping() *libredis.StatusCmd
Get(key string) *libredis.StringCmd
Set(key string, value interface{}, expiration time.Duration) *libredis.StatusCmd
Watch(handler func(*libredis.Tx) error, keys ...string) error
Del(keys ...string) *libredis.IntCmd
SetNX(key string, value interface{}, expiration time.Duration) *libredis.BoolCmd
Eval(script string, keys []string, args ...interface{}) *libredis.Cmd
}
// Store is the redis store.
type Store struct {
// Prefix used for the key.
Prefix string
// MaxRetry is the maximum number of retry under race conditions.
MaxRetry int
// client used to communicate with redis server.
client Client
}
// NewStore returns an instance of redis store with defaults.
func NewStore(client Client) (limiter.Store, error) {
return NewStoreWithOptions(client, limiter.StoreOptions{
Prefix: limiter.DefaultPrefix,
CleanUpInterval: limiter.DefaultCleanUpInterval,
MaxRetry: limiter.DefaultMaxRetry,
})
}
// NewStoreWithOptions returns an instance of redis store with options.
func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) {
store := &Store{
client: client,
Prefix: options.Prefix,
MaxRetry: options.MaxRetry,
}
if store.MaxRetry <= 0 {
store.MaxRetry = 1
}
_, err := store.ping()
if err != nil {
return nil, err
}
return store, nil
}
// Get returns the limit for given identifier.
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
lctx := limiter.Context{}
onWatch := func(rtx *libredis.Tx) error {
created, err := store.doSetValue(rtx, key, rate.Period)
if err != nil {
return err
}
if created {
expiration := now.Add(rate.Period)
lctx = common.GetContextFromState(now, rate, expiration, 1)
return nil
}
count, ttl, err := store.doUpdateValue(rtx, key, rate.Period)
if err != nil {
return err
}
expiration := now.Add(rate.Period)
if ttl > 0 {
expiration = now.Add(ttl)
}
lctx = common.GetContextFromState(now, rate, expiration, count)
return nil
}
err := store.client.Watch(onWatch, key)
if err != nil {
err = errors.Wrapf(err, "limiter: cannot get value for %s", key)
return limiter.Context{}, err
}
return lctx, nil
}
// Peek returns the limit for given identifier, without modification on current values.
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
lctx := limiter.Context{}
onWatch := func(rtx *libredis.Tx) error {
count, ttl, err := store.doPeekValue(rtx, key)
if err != nil {
return err
}
expiration := now.Add(rate.Period)
if ttl > 0 {
expiration = now.Add(ttl)
}
lctx = common.GetContextFromState(now, rate, expiration, count)
return nil
}
err := store.client.Watch(onWatch, key)
if err != nil {
err = errors.Wrapf(err, "limiter: cannot peek value for %s", key)
return limiter.Context{}, err
}
return lctx, nil
}
// Reset returns the limit for given identifier which is set to zero.
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
lctx := limiter.Context{}
onWatch := func(rtx *libredis.Tx) error {
err := store.doResetValue(rtx, key)
if err != nil {
return err
}
count := int64(0)
expiration := now.Add(rate.Period)
lctx = common.GetContextFromState(now, rate, expiration, count)
return nil
}
err := store.client.Watch(onWatch, key)
if err != nil {
err = errors.Wrapf(err, "limiter: cannot reset value for %s", key)
return limiter.Context{}, err
}
return lctx, nil
}
// doPeekValue will execute peekValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doPeekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) {
for i := 0; i < store.MaxRetry; i++ {
count, ttl, err := peekValue(rtx, key)
if err == nil {
return count, ttl, nil
}
}
return 0, 0, errors.New("retry limit exceeded")
}
// peekValue will retrieve the counter and its expiration for given key.
func peekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) {
pipe := rtx.Pipeline()
value := pipe.Get(key)
expire := pipe.PTTL(key)
_, err := pipe.Exec()
if err != nil && err != libredis.Nil {
return 0, 0, err
}
count, err := value.Int64()
if err != nil && err != libredis.Nil {
return 0, 0, err
}
ttl, err := expire.Result()
if err != nil {
return 0, 0, err
}
return count, ttl, nil
}
// doSetValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doSetValue(rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) {
for i := 0; i < store.MaxRetry; i++ {
created, err := setValue(rtx, key, expiration)
if err == nil {
return created, nil
}
}
return false, errors.New("retry limit exceeded")
}
// setValue will try to initialize a new counter if given key doesn't exists.
func setValue(rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) {
value := rtx.SetNX(key, 1, expiration)
created, err := value.Result()
if err != nil {
return false, err
}
return created, nil
}
// doUpdateValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doUpdateValue(rtx *libredis.Tx, key string,
expiration time.Duration) (int64, time.Duration, error) {
for i := 0; i < store.MaxRetry; i++ {
count, ttl, err := updateValue(rtx, key, expiration)
if err == nil {
return count, ttl, nil
}
// If ttl is negative and there is an error, do not retry an update.
if ttl < 0 {
return 0, 0, err
}
}
return 0, 0, errors.New("retry limit exceeded")
}
// updateValue will try to increment the counter identified by given key.
func updateValue(rtx *libredis.Tx, key string, expiration time.Duration) (int64, time.Duration, error) {
pipe := rtx.Pipeline()
value := pipe.Incr(key)
expire := pipe.PTTL(key)
_, err := pipe.Exec()
if err != nil {
return 0, 0, err
}
count, err := value.Result()
if err != nil {
return 0, 0, err
}
ttl, err := expire.Result()
if err != nil {
return 0, 0, err
}
// If ttl is -1ms, we have to define key expiration.
// PTTL return values changed as of Redis 2.8
// Now the command returns -2ms if the key does not exist, and -1ms if the key exists, but there is no expiry set
// We shouldn't try to set an expiry on a key that doesn't exist
if ttl == (-1 * time.Millisecond) {
expire := rtx.Expire(key, expiration)
ok, err := expire.Result()
if err != nil {
return count, ttl, err
}
if !ok {
return count, ttl, errors.New("cannot configure timeout on key")
}
}
return count, ttl, nil
}
// doResetValue will execute resetValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doResetValue(rtx *libredis.Tx, key string) error {
for i := 0; i < store.MaxRetry; i++ {
err := resetValue(rtx, key)
if err == nil {
return nil
}
}
return errors.New("retry limit exceeded")
}
// resetValue will try to reset the counter identified by given key.
func resetValue(rtx *libredis.Tx, key string) error {
deletion := rtx.Del(key)
count, err := deletion.Result()
if err != nil {
return err
}
if count != 1 {
return errors.New("cannot delete key")
}
return nil
}
// ping checks if redis is alive.
func (store *Store) ping() (bool, error) {
cmd := store.client.Ping()
pong, err := cmd.Result()
if err != nil {
return false, errors.Wrap(err, "limiter: cannot ping redis server")
}
return (pong == "PONG"), nil
}