fix(subscriptions): cleanup and simplify fetching subscribers for tasks and projects logic
Some checks failed
continuous-integration/drone/push Build is failing

Vikunja now uses one recursive CTE and a few optimizations to fetch all subscribers for a task or project. This makes the relevant code easier to maintain and more performant.
This commit is contained in:
kolaente 2024-09-04 19:54:22 +02:00
parent 850ec7efb0
commit 4ff8815fe1
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
7 changed files with 330 additions and 265 deletions

View File

@ -1745,7 +1745,7 @@ func IsErrSubscriptionAlreadyExists(err error) bool {
}
func (err *ErrSubscriptionAlreadyExists) Error() string {
return fmt.Sprintf("Subscription for this (entity_id, entity_type, user_id) already exists [EntityType: %d, EntityID: %d, ID: %d]", err.EntityType, err.EntityID, err.UserID)
return fmt.Sprintf("Subscription for this (entity_id, entity_type, user_id) already exists [EntityType: %d, EntityID: %d, UserID: %d]", err.EntityType, err.EntityID, err.UserID)
}
// ErrCodeSubscriptionAlreadyExists holds the unique world-error code of this error
@ -1760,6 +1760,32 @@ func (err ErrSubscriptionAlreadyExists) HTTPError() web.HTTPError {
}
}
// ErrMustProvideUser represents an error where you need to provide a user to fetch subscriptions
type ErrMustProvideUser struct {
}
// IsErrMustProvideUser checks if an error is ErrMustProvideUser.
func IsErrMustProvideUser(err error) bool {
_, ok := err.(*ErrMustProvideUser)
return ok
}
func (err *ErrMustProvideUser) Error() string {
return "no user provided while fetching subscriptions"
}
// ErrCodeMustProvideUser holds the unique world-error code of this error
const ErrCodeMustProvideUser = 12003
// HTTPError holds the http error description
func (err ErrMustProvideUser) HTTPError() web.HTTPError {
return web.HTTPError{
HTTPCode: http.StatusPreconditionFailed,
Code: ErrCodeMustProvideUser,
Message: "You must provide a user to fetch subscriptions",
}
}
// =================
// Link Share errors
// =================

View File

@ -199,7 +199,7 @@ func (s *SendTaskCommentNotification) Handle(msg *message.Message) (err error) {
return err
}
subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID)
subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID)
if err != nil {
return err
}
@ -279,7 +279,7 @@ func (s *SendTaskAssignedNotification) Handle(msg *message.Message) (err error)
sess := db.NewSession()
defer sess.Close()
subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID)
subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID)
if err != nil {
return err
}
@ -340,12 +340,12 @@ func (s *SendTaskDeletedNotification) Handle(msg *message.Message) (err error) {
sess := db.NewSession()
defer sess.Close()
var subscribers []*Subscription
subscribers, err = getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID)
var subscribers []*SubscriptionWithUser
subscribers, err = GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID)
// If the task does not exist and no one has explicitly subscribed to it, we won't find any subscriptions for it.
// Hence, we need to check for subscriptions to the parent project manually.
if err != nil && (IsErrTaskDoesNotExist(err) || IsErrProjectDoesNotExist(err)) {
subscribers, err = getSubscribersForEntity(sess, SubscriptionEntityProject, event.Task.ProjectID)
subscribers, err = GetSubscriptionsForEntity(sess, SubscriptionEntityProject, event.Task.ProjectID)
}
if err != nil {
return err
@ -801,7 +801,7 @@ func (s *SendProjectCreatedNotification) Handle(msg *message.Message) (err error
sess := db.NewSession()
defer sess.Close()
subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityProject, event.Project.ID)
subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityProject, event.Project.ID)
if err != nil {
return err
}

View File

