fix(subscriptions): cleanup and simplify fetching subscribers for tasks and projects logic
Some checks failed
continuous-integration/drone/push Build is failing
Some checks failed
continuous-integration/drone/push Build is failing
Vikunja now uses one recursive CTE and a few optimizations to fetch all subscribers for a task or project. This makes the relevant code easier to maintain and more performant.
This commit is contained in:
parent
850ec7efb0
commit
4ff8815fe1
@ -1745,7 +1745,7 @@ func IsErrSubscriptionAlreadyExists(err error) bool {
|
||||
}
|
||||
|
||||
func (err *ErrSubscriptionAlreadyExists) Error() string {
|
||||
return fmt.Sprintf("Subscription for this (entity_id, entity_type, user_id) already exists [EntityType: %d, EntityID: %d, ID: %d]", err.EntityType, err.EntityID, err.UserID)
|
||||
return fmt.Sprintf("Subscription for this (entity_id, entity_type, user_id) already exists [EntityType: %d, EntityID: %d, UserID: %d]", err.EntityType, err.EntityID, err.UserID)
|
||||
}
|
||||
|
||||
// ErrCodeSubscriptionAlreadyExists holds the unique world-error code of this error
|
||||
@ -1760,6 +1760,32 @@ func (err ErrSubscriptionAlreadyExists) HTTPError() web.HTTPError {
|
||||
}
|
||||
}
|
||||
|
||||
// ErrMustProvideUser represents an error where you need to provide a user to fetch subscriptions
|
||||
type ErrMustProvideUser struct {
|
||||
}
|
||||
|
||||
// IsErrMustProvideUser checks if an error is ErrMustProvideUser.
|
||||
func IsErrMustProvideUser(err error) bool {
|
||||
_, ok := err.(*ErrMustProvideUser)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err *ErrMustProvideUser) Error() string {
|
||||
return "no user provided while fetching subscriptions"
|
||||
}
|
||||
|
||||
// ErrCodeMustProvideUser holds the unique world-error code of this error
|
||||
const ErrCodeMustProvideUser = 12003
|
||||
|
||||
// HTTPError holds the http error description
|
||||
func (err ErrMustProvideUser) HTTPError() web.HTTPError {
|
||||
return web.HTTPError{
|
||||
HTTPCode: http.StatusPreconditionFailed,
|
||||
Code: ErrCodeMustProvideUser,
|
||||
Message: "You must provide a user to fetch subscriptions",
|
||||
}
|
||||
}
|
||||
|
||||
// =================
|
||||
// Link Share errors
|
||||
// =================
|
||||
|
@ -199,7 +199,7 @@ func (s *SendTaskCommentNotification) Handle(msg *message.Message) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID)
|
||||
subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -279,7 +279,7 @@ func (s *SendTaskAssignedNotification) Handle(msg *message.Message) (err error)
|
||||
sess := db.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID)
|
||||
subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -340,12 +340,12 @@ func (s *SendTaskDeletedNotification) Handle(msg *message.Message) (err error) {
|
||||
sess := db.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
var subscribers []*Subscription
|
||||
subscribers, err = getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID)
|
||||
var subscribers []*SubscriptionWithUser
|
||||
subscribers, err = GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID)
|
||||
// If the task does not exist and no one has explicitly subscribed to it, we won't find any subscriptions for it.
|
||||
// Hence, we need to check for subscriptions to the parent project manually.
|
||||
if err != nil && (IsErrTaskDoesNotExist(err) || IsErrProjectDoesNotExist(err)) {
|
||||
subscribers, err = getSubscribersForEntity(sess, SubscriptionEntityProject, event.Task.ProjectID)
|
||||
subscribers, err = GetSubscriptionsForEntity(sess, SubscriptionEntityProject, event.Task.ProjectID)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@ -801,7 +801,7 @@ func (s *SendProjectCreatedNotification) Handle(msg *message.Message) (err error
|
||||
sess := db.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityProject, event.Project.ID)
|
||||
subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityProject, event.Project.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -297,10 +297,13 @@ func (p *Project) ReadOne(s *xorm.Session, a web.Auth) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
p.Subscription, err = GetSubscription(s, SubscriptionEntityProject, p.ID, a)
|
||||
subs, err := GetSubscriptionForUser(s, SubscriptionEntityProject, p.ID, a)
|
||||
if err != nil && IsErrProjectDoesNotExist(err) && isFilter {
|
||||
return nil
|
||||
}
|
||||
if subs != nil {
|
||||
p.Subscription = &subs.Subscription
|
||||
}
|
||||
|
||||
p.Views, err = getViewsForProject(s, p.ID)
|
||||
return
|
||||
@ -629,10 +632,23 @@ func addProjectDetails(s *xorm.Session, projects []*Project, a web.Auth) (err er
|
||||
return err
|
||||
}
|
||||
|
||||
subscriptions, err := GetSubscriptionsForProjects(s, projects, a)
|
||||
if err != nil {
|
||||
log.Errorf("An error occurred while getting project subscriptions for a project: %s", err.Error())
|
||||
subscriptions = make(map[int64][]*Subscription)
|
||||
var subscriptions = make(map[int64][]*Subscription)
|
||||
u, is := a.(*user.User)
|
||||
if is {
|
||||
subscriptionsWithUser, err := GetSubscriptionsForEntitiesAndUser(s, SubscriptionEntityProject, projectIDs, u)
|
||||
if err != nil {
|
||||
log.Errorf("An error occurred while getting project subscriptions for a project: %s", err.Error())
|
||||
}
|
||||
if err == nil {
|
||||
for pID, subs := range subscriptionsWithUser {
|
||||
for _, sub := range subs {
|
||||
if _, has := subscriptions[pID]; !has {
|
||||
subscriptions[pID] = []*Subscription{}
|
||||
}
|
||||
subscriptions[pID] = append(subscriptions[pID], &sub.Subscription)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
views := []*ProjectView{}
|
||||
|
@ -17,12 +17,13 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"xorm.io/builder"
|
||||
|
||||
"code.vikunja.io/api/pkg/user"
|
||||
"code.vikunja.io/api/pkg/utils"
|
||||
"code.vikunja.io/api/pkg/web"
|
||||
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
@ -52,8 +53,7 @@ type Subscription struct {
|
||||
EntityID int64 `xorm:"bigint index not null" json:"entity_id" param:"entityID"`
|
||||
|
||||
// The user who made this subscription
|
||||
User *user.User `xorm:"-" json:"user"`
|
||||
UserID int64 `xorm:"bigint index not null" json:"-"`
|
||||
UserID int64 `xorm:"bigint index not null" json:"-"`
|
||||
|
||||
// A timestamp when this subscription was created. You cannot change this value.
|
||||
Created time.Time `xorm:"created not null" json:"created"`
|
||||
@ -62,7 +62,18 @@ type Subscription struct {
|
||||
web.Rights `xorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// TableName gives us a better tabel name for the subscriptions table
|
||||
type SubscriptionWithUser struct {
|
||||
Subscription `xorm:"extends"`
|
||||
User *user.User `xorm:"extends" json:"user"`
|
||||
}
|
||||
|
||||
type subscriptionResolved struct {
|
||||
OriginalEntityID int64
|
||||
SubscriptionID int64
|
||||
SubscriptionWithUser `xorm:"extends"`
|
||||
}
|
||||
|
||||
// TableName gives us a better table name for the subscriptions table
|
||||
func (sb *Subscription) TableName() string {
|
||||
return "subscriptions"
|
||||
}
|
||||
@ -115,28 +126,23 @@ func (et SubscriptionEntityType) validate() error {
|
||||
// @Failure 500 {object} models.Message "Internal error"
|
||||
// @Router /subscriptions/{entity}/{entityID} [put]
|
||||
func (sb *Subscription) Create(s *xorm.Session, auth web.Auth) (err error) {
|
||||
// Rights method alread does the validation of the entity type so we don't need to do that here
|
||||
// Rights method already does the validation of the entity type, so we don't need to do that here
|
||||
|
||||
sb.UserID = auth.GetID()
|
||||
|
||||
sub, err := GetSubscription(s, sb.EntityType, sb.EntityID, auth)
|
||||
sub, err := GetSubscriptionForUser(s, sb.EntityType, sb.EntityID, auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sub != nil {
|
||||
return &ErrSubscriptionAlreadyExists{
|
||||
EntityID: sb.EntityID,
|
||||
EntityType: sb.EntityType,
|
||||
UserID: sb.UserID,
|
||||
EntityID: sub.EntityID,
|
||||
EntityType: sub.EntityType,
|
||||
UserID: sub.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
_, err = s.Insert(sb)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sb.User, err = user.GetFromAuth(auth)
|
||||
return
|
||||
}
|
||||
|
||||
@ -163,261 +169,228 @@ func (sb *Subscription) Delete(s *xorm.Session, auth web.Auth) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs []int64) (cond builder.Cond) {
|
||||
if entityType == SubscriptionEntityProject {
|
||||
return builder.And(
|
||||
builder.In("entity_id", entityIDs),
|
||||
builder.Eq{"entity_type": SubscriptionEntityProject},
|
||||
)
|
||||
}
|
||||
|
||||
if entityType == SubscriptionEntityTask {
|
||||
return builder.Or(
|
||||
builder.And(
|
||||
builder.In("entity_id", entityIDs),
|
||||
builder.Eq{"entity_type": SubscriptionEntityTask},
|
||||
),
|
||||
builder.And(
|
||||
builder.Eq{"entity_id": builder.
|
||||
Select("project_id").
|
||||
From("tasks").
|
||||
Where(builder.In("id", entityIDs)),
|
||||
// TODO parent project
|
||||
},
|
||||
builder.Eq{"entity_type": SubscriptionEntityProject},
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetSubscription returns a matching subscription for an entity and user.
|
||||
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
|
||||
// that task, if there is none it will look for a subscription on the project the task belongs to.
|
||||
func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) {
|
||||
subs, err := GetSubscriptions(s, entityType, entityID, a)
|
||||
if err != nil || len(subs) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs[0], nil
|
||||
}
|
||||
|
||||
// GetSubscriptions returns a list of subscriptions to for an entity ID
|
||||
func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) {
|
||||
func GetSubscriptionForUser(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *SubscriptionWithUser, err error) {
|
||||
u, is := a.(*user.User)
|
||||
if u != nil && !is {
|
||||
return
|
||||
}
|
||||
|
||||
subs, err := GetSubscriptionsForEntitiesAndUser(s, entityType, []int64{entityID}, u)
|
||||
if err != nil || len(subs) == 0 || len(subs[entityID]) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return subs[entityID][0], nil
|
||||
}
|
||||
|
||||
// GetSubscriptionsForEntities returns a list of subscriptions to for an entity ID
|
||||
func GetSubscriptionsForEntities(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
|
||||
return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, nil, false)
|
||||
}
|
||||
|
||||
func GetSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
|
||||
return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, u, true)
|
||||
}
|
||||
|
||||
func GetSubscriptionsForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*SubscriptionWithUser, err error) {
|
||||
subs, err := GetSubscriptionsForEntities(s, entityType, []int64{entityID})
|
||||
if err != nil || len(subs[entityID]) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
return subs[entityID], nil
|
||||
}
|
||||
|
||||
// This function returns a matching subscription for an entity and user.
|
||||
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
|
||||
// that task, if there is none it will look for a subscription on the project the task belongs to.
|
||||
// It will return a map where the key is the entity id and the value is a slice with all subscriptions for that entity.
|
||||
func getSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User, userOnly bool) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
|
||||
if err := entityType.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawSubscriptions := []*subscriptionResolved{}
|
||||
entityIDString := utils.JoinInt64Slice(entityIDs, ", ")
|
||||
|
||||
var sUserCond string
|
||||
if userOnly {
|
||||
if u == nil {
|
||||
return nil, &ErrMustProvideUser{}
|
||||
}
|
||||
sUserCond = " AND s.user_id = " + strconv.FormatInt(u.ID, 10)
|
||||
}
|
||||
|
||||
switch entityType {
|
||||
case SubscriptionEntityProject:
|
||||
project, err := GetProjectSimpleByID(s, entityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, has := subs[entityID]; has && subs[entityID] != nil {
|
||||
return subs[entityID], nil
|
||||
}
|
||||
err = s.SQL(`
|
||||
WITH RECURSIVE project_hierarchy AS (
|
||||
-- Base case: Start with the specified projects
|
||||
SELECT
|
||||
id,
|
||||
parent_project_id,
|
||||
0 AS level,
|
||||
id AS original_project_id
|
||||
FROM projects
|
||||
WHERE id IN (`+entityIDString+`)
|
||||
|
||||
for _, sub := range subs {
|
||||
// Fallback to the first non-nil subscription
|
||||
if len(sub) > 0 {
|
||||
return sub, nil
|
||||
}
|
||||
}
|
||||
UNION ALL
|
||||
|
||||
return nil, nil
|
||||
-- Recursive case: Get parent projects
|
||||
SELECT
|
||||
p.id,
|
||||
p.parent_project_id,
|
||||
ph.level + 1,
|
||||
ph.original_project_id
|
||||
FROM projects p
|
||||
INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id
|
||||
),
|
||||
|
||||
subscription_hierarchy AS (
|
||||
-- Check for project subscriptions (including parent projects)
|
||||
SELECT
|
||||
s.id,
|
||||
s.entity_type,
|
||||
s.entity_id,
|
||||
s.created,
|
||||
s.user_id,
|
||||
CASE
|
||||
WHEN s.entity_id = ph.original_project_id THEN 1 -- Direct project match
|
||||
ELSE ph.level + 1 -- Parent projects
|
||||
END AS priority,
|
||||
ph.original_project_id
|
||||
FROM subscriptions s
|
||||
INNER JOIN project_hierarchy ph ON s.entity_id = ph.id
|
||||
WHERE s.entity_type = ?`+sUserCond+`
|
||||
)
|
||||
|
||||
SELECT
|
||||
p.id AS original_entity_id,
|
||||
sh.id AS subscription_id,
|
||||
sh.entity_type,
|
||||
sh.entity_id,
|
||||
sh.created,
|
||||
sh.user_id,
|
||||
CASE
|
||||
WHEN sh.priority = 1 THEN 'Direct Project'
|
||||
ELSE 'Parent Project'
|
||||
END
|
||||
AS subscription_level,
|
||||
users.*
|
||||
FROM projects p
|
||||
LEFT JOIN (
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY original_project_id, user_id ORDER BY priority) AS rn
|
||||
FROM subscription_hierarchy
|
||||
) sh ON p.id = sh.original_project_id AND sh.rn = 1
|
||||
LEFT JOIN users ON sh.user_id = users.id
|
||||
WHERE p.id IN (`+entityIDString+`)
|
||||
ORDER BY p.id, sh.user_id`, SubscriptionEntityProject).
|
||||
Find(&rawSubscriptions)
|
||||
case SubscriptionEntityTask:
|
||||
subs, err := getSubscriptionsForTask(s, entityID, u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.SQL(`
|
||||
WITH RECURSIVE project_hierarchy AS (
|
||||
-- Base case: Start with the projects associated with the tasks
|
||||
SELECT
|
||||
p.id,
|
||||
p.parent_project_id,
|
||||
0 AS level,
|
||||
t.id AS task_id
|
||||
FROM tasks t
|
||||
JOIN projects p ON t.project_id = p.id
|
||||
WHERE t.id IN (`+entityIDString+`)
|
||||
|
||||
for _, sub := range subs {
|
||||
// The subscriptions might also contain the immediate parent subscription, if that exists.
|
||||
// This loop makes sure to only return the task subscription if it exists. The fallback
|
||||
// happens in the next if after the loop.
|
||||
if sub.EntityID == entityID && sub.EntityType == SubscriptionEntityTask {
|
||||
return []*Subscription{sub}, nil
|
||||
}
|
||||
}
|
||||
UNION ALL
|
||||
|
||||
if len(subs) > 0 {
|
||||
return subs, nil
|
||||
}
|
||||
-- Recursive case: Get parent projects
|
||||
SELECT
|
||||
p.id,
|
||||
p.parent_project_id,
|
||||
ph.level + 1,
|
||||
ph.task_id
|
||||
FROM projects p
|
||||
INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id
|
||||
),
|
||||
|
||||
projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscription_hierarchy AS (
|
||||
-- Check for task subscriptions
|
||||
SELECT
|
||||
s.id,
|
||||
s.entity_type,
|
||||
s.entity_id,
|
||||
s.created,
|
||||
s.user_id,
|
||||
1 AS priority,
|
||||
t.id AS task_id
|
||||
FROM subscriptions s
|
||||
JOIN tasks t ON s.entity_id = t.id
|
||||
WHERE s.entity_type = ? AND t.id IN (`+entityIDString+`)`+sUserCond+`
|
||||
|
||||
projectSubscriptions, err := GetSubscriptionsForProjects(s, projects, u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
UNION ALL
|
||||
|
||||
if _, has := projectSubscriptions[projects[0].ID]; has {
|
||||
return projectSubscriptions[projects[0].ID], nil
|
||||
}
|
||||
-- Check for project subscriptions (including parent projects)
|
||||
SELECT
|
||||
s.id,
|
||||
s.entity_type,
|
||||
s.entity_id,
|
||||
s.created,
|
||||
s.user_id,
|
||||
ph.level + 2 AS priority,
|
||||
ph.task_id
|
||||
FROM subscriptions s
|
||||
INNER JOIN project_hierarchy ph ON s.entity_id = ph.id
|
||||
WHERE s.entity_type = ?
|
||||
)
|
||||
|
||||
for _, psub := range projectSubscriptions {
|
||||
// Fallback to the first non-nil subscription
|
||||
if len(psub) > 0 {
|
||||
return psub, nil
|
||||
}
|
||||
}
|
||||
|
||||
return subs, nil
|
||||
SELECT
|
||||
t.id AS original_entity_id,
|
||||
sh.id AS subscription_id,
|
||||
sh.entity_type,
|
||||
sh.entity_id,
|
||||
sh.created,
|
||||
sh.user_id,
|
||||
CASE
|
||||
WHEN sh.entity_type = ? THEN 'Task'
|
||||
WHEN sh.priority = ? THEN 'Direct Project'
|
||||
ELSE 'Parent Project'
|
||||
END
|
||||
AS subscription_level,
|
||||
users.*
|
||||
FROM tasks t
|
||||
LEFT JOIN (
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (PARTITION BY task_id, user_id ORDER BY priority) AS rn
|
||||
FROM subscription_hierarchy
|
||||
) sh ON t.id = sh.task_id AND sh.rn = 1
|
||||
LEFT JOIN users ON sh.user_id = users.id
|
||||
WHERE t.id IN (`+entityIDString+`)
|
||||
ORDER BY t.id, sh.user_id`,
|
||||
SubscriptionEntityTask, SubscriptionEntityProject, SubscriptionEntityTask, SubscriptionEntityProject).
|
||||
Find(&rawSubscriptions)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
subscriptions = make(map[int64][]*SubscriptionWithUser)
|
||||
for _, sub := range rawSubscriptions {
|
||||
|
||||
func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) {
|
||||
u, is := a.(*user.User)
|
||||
if u != nil && !is {
|
||||
return
|
||||
}
|
||||
|
||||
var ps = make(map[int64]*Project)
|
||||
origProjectIDs := make([]int64, 0, len(projects))
|
||||
allProjectIDs := make([]int64, 0, len(projects))
|
||||
|
||||
for _, p := range projects {
|
||||
ps[p.ID] = p
|
||||
origProjectIDs = append(origProjectIDs, p.ID)
|
||||
allProjectIDs = append(allProjectIDs, p.ID)
|
||||
}
|
||||
|
||||
// We can't just use the projects we have, we need to fetch the parents
|
||||
// because they may not be loaded in the same object
|
||||
|
||||
for _, p := range projects {
|
||||
if p.ParentProjectID == 0 {
|
||||
if sub.Subscription.EntityID == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, has := ps[p.ParentProjectID]; has {
|
||||
continue
|
||||
_, has := subscriptions[sub.OriginalEntityID]
|
||||
if !has {
|
||||
subscriptions[sub.OriginalEntityID] = []*SubscriptionWithUser{}
|
||||
}
|
||||
|
||||
parents, err := GetAllParentProjects(s, p.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
sub.Subscription.ID = sub.SubscriptionID
|
||||
if sub.User != nil {
|
||||
sub.User.ID = sub.UserID
|
||||
}
|
||||
|
||||
// Walk the tree up until we reach the top
|
||||
var parent = parents[p.ParentProjectID] // parent now has a pointer…
|
||||
ps[p.ID].ParentProject = parents[p.ParentProjectID]
|
||||
for parent != nil {
|
||||
allProjectIDs = append(allProjectIDs, parent.ID)
|
||||
parent = parents[parent.ParentProjectID] // … which means we can update it here and then update the pointer in the map
|
||||
}
|
||||
subscriptions[sub.OriginalEntityID] = append(subscriptions[sub.OriginalEntityID], &sub.SubscriptionWithUser)
|
||||
}
|
||||
|
||||
var subscriptions []*Subscription
|
||||
if u != nil {
|
||||
err = s.
|
||||
Where("user_id = ?", u.ID).
|
||||
And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)).
|
||||
Find(&subscriptions)
|
||||
} else {
|
||||
err = s.
|
||||
And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)).
|
||||
Find(&subscriptions)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
projectsToSubscriptions = make(map[int64][]*Subscription)
|
||||
for _, sub := range subscriptions {
|
||||
sub.Entity = sub.EntityType.String()
|
||||
projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub)
|
||||
}
|
||||
|
||||
// Rearrange so that subscriptions trickle down
|
||||
|
||||
for _, eID := range origProjectIDs {
|
||||
// If the current project does not have a subscription, climb up the tree until a project has one,
|
||||
// then use that subscription for all child projects
|
||||
_, has := projectsToSubscriptions[eID]
|
||||
_, hasProject := ps[eID]
|
||||
if !has && hasProject {
|
||||
_, exists := ps[eID]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
var parent = ps[eID].ParentProject
|
||||
for parent != nil {
|
||||
sub, has := projectsToSubscriptions[parent.ID]
|
||||
projectsToSubscriptions[eID] = sub
|
||||
parent = parent.ParentProject
|
||||
if has { // reached the top of the tree
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return projectsToSubscriptions, nil
|
||||
}
|
||||
|
||||
func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) {
|
||||
if u != nil {
|
||||
err = s.
|
||||
Where("user_id = ?", u.ID).
|
||||
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
|
||||
Find(&subscriptions)
|
||||
} else {
|
||||
err = s.
|
||||
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
|
||||
Find(&subscriptions)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sub := range subscriptions {
|
||||
sub.Entity = sub.EntityType.String()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*Subscription, err error) {
|
||||
if err := entityType.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subs, err := GetSubscriptions(s, entityType, entityID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
userIDs := []int64{}
|
||||
subscriptions = make([]*Subscription, 0, len(subs))
|
||||
for _, subscription := range subs {
|
||||
userIDs = append(userIDs, subscription.UserID)
|
||||
subscriptions = append(subscriptions, subscription)
|
||||
}
|
||||
|
||||
users, err := user.GetUsersByIDs(s, userIDs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, subscription := range subscriptions {
|
||||
subscription.User = users[subscription.UserID]
|
||||
}
|
||||
return
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
@ -52,7 +52,6 @@ func TestSubscription_Create(t *testing.T) {
|
||||
sb := &Subscription{
|
||||
Entity: "task",
|
||||
EntityID: 1,
|
||||
UserID: u.ID,
|
||||
}
|
||||
|
||||
can, err := sb.CanCreate(s, u)
|
||||
@ -61,7 +60,6 @@ func TestSubscription_Create(t *testing.T) {
|
||||
|
||||
err = sb.Create(s, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sb.User)
|
||||
|
||||
db.AssertExists(t, "subscriptions", map[string]interface{}{
|
||||
"entity_type": 3,
|
||||
@ -69,6 +67,26 @@ func TestSubscription_Create(t *testing.T) {
|
||||
"user_id": u.ID,
|
||||
}, false)
|
||||
})
|
||||
t.Run("already exists", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
sb := &Subscription{
|
||||
Entity: "task",
|
||||
EntityID: 2,
|
||||
UserID: u.ID,
|
||||
}
|
||||
|
||||
can, err := sb.CanCreate(s, u)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, can)
|
||||
|
||||
err = sb.Create(s, u)
|
||||
require.Error(t, err)
|
||||
terr := &ErrSubscriptionAlreadyExists{}
|
||||
assert.ErrorAs(t, err, &terr)
|
||||
})
|
||||
t.Run("forbidden for link shares", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
@ -86,7 +104,7 @@ func TestSubscription_Create(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.False(t, can)
|
||||
})
|
||||
t.Run("noneixsting project", func(t *testing.T) {
|
||||
t.Run("nonexisting project", func(t *testing.T) {
|
||||
db.LoadAndAssertFixtures(t)
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
@ -240,7 +258,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
sub, err := GetSubscription(s, SubscriptionEntityProject, 12, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 12, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
assert.Equal(t, int64(3), sub.ID)
|
||||
@ -250,7 +268,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
sub, err := GetSubscription(s, SubscriptionEntityTask, 22, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 22, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
assert.Equal(t, int64(4), sub.ID)
|
||||
@ -263,7 +281,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Project 25 belongs to project 12 where user 6 has subscribed to
|
||||
sub, err := GetSubscription(s, SubscriptionEntityProject, 25, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 25, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
assert.Equal(t, int64(12), sub.EntityID)
|
||||
@ -275,7 +293,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Project 26 belongs to project 25 which belongs to project 12 where user 6 has subscribed to
|
||||
sub, err := GetSubscription(s, SubscriptionEntityProject, 26, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 26, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
assert.Equal(t, int64(12), sub.EntityID)
|
||||
@ -287,7 +305,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Task 39 belongs to project 25 which belongs to project 12 where the user has subscribed
|
||||
sub, err := GetSubscription(s, SubscriptionEntityTask, 39, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 39, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
// assert.Equal(t, int64(2), sub.ID) TODO
|
||||
@ -298,7 +316,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Task 21 belongs to project 32 which the user has subscribed to
|
||||
sub, err := GetSubscription(s, SubscriptionEntityTask, 21, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 21, u)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, sub)
|
||||
assert.Equal(t, int64(8), sub.ID)
|
||||
@ -309,7 +327,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
_, err := GetSubscription(s, 2342, 21, u)
|
||||
_, err := GetSubscriptionForUser(s, 2342, 21, u)
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsErrUnknownSubscriptionEntityType(err))
|
||||
})
|
||||
@ -318,7 +336,7 @@ func TestSubscriptionGet(t *testing.T) {
|
||||
s := db.NewSession()
|
||||
defer s.Close()
|
||||
|
||||
sub, err := GetSubscription(s, SubscriptionEntityTask, 18, u)
|
||||
sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 18, u)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(9), sub.ID)
|
||||
})
|
||||
|
@ -1574,10 +1574,11 @@ func (t *Task) ReadOne(s *xorm.Session, a web.Auth) (err error) {
|
||||
|
||||
*t = *taskMap[t.ID]
|
||||
|
||||
t.Subscription, err = GetSubscription(s, SubscriptionEntityTask, t.ID, a)
|
||||
subs, err := GetSubscriptionForUser(s, SubscriptionEntityTask, t.ID, a)
|
||||
if err != nil && IsErrProjectDoesNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
t.Subscription = &subs.Subscription
|
||||
|
||||
return
|
||||
}
|
||||
|
31
pkg/utils/strings.go
Normal file
31
pkg/utils/strings.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Vikunja is a to-do list application to facilitate your life.
|
||||
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public Licensee as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public Licensee for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public Licensee
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package utils
|
||||
|
||||
import "strconv"
|
||||
|
||||
func JoinInt64Slice(ints []int64, delim string) string {
|
||||
b := ""
|
||||
for _, v := range ints {
|
||||
if len(b) > 0 {
|
||||
b += delim
|
||||
}
|
||||
b += strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user