From ffcfc85b0086eee3e09690466116336422d8577a Mon Sep 17 00:00:00 2001 From: kolaente Date: Wed, 22 Nov 2023 10:33:03 +0100 Subject: [PATCH] fix(filter): correctly filter for buckets --- pkg/models/error.go | 2 +- pkg/models/kanban.go | 37 ++++++++++++++++++---------- pkg/models/kanban_test.go | 24 ++++++++++++++++++ pkg/models/task_collection.go | 2 +- pkg/models/task_collection_filter.go | 14 +++++------ 5 files changed, 57 insertions(+), 22 deletions(-) diff --git a/pkg/models/error.go b/pkg/models/error.go index 84220540f..431d9e7d5 100644 --- a/pkg/models/error.go +++ b/pkg/models/error.go @@ -1045,7 +1045,7 @@ func IsErrInvalidFilterExpression(err error) bool { } func (err ErrInvalidFilterExpression) Error() string { - return fmt.Sprintf("Task filter expression is invalid [ExpressionError: %v]", err.ExpressionError) + return fmt.Sprintf("Task filter expression '%s' is invalid [ExpressionError: %v]", err.Expression, err.ExpressionError) } // ErrCodeInvalidFilterExpression holds the unique world-error code of this error diff --git a/pkg/models/kanban.go b/pkg/models/kanban.go index 3da2ebe20..3cf007a82 100644 --- a/pkg/models/kanban.go +++ b/pkg/models/kanban.go @@ -17,6 +17,8 @@ package models import ( + "strconv" + "strings" "time" "code.vikunja.io/api/pkg/log" @@ -175,26 +177,35 @@ func (b *Bucket) ReadAll(s *xorm.Session, auth web.Auth, search string, page int opts.search = search opts.filterConcat = filterConcatAnd - var bucketFilterIndex int - for i, filter := range opts.filters { + for _, filter := range opts.filters { if filter.field == taskPropertyBucketID { - bucketFilterIndex = i + + // Limiting the map to the one filter we're looking for is the easiest way to ensure we only + // get tasks in this bucket + bucketID := filter.value.(int64) + bucket := bucketMap[bucketID] + + bucketMap = make(map[int64]*Bucket, 1) + bucketMap[bucketID] = bucket break } } - if bucketFilterIndex == 0 { - opts.filters = append(opts.filters, &taskFilter{ - field: taskPropertyBucketID, - value: 0, - comparator: taskFilterComparatorEquals, - }) - bucketFilterIndex = len(opts.filters) - 1 - } - + originalFilter := opts.filter for id, bucket := range bucketMap { - opts.filters[bucketFilterIndex].value = id + if !strings.Contains(originalFilter, "bucket_id") { + var filterString string + if originalFilter == "" { + filterString = "bucket_id = " + strconv.FormatInt(id, 10) + } else { + filterString = "(" + originalFilter + ") && bucket_id = " + strconv.FormatInt(id, 10) + } + opts.filters, err = getTaskFiltersFromFilterString(filterString) + if err != nil { + return + } + } ts, _, total, err := getRawTasksForProjects(s, []*Project{{ID: bucket.ProjectID}}, auth, opts) if err != nil { diff --git a/pkg/models/kanban_test.go b/pkg/models/kanban_test.go index 5993ed9ab..8a1f59a87 100644 --- a/pkg/models/kanban_test.go +++ b/pkg/models/kanban_test.go @@ -92,6 +92,30 @@ func TestBucket_ReadAll(t *testing.T) { assert.Equal(t, int64(2), buckets[0].Tasks[0].ID) assert.Equal(t, int64(33), buckets[0].Tasks[1].ID) }) + t.Run("filtered by bucket", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + testuser := &user.User{ID: 1} + b := &Bucket{ + ProjectID: 1, + TaskCollection: TaskCollection{ + Filter: "title ~ 'task' && bucket_id = 2", + }, + } + bucketsInterface, _, _, err := b.ReadAll(s, testuser, "", -1, 0) + assert.NoError(t, err) + + buckets := bucketsInterface.([]*Bucket) + assert.Len(t, buckets, 3) + assert.Len(t, buckets[0].Tasks, 0) + assert.Len(t, buckets[1].Tasks, 3) + assert.Len(t, buckets[2].Tasks, 0) + assert.Equal(t, int64(3), buckets[1].Tasks[0].ID) + assert.Equal(t, int64(4), buckets[1].Tasks[1].ID) + assert.Equal(t, int64(5), buckets[1].Tasks[2].ID) + }) t.Run("accessed by link share", func(t *testing.T) { db.LoadAndAssertFixtures(t) s := db.NewSession() diff --git a/pkg/models/task_collection.go b/pkg/models/task_collection.go index f8fd6cbac..ab2f65321 100644 --- a/pkg/models/task_collection.go +++ b/pkg/models/task_collection.go @@ -105,7 +105,7 @@ func getTaskFilterOptsFromCollection(tf *TaskCollection) (opts *taskSearchOption filter: tf.Filter, } - opts.filters, err = getTaskFiltersByCollections(tf) + opts.filters, err = getTaskFiltersFromFilterString(tf.Filter) return opts, err } diff --git a/pkg/models/task_collection_filter.go b/pkg/models/task_collection_filter.go index 7069a70b1..3162106c4 100644 --- a/pkg/models/task_collection_filter.go +++ b/pkg/models/task_collection_filter.go @@ -144,29 +144,29 @@ func parseFilterFromExpression(f fexpr.ExprGroup) (filter *taskFilter, err error return filter, nil } -func getTaskFiltersByCollections(c *TaskCollection) (filters []*taskFilter, err error) { +func getTaskFiltersFromFilterString(filter string) (filters []*taskFilter, err error) { - if c.Filter == "" { + if filter == "" { return } - c.Filter = strings.ReplaceAll(c.Filter, " in ", " ?= ") + filter = strings.ReplaceAll(filter, " in ", " ?= ") - parsedFilter, err := fexpr.Parse(c.Filter) + parsedFilter, err := fexpr.Parse(filter) if err != nil { return nil, &ErrInvalidFilterExpression{ - Expression: c.Filter, + Expression: filter, ExpressionError: err, } } filters = make([]*taskFilter, 0, len(parsedFilter)) for _, f := range parsedFilter { - filter, err := parseFilterFromExpression(f) + parsedFilter, err := parseFilterFromExpression(f) if err != nil { return nil, err } - filters = append(filters, filter) + filters = append(filters, parsedFilter) } return