From f9b48ec091b387fde5b6a0e70c992a46ecc4294b Mon Sep 17 00:00:00 2001 From: kolaente Date: Tue, 8 Nov 2022 17:03:07 +0100 Subject: [PATCH] fix(filter): only check for 0 values in filter fields with numeric values --- pkg/models/task_collection_filter.go | 24 +++++++++++++++--------- pkg/models/tasks.go | 5 ++++- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pkg/models/task_collection_filter.go b/pkg/models/task_collection_filter.go index 11718dbcf22..dbd92825249 100644 --- a/pkg/models/task_collection_filter.go +++ b/pkg/models/task_collection_filter.go @@ -48,6 +48,7 @@ type taskFilter struct { field string value interface{} // Needs to be an interface to be able to hold the field's native value comparator taskFilterComparator + isNumeric bool } func getTaskFiltersByCollections(c *TaskCollection) (filters []*taskFilter, err error) { @@ -90,8 +91,9 @@ func getTaskFiltersByCollections(c *TaskCollection) (filters []*taskFilter, err } // Cast the field value to its native type + var reflectValue *reflect.StructField if len(c.FilterValue) > i { - filter.value, err = getNativeValueForTaskField(filter.field, filter.comparator, c.FilterValue[i]) + reflectValue, filter.value, err = getNativeValueForTaskField(filter.field, filter.comparator, c.FilterValue[i]) if err != nil { return nil, ErrInvalidTaskFilterValue{ Value: filter.field, @@ -99,6 +101,9 @@ func getTaskFiltersByCollections(c *TaskCollection) (filters []*taskFilter, err } } } + if reflectValue != nil { + filter.isNumeric = reflectValue.Type.Kind() == reflect.Int64 + } filters = append(filters, filter) } @@ -192,7 +197,7 @@ func getValueForField(field reflect.StructField, rawValue string) (value interfa return } -func getNativeValueForTaskField(fieldName string, comparator taskFilterComparator, value string) (nativeValue interface{}, err error) { +func getNativeValueForTaskField(fieldName string, comparator taskFilterComparator, value string) (reflectField *reflect.StructField, nativeValue interface{}, err error) { realFieldName := strings.ReplaceAll(strcase.ToCamel(fieldName), "Id", "ID") @@ -203,11 +208,11 @@ func getNativeValueForTaskField(fieldName string, comparator taskFilterComparato for _, val := range vals { v, err := strconv.ParseInt(val, 10, 64) if err != nil { - return nil, err + return nil, nil, err } valueSlice = append(valueSlice, v) } - return valueSlice, nil + return nil, valueSlice, nil } nativeValue, err = strconv.ParseInt(value, 10, 64) @@ -217,12 +222,12 @@ func getNativeValueForTaskField(fieldName string, comparator taskFilterComparato if realFieldName == "Assignees" { vals := strings.Split(value, ",") valueSlice := append([]string{}, vals...) - return valueSlice, nil + return nil, valueSlice, nil } field, ok := reflect.TypeOf(&Task{}).Elem().FieldByName(realFieldName) if !ok { - return nil, ErrInvalidTaskField{TaskField: fieldName} + return nil, nil, ErrInvalidTaskField{TaskField: fieldName} } if comparator == taskFilterComparatorIn { @@ -231,12 +236,13 @@ func getNativeValueForTaskField(fieldName string, comparator taskFilterComparato for _, val := range vals { v, err := getValueForField(field, val) if err != nil { - return nil, err + return nil, nil, err } valueSlice = append(valueSlice, v) } - return valueSlice, nil + return nil, valueSlice, nil } - return getValueForField(field, value) + val, err := getValueForField(field, value) + return &field, val, err } diff --git a/pkg/models/tasks.go b/pkg/models/tasks.go index bcc04a7512f..707e065909c 100644 --- a/pkg/models/tasks.go +++ b/pkg/models/tasks.go @@ -226,7 +226,10 @@ func getFilterCond(f *taskFilter, includeNulls bool) (cond builder.Cond, err err } if includeNulls { - cond = builder.Or(cond, &builder.IsNull{field}, &builder.Eq{field: 0}) + cond = builder.Or(cond, &builder.IsNull{field}) + if f.isNumeric { + cond = builder.Or(cond, &builder.IsNull{field}, &builder.Eq{field: 0}) + } } return