@ -297,10 +297,13 @@ func (p *Project) ReadOne(s *xorm.Session, a web.Auth) (err error) {
return
}
p.Subscription, err = GetSubscription(s, SubscriptionEntityProject, p.ID, a)
subs, err := GetSubscriptionForUser(s, SubscriptionEntityProject, p.ID, a)
if err != nil && IsErrProjectDoesNotExist(err) && isFilter {
return nil
}
if subs != nil {
p.Subscription = &subs.Subscription
}
p.Views, err = getViewsForProject(s, p.ID)
return
@ -629,10 +632,23 @@ func addProjectDetails(s *xorm.Session, projects []*Project, a web.Auth) (err er
return err
}
subscriptions, err := GetSubscriptionsForProjects(s, projects, a)
if err != nil {
log.Errorf("An error occurred while getting project subscriptions for a project: %s", err.Error())
subscriptions = make(map[int64][]*Subscription)
var subscriptions = make(map[int64][]*Subscription)
u, is := a.(*user.User)
if is {
subscriptionsWithUser, err := GetSubscriptionsForEntitiesAndUser(s, SubscriptionEntityProject, projectIDs, u)
if err != nil {
log.Errorf("An error occurred while getting project subscriptions for a project: %s", err.Error())
}
if err == nil {
for pID, subs := range subscriptionsWithUser {
for _, sub := range subs {
if _, has := subscriptions[pID]; !has {
subscriptions[pID] = []*Subscription{}
}
subscriptions[pID] = append(subscriptions[pID], &sub.Subscription)
}
}
}
}
views := []*ProjectView{}

View File

@ -17,12 +17,13 @@
package models
import (
"strconv"
"time"
"xorm.io/builder"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/api/pkg/utils"
"code.vikunja.io/api/pkg/web"
"xorm.io/xorm"
)
@ -52,8 +53,7 @@ type Subscription struct {
EntityID int64 `xorm:"bigint index not null" json:"entity_id" param:"entityID"`
// The user who made this subscription
User *user.User `xorm:"-" json:"user"`
UserID int64 `xorm:"bigint index not null" json:"-"`
UserID int64 `xorm:"bigint index not null" json:"-"`
// A timestamp when this subscription was created. You cannot change this value.
Created time.Time `xorm:"created not null" json:"created"`
@ -62,7 +62,18 @@ type Subscription struct {
web.Rights `xorm:"-" json:"-"`
}
// TableName gives us a better tabel name for the subscriptions table
type SubscriptionWithUser struct {
Subscription `xorm:"extends"`
User *user.User `xorm:"extends" json:"user"`
}
type subscriptionResolved struct {
OriginalEntityID int64
SubscriptionID int64
SubscriptionWithUser `xorm:"extends"`
}
// TableName gives us a better table name for the subscriptions table
func (sb *Subscription) TableName() string {
return "subscriptions"
}
@ -115,28 +126,23 @@ func (et SubscriptionEntityType) validate() error {
// @Failure 500 {object} models.Message "Internal error"
// @Router /subscriptions/{entity}/{entityID} [put]
func (sb *Subscription) Create(s *xorm.Session, auth web.Auth) (err error) {
// Rights method alread does the validation of the entity type so we don't need to do that here
// Rights method already does the validation of the entity type, so we don't need to do that here
sb.UserID = auth.GetID()
sub, err := GetSubscription(s, sb.EntityType, sb.EntityID, auth)
sub, err := GetSubscriptionForUser(s, sb.EntityType, sb.EntityID, auth)
if err != nil {
return err
}
if sub != nil {
return &ErrSubscriptionAlreadyExists{
EntityID: sb.EntityID,
EntityType: sb.EntityType,
UserID: sb.UserID,
EntityID: sub.EntityID,
EntityType: sub.EntityType,
UserID: sub.UserID,
}
}
_, err = s.Insert(sb)
if err != nil {
return
}
sb.User, err = user.GetFromAuth(auth)
return
}
@ -163,261 +169,228 @@ func (sb *Subscription) Delete(s *xorm.Session, auth web.Auth) (err error) {
return
}
func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs []int64) (cond builder.Cond) {
if entityType == SubscriptionEntityProject {
return builder.And(
builder.In("entity_id", entityIDs),
builder.Eq{"entity_type": SubscriptionEntityProject},
)
}
if entityType == SubscriptionEntityTask {
return builder.Or(
builder.And(
builder.In("entity_id", entityIDs),
builder.Eq{"entity_type": SubscriptionEntityTask},
),
builder.And(
builder.Eq{"entity_id": builder.
Select("project_id").
From("tasks").
Where(builder.In("id", entityIDs)),
// TODO parent project
},
builder.Eq{"entity_type": SubscriptionEntityProject},
),
)
}
return
}
// GetSubscription returns a matching subscription for an entity and user.
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
// that task, if there is none it will look for a subscription on the project the task belongs to.
func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) {
subs, err := GetSubscriptions(s, entityType, entityID, a)
if err != nil || len(subs) == 0 {
return nil, err
}
return subs[0], nil
}
// GetSubscriptions returns a list of subscriptions to for an entity ID
func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) {
func GetSubscriptionForUser(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *SubscriptionWithUser, err error) {
u, is := a.(*user.User)
if u != nil && !is {
return
}
subs, err := GetSubscriptionsForEntitiesAndUser(s, entityType, []int64{entityID}, u)
if err != nil || len(subs) == 0 || len(subs[entityID]) == 0 {
return nil, err
}
return subs[entityID][0], nil
}
// GetSubscriptionsForEntities returns a list of subscriptions to for an entity ID
func GetSubscriptionsForEntities(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, nil, false)
}
func GetSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, u, true)
}
func GetSubscriptionsForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*SubscriptionWithUser, err error) {
subs, err := GetSubscriptionsForEntities(s, entityType, []int64{entityID})
if err != nil || len(subs[entityID]) == 0 {
return
}
return subs[entityID], nil
}
// This function returns a matching subscription for an entity and user.
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
// that task, if there is none it will look for a subscription on the project the task belongs to.
// It will return a map where the key is the entity id and the value is a slice with all subscriptions for that entity.
func getSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User, userOnly bool) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
if err := entityType.validate(); err != nil {
return nil, err
}
rawSubscriptions := []*subscriptionResolved{}
entityIDString := utils.JoinInt64Slice(entityIDs, ", ")
var sUserCond string
if userOnly {
if u == nil {
return nil, &ErrMustProvideUser{}
}
sUserCond = " AND s.user_id = " + strconv.FormatInt(u.ID, 10)
}
switch entityType {
case SubscriptionEntityProject:
project, err := GetProjectSimpleByID(s, entityID)
if err != nil {
return nil, err
}
subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u)
if err != nil {
return nil, err
}
if _, has := subs[entityID]; has && subs[entityID] != nil {
return subs[entityID], nil
}
err = s.SQL(`
WITH RECURSIVE project_hierarchy AS (
-- Base case: Start with the specified projects
SELECT
id,
parent_project_id,
0 AS level,
id AS original_project_id
FROM projects
WHERE id IN (`+entityIDString+`)
for _, sub := range subs {
// Fallback to the first non-nil subscription
if len(sub) > 0 {
return sub, nil
}
}
UNION ALL
return nil, nil
-- Recursive case: Get parent projects
SELECT
p.id,
p.parent_project_id,
ph.level + 1,
ph.original_project_id
FROM projects p
INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id
),
subscription_hierarchy AS (
-- Check for project subscriptions (including parent projects)
SELECT
s.id,
s.entity_type,
s.entity_id,
s.created,
s.user_id,
CASE
WHEN s.entity_id = ph.original_project_id THEN 1 -- Direct project match
ELSE ph.level + 1 -- Parent projects
END AS priority,
ph.original_project_id
FROM subscriptions s
INNER JOIN project_hierarchy ph ON s.entity_id = ph.id
WHERE s.entity_type = ?`+sUserCond+`
)
SELECT
p.id AS original_entity_id,
sh.id AS subscription_id,
sh.entity_type,
sh.entity_id,
sh.created,
sh.user_id,
CASE
WHEN sh.priority = 1 THEN 'Direct Project'
ELSE 'Parent Project'
END
AS subscription_level,
users.*
FROM projects p
LEFT JOIN (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY original_project_id, user_id ORDER BY priority) AS rn
FROM subscription_hierarchy
) sh ON p.id = sh.original_project_id AND sh.rn = 1
LEFT JOIN users ON sh.user_id = users.id
WHERE p.id IN (`+entityIDString+`)
ORDER BY p.id, sh.user_id`, SubscriptionEntityProject).
Find(&rawSubscriptions)
case SubscriptionEntityTask:
subs, err := getSubscriptionsForTask(s, entityID, u)
if err != nil {
return nil, err
}
err = s.SQL(`
WITH RECURSIVE project_hierarchy AS (
-- Base case: Start with the projects associated with the tasks
SELECT
p.id,
p.parent_project_id,
0 AS level,
t.id AS task_id
FROM tasks t
JOIN projects p ON t.project_id = p.id
WHERE t.id IN (`+entityIDString+`)
for _, sub := range subs {
// The subscriptions might also contain the immediate parent subscription, if that exists.
// This loop makes sure to only return the task subscription if it exists. The fallback
// happens in the next if after the loop.
if sub.EntityID == entityID && sub.EntityType == SubscriptionEntityTask {
return []*Subscription{sub}, nil
}
}
UNION ALL
if len(subs) > 0 {
return subs, nil
}
-- Recursive case: Get parent projects
SELECT
p.id,
p.parent_project_id,
ph.level + 1,
ph.task_id
FROM projects p
INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id
),
projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID})
if err != nil {
return nil, err
}
subscription_hierarchy AS (
-- Check for task subscriptions
SELECT
s.id,
s.entity_type,
s.entity_id,
s.created,
s.user_id,
1 AS priority,
t.id AS task_id
FROM subscriptions s
JOIN tasks t ON s.entity_id = t.id
WHERE s.entity_type = ? AND t.id IN (`+entityIDString+`)`+sUserCond+`
projectSubscriptions, err := GetSubscriptionsForProjects(s, projects, u)
if err != nil {
return nil, err
}
UNION ALL
if _, has := projectSubscriptions[projects[0].ID]; has {
return projectSubscriptions[projects[0].ID], nil
}
-- Check for project subscriptions (including parent projects)
SELECT
s.id,
s.entity_type,
s.entity_id,
s.created,
s.user_id,
ph.level + 2 AS priority,
ph.task_id
FROM subscriptions s
INNER JOIN project_hierarchy ph ON s.entity_id = ph.id
WHERE s.entity_type = ?
)
for _, psub := range projectSubscriptions {
// Fallback to the first non-nil subscription
if len(psub) > 0 {
return psub, nil
}
}
return subs, nil
SELECT
t.id AS original_entity_id,
sh.id AS subscription_id,
sh.entity_type,
sh.entity_id,
sh.created,
sh.user_id,
CASE
WHEN sh.entity_type = ? THEN 'Task'
WHEN sh.priority = ? THEN 'Direct Project'
ELSE 'Parent Project'
END
AS subscription_level,
users.*
FROM tasks t
LEFT JOIN (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY task_id, user_id ORDER BY priority) AS rn
FROM subscription_hierarchy
) sh ON t.id = sh.task_id AND sh.rn = 1
LEFT JOIN users ON sh.user_id = users.id
WHERE t.id IN (`+entityIDString+`)
ORDER BY t.id, sh.user_id`,
SubscriptionEntityTask, SubscriptionEntityProject, SubscriptionEntityTask, SubscriptionEntityProject).
Find(&rawSubscriptions)
}
if err != nil {
return nil, err
}
return
}
subscriptions = make(map[int64][]*SubscriptionWithUser)
for _, sub := range rawSubscriptions {
func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) {
u, is := a.(*user.User)
if u != nil && !is {
return
}
var ps = make(map[int64]*Project)
origProjectIDs := make([]int64, 0, len(projects))
allProjectIDs := make([]int64, 0, len(projects))
for _, p := range projects {
ps[p.ID] = p
origProjectIDs = append(origProjectIDs, p.ID)
allProjectIDs = append(allProjectIDs, p.ID)
}
// We can't just use the projects we have, we need to fetch the parents
// because they may not be loaded in the same object
for _, p := range projects {
if p.ParentProjectID == 0 {
if sub.Subscription.EntityID == 0 {
continue
}
if _, has := ps[p.ParentProjectID]; has {
continue
_, has := subscriptions[sub.OriginalEntityID]
if !has {
subscriptions[sub.OriginalEntityID] = []*SubscriptionWithUser{}
}
parents, err := GetAllParentProjects(s, p.ID)
if err != nil {
return nil, err
sub.Subscription.ID = sub.SubscriptionID
if sub.User != nil {
sub.User.ID = sub.UserID
}
// Walk the tree up until we reach the top
var parent = parents[p.ParentProjectID] // parent now has a pointer…
ps[p.ID].ParentProject = parents[p.ParentProjectID]
for parent != nil {
allProjectIDs = append(allProjectIDs, parent.ID)
parent = parents[parent.ParentProjectID] // … which means we can update it here and then update the pointer in the map
}
subscriptions[sub.OriginalEntityID] = append(subscriptions[sub.OriginalEntityID], &sub.SubscriptionWithUser)
}
var subscriptions []*Subscription
if u != nil {
err = s.
Where("user_id = ?", u.ID).
And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)).
Find(&subscriptions)
} else {
err = s.
And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)).
Find(&subscriptions)
}
if err != nil {
return nil, err
}
projectsToSubscriptions = make(map[int64][]*Subscription)
for _, sub := range subscriptions {
sub.Entity = sub.EntityType.String()
projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub)
}
// Rearrange so that subscriptions trickle down
for _, eID := range origProjectIDs {
// If the current project does not have a subscription, climb up the tree until a project has one,
// then use that subscription for all child projects
_, has := projectsToSubscriptions[eID]
_, hasProject := ps[eID]
if !has && hasProject {
_, exists := ps[eID]
if !exists {
continue
}
var parent = ps[eID].ParentProject
for parent != nil {
sub, has := projectsToSubscriptions[parent.ID]
projectsToSubscriptions[eID] = sub
parent = parent.ParentProject
if has { // reached the top of the tree
break
}
}
}
}
return projectsToSubscriptions, nil
}
func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) {
if u != nil {
err = s.
Where("user_id = ?", u.ID).
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions)
} else {
err = s.
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions)
}
if err != nil {
return nil, err
}
for _, sub := range subscriptions {
sub.Entity = sub.EntityType.String()
}
return
}
func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*Subscription, err error) {
if err := entityType.validate(); err != nil {
return nil, err
}
subs, err := GetSubscriptions(s, entityType, entityID, nil)
if err != nil {
return
}
userIDs := []int64{}
subscriptions = make([]*Subscription, 0, len(subs))
for _, subscription := range subs {
userIDs = append(userIDs, subscription.UserID)
subscriptions = append(subscriptions, subscription)
}
users, err := user.GetUsersByIDs(s, userIDs)
if err != nil {
return
}
for _, subscription := range subscriptions {
subscription.User = users[subscription.UserID]
}
return
return subscriptions, nil
}

