diff --git a/pkg/metrics/active_users.go b/pkg/metrics/active_users.go index 4feda345b..deeddc334 100644 --- a/pkg/metrics/active_users.go +++ b/pkg/metrics/active_users.go @@ -91,12 +91,8 @@ func SetUserActive(a web.Auth) (err error) { // getActiveUsers returns the active users from redis func getActiveUsers() (users activeUsersMap, err error) { - u, _, err := keyvalue.Get(ActiveUsersKey) - if err != nil { - return nil, err - } - - users = u.(activeUsersMap) + users = activeUsersMap{} + _, err = keyvalue.GetWithValue(ActiveUsersKey, &users) return } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index aa79edede..0441f061e 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -17,6 +17,8 @@ package metrics import ( + "strconv" + "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/modules/keyvalue" "github.com/prometheus/client_golang/prometheus" @@ -132,7 +134,11 @@ func GetCount(key string) (count int64, err error) { return 0, nil } - count = cnt.(int64) + if s, is := cnt.(string); is { + count, err = strconv.ParseInt(s, 10, 64) + } else { + count = cnt.(int64) + } return } diff --git a/pkg/modules/auth/openid/openid.go b/pkg/modules/auth/openid/openid.go index 8c62664e3..f5835097a 100644 --- a/pkg/modules/auth/openid/openid.go +++ b/pkg/modules/auth/openid/openid.go @@ -46,12 +46,12 @@ type Callback struct { // Provider is the structure of an OpenID Connect provider type Provider struct { - Name string `json:"name"` - Key string `json:"key"` - AuthURL string `json:"auth_url"` - ClientID string `json:"client_id"` - ClientSecret string `json:"-"` - OpenIDProvider *oidc.Provider `json:"-"` + Name string `json:"name"` + Key string `json:"key"` + AuthURL string `json:"auth_url"` + ClientID string `json:"client_id"` + ClientSecret string `json:"-"` + openIDProvider *oidc.Provider Oauth2Config *oauth2.Config `json:"-"` } @@ -66,6 +66,11 @@ func init() { rand.Seed(time.Now().UTC().UnixNano()) } +func (p *Provider) setOicdProvider() (err error) { + p.openIDProvider, err = oidc.NewProvider(context.Background(), p.AuthURL) + return err +} + // HandleCallback handles the auth request callback after redirecting from the provider with an auth code // @Summary Authenticate a user with OpenID Connect // @Description After a redirect from the OpenID Connect provider to the frontend has been made with the authentication `code`, this endpoint can be used to obtain a jwt token for that user and thus log them in. @@ -122,7 +127,7 @@ func HandleCallback(c echo.Context) error { return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"}) } - verifier := provider.OpenIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID}) + verifier := provider.openIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID}) // Parse and verify ID Token payload. idToken, err := verifier.Verify(context.Background(), rawIDToken) @@ -140,7 +145,7 @@ func HandleCallback(c echo.Context) error { } if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" { - info, err := provider.OpenIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token)) + info, err := provider.openIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token)) if err != nil { log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err) return handler.HandleHTTPError(err, c) diff --git a/pkg/modules/auth/openid/providers.go b/pkg/modules/auth/openid/providers.go index 5ed8ce598..6482a3188 100644 --- a/pkg/modules/auth/openid/providers.go +++ b/pkg/modules/auth/openid/providers.go @@ -17,7 +17,6 @@ package openid import ( - "context" "regexp" "strconv" "strings" @@ -36,7 +35,8 @@ func GetAllProviders() (providers []*Provider, err error) { return nil, nil } - ps, exists, err := keyvalue.Get("openid_providers") + providers = []*Provider{} + exists, err := keyvalue.GetWithValue("openid_providers", &providers) if !exists { rawProviders := config.AuthOpenIDProviders.Get() if rawProviders == nil { @@ -68,31 +68,30 @@ func GetAllProviders() (providers []*Provider, err error) { err = keyvalue.Put("openid_providers", providers) } - if ps != nil { - return ps.([]*Provider), nil - } - return } // GetProvider retrieves a provider from keyvalue func GetProvider(key string) (provider *Provider, err error) { - var p interface{} - p, exists, err := keyvalue.Get("openid_provider_" + key) + provider = &Provider{} + exists, err := keyvalue.GetWithValue("openid_provider_"+key, provider) + if err != nil { + return nil, err + } if !exists { _, err = GetAllProviders() // This will put all providers in cache if err != nil { return nil, err } - p, _, err = keyvalue.Get("openid_provider_" + key) + _, err = keyvalue.GetWithValue("openid_provider_"+key, provider) + if err != nil { + return nil, err + } } - if p != nil { - return p.(*Provider), nil - } - - return nil, err + err = provider.setOicdProvider() + return } func getKeyFromName(name string) string { @@ -100,7 +99,7 @@ func getKeyFromName(name string) string { return reg.ReplaceAllString(strings.ToLower(name), "") } -func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) { +func getProviderFromMap(pi map[interface{}]interface{}) (provider *Provider, err error) { name, is := pi["name"].(string) if !is { return nil, nil @@ -108,7 +107,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) { k := getKeyFromName(name) - provider := &Provider{ + provider = &Provider{ Name: pi["name"].(string), Key: k, AuthURL: pi["authurl"].(string), @@ -122,10 +121,9 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) { provider.ClientID = pi["clientid"].(string) } - var err error - provider.OpenIDProvider, err = oidc.NewProvider(context.Background(), provider.AuthURL) + err = provider.setOicdProvider() if err != nil { - return provider, err + return } provider.Oauth2Config = &oauth2.Config{ @@ -134,7 +132,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) { RedirectURL: config.AuthOpenIDRedirectURL.GetString() + k, // Discovery returns the OAuth2 endpoints. - Endpoint: provider.OpenIDProvider.Endpoint(), + Endpoint: provider.openIDProvider.Endpoint(), // "openid" is a required scope for OpenID Connect flows. Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, @@ -142,5 +140,5 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) { provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL - return provider, nil + return } diff --git a/pkg/modules/avatar/initials/initials.go b/pkg/modules/avatar/initials/initials.go index 1f062201d..6d92c3d8c 100644 --- a/pkg/modules/avatar/initials/initials.go +++ b/pkg/modules/avatar/initials/initials.go @@ -127,7 +127,8 @@ func getCacheKey(prefix string, keys ...int64) string { func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) { cacheKey := getCacheKey("full", u.ID) - a, exists, err := keyvalue.Get(cacheKey) + fullSizeAvatar = &image.RGBA64{} + exists, err := keyvalue.GetWithValue(cacheKey, fullSizeAvatar) if err != nil { return nil, err } @@ -145,8 +146,6 @@ func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) { if err != nil { return nil, err } - } else { - fullSizeAvatar = a.(*image.RGBA64) } return fullSizeAvatar, nil @@ -156,7 +155,7 @@ func getAvatarForUser(u *user.User) (fullSizeAvatar *image.RGBA64, err error) { func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType string, err error) { cacheKey := getCacheKey("resized", u.ID, size) - a, exists, err := keyvalue.Get(cacheKey) + exists, err := keyvalue.GetWithValue(cacheKey, &avatar) if err != nil { return nil, "", err } @@ -180,7 +179,6 @@ func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType return nil, "", err } } else { - avatar = a.([]byte) log.Debugf("Serving initials avatar for user %d and size %d from cache", u.ID, size) } diff --git a/pkg/modules/avatar/upload/upload.go b/pkg/modules/avatar/upload/upload.go index 9348d4dee..e75f2ba66 100644 --- a/pkg/modules/avatar/upload/upload.go +++ b/pkg/modules/avatar/upload/upload.go @@ -39,22 +39,17 @@ func (p *Provider) GetAvatar(u *user.User, size int64) (avatar []byte, mimeType cacheKey := "avatar_upload_" + strconv.Itoa(int(u.ID)) - ai, exists, err := keyvalue.Get(cacheKey) + var cached map[int64][]byte + exists, err := keyvalue.GetWithValue(cacheKey, &cached) if err != nil { return nil, "", err } - var cached map[int64][]byte - - if ai != nil { - cached = ai.(map[int64][]byte) - } - if !exists { // Nothing ever cached for this user so we need to create the size map to avoid panics cached = make(map[int64][]byte) } else { - a := ai.(map[int64][]byte) + a := cached if a != nil && a[size] != nil { log.Debugf("Serving uploaded avatar for user %d and size %d from cache.", u.ID, size) return a[size], "", nil diff --git a/pkg/modules/background/unsplash/unsplash.go b/pkg/modules/background/unsplash/unsplash.go index 8dd29292b..7f13be8ef 100644 --- a/pkg/modules/background/unsplash/unsplash.go +++ b/pkg/modules/background/unsplash/unsplash.go @@ -122,7 +122,8 @@ func getImageID(fullURL string) string { // Gets an unsplash photo either from cache or directly from the unsplash api func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) { - p, exists, err := keyvalue.Get(cachePrefix + photoID) + photo = &Photo{} + exists, err := keyvalue.GetWithValue(cachePrefix+photoID, photo) if err != nil { return nil, err } @@ -134,8 +135,6 @@ func getUnsplashPhotoInfoByID(photoID string) (photo *Photo, err error) { if err != nil { return } - } else { - photo = p.(*Photo) } return } diff --git a/pkg/modules/keyvalue/keyvalue.go b/pkg/modules/keyvalue/keyvalue.go index cc2f26afb..0e76ba6be 100644 --- a/pkg/modules/keyvalue/keyvalue.go +++ b/pkg/modules/keyvalue/keyvalue.go @@ -26,6 +26,7 @@ import ( type Storage interface { Put(key string, value interface{}) (err error) Get(key string) (value interface{}, exists bool, err error) + GetWithValue(key string, value interface{}) (exists bool, err error) Del(key string) (err error) IncrBy(key string, update int64) (err error) DecrBy(key string, update int64) (err error) @@ -55,6 +56,10 @@ func Get(key string) (value interface{}, exists bool, err error) { return store.Get(key) } +func GetWithValue(key string, value interface{}) (exists bool, err error) { + return store.GetWithValue(key, value) +} + // Del removes a save value from a storage backend func Del(key string) (err error) { return store.Del(key) diff --git a/pkg/modules/keyvalue/memory/memory.go b/pkg/modules/keyvalue/memory/memory.go index 6530f506a..230dd296e 100644 --- a/pkg/modules/keyvalue/memory/memory.go +++ b/pkg/modules/keyvalue/memory/memory.go @@ -17,6 +17,7 @@ package memory import ( + "reflect" "sync" e "code.vikunja.io/api/pkg/modules/keyvalue/error" @@ -52,6 +53,21 @@ func (s *Storage) Get(key string) (value interface{}, exists bool, err error) { return } +func (s *Storage) GetWithValue(key string, value interface{}) (exists bool, err error) { + v, exists, err := s.Get(key) + if !exists { + return exists, err + } + + val := reflect.ValueOf(value) + if val.Kind() != reflect.Ptr { + panic("some: check must be a pointer") + } + + val.Elem().Set(reflect.ValueOf(v)) + return exists, err +} + // Del removes a saved value from a memory storage func (s *Storage) Del(key string) (err error) { s.mutex.Lock() diff --git a/pkg/modules/keyvalue/redis/redis.go b/pkg/modules/keyvalue/redis/redis.go index c5357376c..9031c4826 100644 --- a/pkg/modules/keyvalue/redis/redis.go +++ b/pkg/modules/keyvalue/redis/redis.go @@ -17,8 +17,10 @@ package redis import ( + "bytes" "context" - "encoding/json" + "encoding/gob" + "errors" "code.vikunja.io/api/pkg/red" "github.com/go-redis/redis/v8" @@ -40,9 +42,28 @@ func NewStorage() *Storage { // Put puts a value into redis func (s *Storage) Put(key string, value interface{}) (err error) { - v, err := json.Marshal(value) - if err != nil { - return err + + var v interface{} + + switch value.(type) { + case int: + v = value + case int8: + v = value + case int16: + v = value + case int32: + v = value + case int64: + v = value + default: + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err = enc.Encode(value) + if err != nil { + return err + } + return s.client.Set(context.Background(), key, buf.Bytes(), 0).Err() } return s.client.Set(context.Background(), key, v, 0).Err() @@ -50,13 +71,32 @@ func (s *Storage) Put(key string, value interface{}) (err error) { // Get retrieves a saved value from redis func (s *Storage) Get(key string) (value interface{}, exists bool, err error) { + value, err = s.client.Get(context.Background(), key).Result() + if err != nil && errors.Is(err, redis.Nil) { + return nil, false, nil + } + return value, true, err +} + +func (s *Storage) GetWithValue(key string, value interface{}) (exists bool, err error) { b, err := s.client.Get(context.Background(), key).Bytes() if err != nil { - return nil, false, err + if errors.Is(err, redis.Nil) { + return false, nil + } + + return } - err = json.Unmarshal(b, value) - return + var buf bytes.Buffer + _, err = buf.Write(b) + if err != nil { + return + } + + dec := gob.NewDecoder(&buf) + err = dec.Decode(value) + return true, err } // Del removed a value from redis