fix(subscriptions): correctly inherit subscriptions
Some checks failed
continuous-integration/drone/push Build is failing

Resolves https://community.vikunja.io/t/e-mail-notification-twice/2740/20
This commit is contained in:
kolaente 2024-09-03 22:03:55 +02:00
parent 7bd84a845c
commit 06305eb6b3
Signed by: konrad
GPG Key ID: F40E70337AB24C9B
3 changed files with 59 additions and 42 deletions

View File

@ -28,3 +28,13 @@
entity_id: 32
user_id: 6
created: 2021-02-01 15:13:12
- id: 9
entity_type: 3 # Task
entity_id: 18
user_id: 6
created: 2021-02-01 15:13:12
- id: 10
entity_type: 2 # Project
entity_id: 9
user_id: 6
created: 2021-02-01 15:13:12

View File

@ -196,23 +196,16 @@ func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs [
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
// that task, if there is none it will look for a subscription on the project the task belongs to.
func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) {
subs, err := GetSubscriptions(s, entityType, []int64{entityID}, a)
subs, err := GetSubscriptions(s, entityType, entityID, a)
if err != nil || len(subs) == 0 {
return nil, err
}
if sub, exists := subs[entityID]; exists && len(sub) > 0 {
return sub[0], nil // Take exact match first, if available
}
for _, sub := range subs {
if len(sub) > 0 {
return sub[0], nil // For parents, take next available
}
}
return nil, nil
return subs[0], nil
}
// GetSubscriptions returns a map of subscriptions to a set of given entity IDs
func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) {
// GetSubscriptions returns a list of subscriptions to for an entity ID
func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) {
u, is := a.(*user.User)
if u != nil && !is {
return
@ -223,23 +216,37 @@ func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entity
switch entityType {
case SubscriptionEntityProject:
projects, err := GetProjectsByIDs(s, entityIDs)
project, err := GetProjectSimpleByID(s, entityID)
if err != nil {
return nil, err
}
return GetSubscriptionsForProjects(s, projects, u)
subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u)
if err != nil {
return nil, err
}
if _, has := subs[entityID]; has && subs[entityID] != nil {
return subs[entityID], nil
}
for _, sub := range subs {
// Fallback to the first non-nil subscription
if len(sub) > 0 {
return sub, nil
}
}
return nil, nil
case SubscriptionEntityTask:
subs, err := getSubscriptionsForTasks(s, entityIDs, u)
subs, err := getSubscriptionsForTask(s, entityID, u)
if err != nil {
return nil, err
}
projects, err := GetProjectsSimplByTaskIDs(s, entityIDs)
if err != nil {
return nil, err
if len(subs) > 0 {
return subs, nil
}
tasks, err := GetTasksSimpleByIDs(s, entityIDs)
projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID})
if err != nil {
return nil, err
}
@ -249,18 +256,14 @@ func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entity
return nil, err
}
for _, task := range tasks {
// If a task is already subscribed through the parent project,
// remove the task subscription since that's a duplicate.
// But if the user is not subscribed to the task but a parent project is, add that to the subscriptions
psub, hasProjectSub := projectSubscriptions[task.ProjectID]
_, hasTaskSub := subs[task.ID]
if hasProjectSub && hasTaskSub {
delete(subs, task.ID)
}
if _, has := projectSubscriptions[projects[0].ID]; has {
return projectSubscriptions[projects[0].ID], nil
}
if !hasTaskSub && !hasProjectSub {
subs[task.ID] = psub
for _, psub := range projectSubscriptions {
// Fallback to the first non-nil subscription
if len(psub) > 0 {
return psub, nil
}
}
@ -360,26 +363,23 @@ func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Aut
return projectsToSubscriptions, nil
}
func getSubscriptionsForTasks(s *xorm.Session, taskIDs []int64, u *user.User) (projectsToSubscriptions map[int64][]*Subscription, err error) {
var subscriptions []*Subscription
func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) {
if u != nil {
err = s.
Where("user_id = ?", u.ID).
And(getSubscriberCondForEntities(SubscriptionEntityTask, taskIDs)).
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions)
} else {
err = s.
And(getSubscriberCondForEntities(SubscriptionEntityTask, taskIDs)).
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions)
}
if err != nil {
return nil, err
}
projectsToSubscriptions = make(map[int64][]*Subscription)
for _, sub := range subscriptions {
sub.Entity = sub.EntityType.String()
projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub)
}
return
@ -390,18 +390,16 @@ func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType,
return nil, err
}
subs, err := GetSubscriptions(s, entityType, []int64{entityID}, nil)
subs, err := GetSubscriptions(s, entityType, entityID, nil)
if err != nil {
return
}
userIDs := []int64{}
subscriptions = make([]*Subscription, 0, len(subs))
for _, subss := range subs {
for _, subscription := range subss {
userIDs = append(userIDs, subscription.UserID)
subscriptions = append(subscriptions, subscription)
}
for _, subscription := range subs {
userIDs = append(userIDs, subscription.UserID)
subscriptions = append(subscriptions, subscription)
}
users, err := user.GetUsersByIDs(s, userIDs)

View File

@ -313,4 +313,13 @@ func TestSubscriptionGet(t *testing.T) {
require.Error(t, err)
assert.True(t, IsErrUnknownSubscriptionEntityType(err))
})
t.Run("double subscription should be returned once", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sub, err := GetSubscription(s, SubscriptionEntityTask, 18, u)
require.NoError(t, err)
assert.Equal(t, int64(9), sub.ID)
})
}