View File

@ -52,7 +52,6 @@ func TestSubscription_Create(t *testing.T) {
sb := &Subscription{
Entity: "task",
EntityID: 1,
UserID: u.ID,
}
can, err := sb.CanCreate(s, u)
@ -61,7 +60,6 @@ func TestSubscription_Create(t *testing.T) {
err = sb.Create(s, u)
require.NoError(t, err)
assert.NotNil(t, sb.User)
db.AssertExists(t, "subscriptions", map[string]interface{}{
"entity_type": 3,
@ -69,6 +67,26 @@ func TestSubscription_Create(t *testing.T) {
"user_id": u.ID,
}, false)
})
t.Run("already exists", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sb := &Subscription{
Entity: "task",
EntityID: 2,
UserID: u.ID,
}
can, err := sb.CanCreate(s, u)
require.NoError(t, err)
assert.True(t, can)
err = sb.Create(s, u)
require.Error(t, err)
terr := &ErrSubscriptionAlreadyExists{}
assert.ErrorAs(t, err, &terr)
})
t.Run("forbidden for link shares", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
@ -86,7 +104,7 @@ func TestSubscription_Create(t *testing.T) {
require.Error(t, err)
assert.False(t, can)
})
t.Run("noneixsting project", func(t *testing.T) {
t.Run("nonexisting project", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
@ -240,7 +258,7 @@ func TestSubscriptionGet(t *testing.T) {
s := db.NewSession()
defer s.Close()
sub, err := GetSubscription(s, SubscriptionEntityProject, 12, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 12, u)
require.NoError(t, err)
assert.NotNil(t, sub)
assert.Equal(t, int64(3), sub.ID)
@ -250,7 +268,7 @@ func TestSubscriptionGet(t *testing.T) {
s := db.NewSession()
defer s.Close()
sub, err := GetSubscription(s, SubscriptionEntityTask, 22, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 22, u)
require.NoError(t, err)
assert.NotNil(t, sub)
assert.Equal(t, int64(4), sub.ID)
@ -263,7 +281,7 @@ func TestSubscriptionGet(t *testing.T) {
defer s.Close()
// Project 25 belongs to project 12 where user 6 has subscribed to
sub, err := GetSubscription(s, SubscriptionEntityProject, 25, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 25, u)
require.NoError(t, err)
assert.NotNil(t, sub)
assert.Equal(t, int64(12), sub.EntityID)
@ -275,7 +293,7 @@ func TestSubscriptionGet(t *testing.T) {
defer s.Close()
// Project 26 belongs to project 25 which belongs to project 12 where user 6 has subscribed to
sub, err := GetSubscription(s, SubscriptionEntityProject, 26, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 26, u)
require.NoError(t, err)
assert.NotNil(t, sub)
assert.Equal(t, int64(12), sub.EntityID)
@ -287,7 +305,7 @@ func TestSubscriptionGet(t *testing.T) {
defer s.Close()
// Task 39 belongs to project 25 which belongs to project 12 where the user has subscribed
sub, err := GetSubscription(s, SubscriptionEntityTask, 39, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 39, u)
require.NoError(t, err)
assert.NotNil(t, sub)
// assert.Equal(t, int64(2), sub.ID) TODO
@ -298,7 +316,7 @@ func TestSubscriptionGet(t *testing.T) {
defer s.Close()
// Task 21 belongs to project 32 which the user has subscribed to
sub, err := GetSubscription(s, SubscriptionEntityTask, 21, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 21, u)
require.NoError(t, err)
assert.NotNil(t, sub)
assert.Equal(t, int64(8), sub.ID)
@ -309,7 +327,7 @@ func TestSubscriptionGet(t *testing.T) {
s := db.NewSession()
defer s.Close()
_, err := GetSubscription(s, 2342, 21, u)
_, err := GetSubscriptionForUser(s, 2342, 21, u)
require.Error(t, err)
assert.True(t, IsErrUnknownSubscriptionEntityType(err))
})
@ -318,7 +336,7 @@ func TestSubscriptionGet(t *testing.T) {
s := db.NewSession()
defer s.Close()
sub, err := GetSubscription(s, SubscriptionEntityTask, 18, u)
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 18, u)
require.NoError(t, err)
assert.Equal(t, int64(9), sub.ID)
})

View File

@ -1574,10 +1574,11 @@ func (t *Task) ReadOne(s *xorm.Session, a web.Auth) (err error) {
*t = *taskMap[t.ID]
t.Subscription, err = GetSubscription(s, SubscriptionEntityTask, t.ID, a)
subs, err := GetSubscriptionForUser(s, SubscriptionEntityTask, t.ID, a)
if err != nil && IsErrProjectDoesNotExist(err) {
return nil
}
t.Subscription = &subs.Subscription
return
}

31
pkg/utils/strings.go Normal file
View File

@ -0,0 +1,31 @@
// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public Licensee as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public Licensee for more details.
//
// You should have received a copy of the GNU Affero General Public Licensee
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package utils
import "strconv"
func JoinInt64Slice(ints []int64, delim string) string {
b := ""
for _, v := range ints {
if len(b) > 0 {
b += delim
}
b += strconv.FormatInt(v, 10)
}
return b
}