diff --git a/pkg/models/label.go b/pkg/models/label.go index 890a7164fa..8d8acef82e 100644 --- a/pkg/models/label.go +++ b/pkg/models/label.go @@ -146,10 +146,7 @@ func (l *Label) ReadAll(s *xorm.Session, a web.Auth, search string, page int, pe return nil, 0, 0, ErrGenericForbidden{} } - u := &user.User{ID: a.GetID()} - - // Get all tasks - taskIDs, err := getUserTaskIDs(s, u) + u, err := user.GetUserByID(s, a.GetID()) if err != nil { return nil, 0, 0, err } @@ -157,7 +154,7 @@ func (l *Label) ReadAll(s *xorm.Session, a web.Auth, search string, page int, pe return getLabelsByTaskIDs(s, &LabelByTaskIDsOptions{ Search: search, User: u, - TaskIDs: taskIDs, + GetForUser: u.ID, Page: page, PerPage: perPage, GetUnusedLabels: true, @@ -206,34 +203,3 @@ func getLabelByIDSimple(s *xorm.Session, labelID int64) (*Label, error) { } return &label, err } - -// Helper method to get all task ids a user has -func getUserTaskIDs(s *xorm.Session, u *user.User) (taskIDs []int64, err error) { - - // Get all lists - lists, _, _, err := getRawListsForUser( - s, - &listOptions{ - user: u, - page: -1, - }, - ) - if err != nil { - return nil, err - } - - tasks, _, _, err := getRawTasksForLists(s, lists, u, &taskOptions{ - page: -1, - perPage: 0, - }) - if err != nil { - return nil, err - } - - // make a slice of task ids - for _, t := range tasks { - taskIDs = append(taskIDs, t.ID) - } - - return -} diff --git a/pkg/models/label_rights.go b/pkg/models/label_rights.go index bd8c153371..9bb2926f2e 100644 --- a/pkg/models/label_rights.go +++ b/pkg/models/label_rights.go @@ -64,21 +64,28 @@ func (l *Label) isLabelOwner(s *xorm.Session, a web.Auth) (bool, error) { // Helper method to check if a user can see a specific label func (l *Label) hasAccessToLabel(s *xorm.Session, a web.Auth) (has bool, maxRight int, err error) { - // TODO: add an extra check for link share handling + if _, is := a.(*LinkSharing); is { + return false, 0, nil + } - // Get all tasks - taskIDs, err := getUserTaskIDs(s, &user.User{ID: a.GetID()}) + u, err := user.GetUserByID(s, a.GetID()) if err != nil { return false, 0, err } - // Get all labels associated with these tasks + cond := builder.In("label_task.task_id", + builder. + Select("id"). + From("tasks"). + Where(builder.In("list_id", getUserListsStatement(u.ID).Select("l.id"))), + ) + ll := &LabelTask{} has, err = s.Table("labels"). Select("label_task.*"). Join("LEFT", "label_task", "label_task.label_id = labels.id"). - Where("label_task.label_id is not null OR labels.created_by_id = ?", a.GetID()). - Or(builder.In("label_task.task_id", taskIDs)). + Where("label_task.label_id is not null OR labels.created_by_id = ?", u.ID). + Or(cond). And("labels.id = ?", l.ID). Exist(ll) if err != nil { diff --git a/pkg/models/label_task.go b/pkg/models/label_task.go index 5c94a5c77e..ae6672d654 100644 --- a/pkg/models/label_task.go +++ b/pkg/models/label_task.go @@ -149,6 +149,7 @@ type LabelByTaskIDsOptions struct { TaskIDs []int64 GetUnusedLabels bool GroupByLabelIDsOnly bool + GetForUser int64 } // Helper function to get all labels for a set of tasks @@ -168,22 +169,32 @@ func getLabelsByTaskIDs(s *xorm.Session, opts *LabelByTaskIDsOptions) (ls []*lab // Get all labels associated with these tasks var labels []*labelWithTaskID cond := builder.And(builder.NotNull{"label_task.label_id"}) - if len(opts.TaskIDs) > 0 { + if len(opts.TaskIDs) > 0 && opts.GetForUser == 0 { cond = builder.And(builder.In("label_task.task_id", opts.TaskIDs), cond) } + if opts.GetForUser != 0 { + cond = builder.And(builder.In("label_task.task_id", + builder. + Select("id"). + From("tasks"). + Where(builder.In("list_id", getUserListsStatement(opts.GetForUser).Select("l.id"))), + ), cond) + } if opts.GetUnusedLabels { cond = builder.Or(cond, builder.Eq{"labels.created_by_id": opts.User.ID}) } - vals := strings.Split(opts.Search, ",") ids := []int64{} - for _, val := range vals { - v, err := strconv.ParseInt(val, 10, 64) - if err != nil { - log.Debugf("Label search string part '%s' is not a number: %s", val, err) - continue + if opts.Search != "" { + vals := strings.Split(opts.Search, ",") + for _, val := range vals { + v, err := strconv.ParseInt(val, 10, 64) + if err != nil { + log.Debugf("Label search string part '%s' is not a number: %s", val, err) + continue + } + ids = append(ids, v) } - ids = append(ids, v) } if len(ids) > 0 { diff --git a/pkg/models/list.go b/pkg/models/list.go index a742402fe4..5ced67d42a 100644 --- a/pkg/models/list.go +++ b/pkg/models/list.go @@ -21,11 +21,10 @@ import ( "strings" "time" + "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/events" - - "code.vikunja.io/api/pkg/log" - "code.vikunja.io/api/pkg/files" + "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/user" "code.vikunja.io/web" "xorm.io/builder" @@ -329,6 +328,32 @@ type listOptions struct { isArchived bool } +func getUserListsStatement(userID int64) *builder.Builder { + dialect := config.DatabaseType.GetString() + if dialect == "sqlite" { + dialect = builder.SQLITE + } + + return builder.Dialect(dialect). + Select("l.*"). + From("list", "l"). + Join("INNER", "namespaces n", "l.namespace_id = n.id"). + Join("LEFT", "team_namespaces tn", "tn.namespace_id = n.id"). + Join("LEFT", "team_members tm", "tm.team_id = tn.team_id"). + Join("LEFT", "team_list tl", "l.id = tl.list_id"). + Join("LEFT", "team_members tm2", "tm2.team_id = tl.team_id"). + Join("LEFT", "users_list ul", "ul.list_id = l.id"). + Join("LEFT", "users_namespace un", "un.namespace_id = l.namespace_id"). + Where(builder.Or( + builder.Eq{"tm.user_id": userID}, + builder.Eq{"tm2.user_id": userID}, + builder.Eq{"ul.user_id": userID}, + builder.Eq{"un.user_id": userID}, + builder.Eq{"l.owner_id": userID}, + )). + GroupBy("l.id") +} + // Gets the lists only, without any tasks or so func getRawListsForUser(s *xorm.Session, opts *listOptions) (lists []*List, resultCount int, totalItems int64, err error) { fullUser, err := user.GetUserByID(s, opts.user.ID) @@ -367,54 +392,23 @@ func getRawListsForUser(s *xorm.Session, opts *listOptions) (lists []*List, resu // Gets all Lists where the user is either owner or in a team which has access to the list // Or in a team which has namespace read access - query := s.Select("l.*"). - Table("list"). - Alias("l"). - Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id"). - Join("LEFT", []string{"team_namespaces", "tn"}, "tn.namespace_id = n.id"). - Join("LEFT", []string{"team_members", "tm"}, "tm.team_id = tn.team_id"). - Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id"). - Join("LEFT", []string{"team_members", "tm2"}, "tm2.team_id = tl.team_id"). - Join("LEFT", []string{"users_list", "ul"}, "ul.list_id = l.id"). - Join("LEFT", []string{"users_namespace", "un"}, "un.namespace_id = l.namespace_id"). - Where(builder.Or( - builder.Eq{"tm.user_id": fullUser.ID}, - builder.Eq{"tm2.user_id": fullUser.ID}, - builder.Eq{"ul.user_id": fullUser.ID}, - builder.Eq{"un.user_id": fullUser.ID}, - builder.Eq{"l.owner_id": fullUser.ID}, - )). - GroupBy("l.id"). + + query := getUserListsStatement(fullUser.ID). Where(filterCond). Where(isArchivedCond) if limit > 0 { query = query.Limit(limit, start) } - err = query.Find(&lists) + err = s.SQL(query).Find(&lists) if err != nil { return nil, 0, 0, err } - totalItems, err = s. - Table("list"). - Alias("l"). - Join("INNER", []string{"namespaces", "n"}, "l.namespace_id = n.id"). - Join("LEFT", []string{"team_namespaces", "tn"}, "tn.namespace_id = n.id"). - Join("LEFT", []string{"team_members", "tm"}, "tm.team_id = tn.team_id"). - Join("LEFT", []string{"team_list", "tl"}, "l.id = tl.list_id"). - Join("LEFT", []string{"team_members", "tm2"}, "tm2.team_id = tl.team_id"). - Join("LEFT", []string{"users_list", "ul"}, "ul.list_id = l.id"). - Join("LEFT", []string{"users_namespace", "un"}, "un.namespace_id = l.namespace_id"). - Where(builder.Or( - builder.Eq{"tm.user_id": fullUser.ID}, - builder.Eq{"tm2.user_id": fullUser.ID}, - builder.Eq{"ul.user_id": fullUser.ID}, - builder.Eq{"un.user_id": fullUser.ID}, - builder.Eq{"l.owner_id": fullUser.ID}, - )). - GroupBy("l.id"). + query = getUserListsStatement(fullUser.ID). Where(filterCond). - Where(isArchivedCond). + Where(isArchivedCond) + totalItems, err = s. + SQL(query.Select("count(*)")). Count(&List{}) return lists, len(lists), totalItems, err }